Commit c286c138 by Thierry Moreau Committed by Tianqi Chen

[RELAY][OP] Left shift operator (#1839)

parent eec3648a
...@@ -47,6 +47,7 @@ This level enables typical convnet models. ...@@ -47,6 +47,7 @@ This level enables typical convnet models.
:nosignatures: :nosignatures:
tvm.relay.right_shift tvm.relay.right_shift
tvm.relay.left_shift
tvm.relay.equal tvm.relay.equal
tvm.relay.not_equal tvm.relay.not_equal
tvm.relay.greater tvm.relay.greater
...@@ -74,6 +75,7 @@ Level 2 Definitions ...@@ -74,6 +75,7 @@ Level 2 Definitions
Level 4 Definitions Level 4 Definitions
------------------- -------------------
.. autofunction:: tvm.relay.right_shift .. autofunction:: tvm.relay.right_shift
.. autofunction:: tvm.relay.left_shift
.. autofunction:: tvm.relay.equal .. autofunction:: tvm.relay.equal
.. autofunction:: tvm.relay.not_equal .. autofunction:: tvm.relay.not_equal
.. autofunction:: tvm.relay.greater .. autofunction:: tvm.relay.greater
......
...@@ -120,7 +120,6 @@ def subtract(lhs, rhs): ...@@ -120,7 +120,6 @@ def subtract(lhs, rhs):
return _make.subtract(lhs, rhs) return _make.subtract(lhs, rhs)
def equal(lhs, rhs): def equal(lhs, rhs):
"""Broadcasted elementwise test for (lhs == rhs). """Broadcasted elementwise test for (lhs == rhs).
...@@ -247,6 +246,24 @@ def right_shift(lhs, rhs): ...@@ -247,6 +246,24 @@ def right_shift(lhs, rhs):
return _make.right_shift(lhs, rhs) return _make.right_shift(lhs, rhs)
def left_shift(lhs, rhs):
"""Left shift with numpy-style broadcasting.
Parameters
----------
lhs : relay.Expr
The left hand side input data
rhs : relay.Expr
The right hand side input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.left_shift(lhs, rhs)
def concat(*args): def concat(*args):
"""Concatenate the input tensors along the zero axis. """Concatenate the input tensors along the zero axis.
......
...@@ -27,16 +27,23 @@ RELAY_REGISTER_BINARY_OP("add") ...@@ -27,16 +27,23 @@ RELAY_REGISTER_BINARY_OP("add")
.describe("Elementwise add with with broadcasting") .describe("Elementwise add with with broadcasting")
.set_support_level(1); .set_support_level(1);
// Subtraction
RELAY_REGISTER_BINARY_OP("subtract") RELAY_REGISTER_BINARY_OP("subtract")
.describe("Elementwise substract with broadcasting") .describe("Elementwise substract with broadcasting")
.set_support_level(1); .set_support_level(1);
// Right shift
RELAY_REGISTER_BINARY_OP("right_shift") RELAY_REGISTER_BINARY_OP("right_shift")
.describe("Elementwise right shift with broadcasting") .describe("Elementwise right shift with broadcasting")
.set_support_level(4); .set_support_level(4);
// Left shift
RELAY_REGISTER_BINARY_OP("left_shift")
.describe("Elementwise left shift with broadcasting")
.set_support_level(4);
// Comparisons // Comparisons
#define RELAY_REGISTER_CMP_OP(OpName, SupportLevel) \ #define RELAY_REGISTER_CMP_OP(OpName) \
TVM_REGISTER_API("relay.op._make." OpName) \ TVM_REGISTER_API("relay.op._make." OpName) \
.set_body_typed<Expr(Expr, Expr)>([](Expr lhs, Expr rhs) { \ .set_body_typed<Expr(Expr, Expr)>([](Expr lhs, Expr rhs) { \
static const Op& op = Op::Get(OpName); \ static const Op& op = Op::Get(OpName); \
...@@ -46,15 +53,26 @@ RELAY_REGISTER_BINARY_OP("right_shift") ...@@ -46,15 +53,26 @@ RELAY_REGISTER_BINARY_OP("right_shift")
.set_num_inputs(2) \ .set_num_inputs(2) \
.add_argument("lhs", "Tensor", "The left hand side tensor.") \ .add_argument("lhs", "Tensor", "The left hand side tensor.") \
.add_argument("rhs", "Tensor", "The right hand side tensor.") \ .add_argument("rhs", "Tensor", "The right hand side tensor.") \
.set_support_level(SupportLevel) \ .add_type_rel("BroadcastComp", BroadcastCompRel)
.add_type_rel("BroadcastComp", BroadcastCompRel);
RELAY_REGISTER_CMP_OP("equal")
RELAY_REGISTER_CMP_OP("equal", 4); .describe("Elementwise equal compare with broadcasting")
RELAY_REGISTER_CMP_OP("not_equal", 4); .set_support_level(4);
RELAY_REGISTER_CMP_OP("less", 4); RELAY_REGISTER_CMP_OP("not_equal")
RELAY_REGISTER_CMP_OP("less_equal", 4); .describe("Elementwise not equal with broadcasting")
RELAY_REGISTER_CMP_OP("greater", 4); .set_support_level(4);
RELAY_REGISTER_CMP_OP("greater_equal", 4); RELAY_REGISTER_CMP_OP("less")
.describe("Elementwise less than with broadcasting")
.set_support_level(4);
RELAY_REGISTER_CMP_OP("less_equal")
.describe("Elementwise less than or equal compare with broadcasting")
.set_support_level(4);
RELAY_REGISTER_CMP_OP("greater")
.describe("Elementwise greater than compare with broadcasting")
.set_support_level(4);
RELAY_REGISTER_CMP_OP("greater_equal")
.describe("Elementwise greater than or equal compare with broadcasting")
.set_support_level(4);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -21,7 +21,8 @@ def test_cmp_type(): ...@@ -21,7 +21,8 @@ def test_cmp_type():
def test_binary_broadcast(): def test_binary_broadcast():
for op in [relay.right_shift]: for op in (relay.right_shift,
relay.left_shift):
ib = relay.ir_builder.IRBuilder() ib = relay.ir_builder.IRBuilder()
x = ib.param("x", relay.TensorType((10, 4), "int32")) x = ib.param("x", relay.TensorType((10, 4), "int32"))
y = ib.param("y", relay.TensorType((5, 10, 1), "int32")) y = ib.param("y", relay.TensorType((5, 10, 1), "int32"))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment