Commit c286c138 by Thierry Moreau Committed by Tianqi Chen

[RELAY][OP] Left shift operator (#1839)

parent eec3648a
......@@ -47,6 +47,7 @@ This level enables typical convnet models.
:nosignatures:
tvm.relay.right_shift
tvm.relay.left_shift
tvm.relay.equal
tvm.relay.not_equal
tvm.relay.greater
......@@ -74,6 +75,7 @@ Level 2 Definitions
Level 4 Definitions
-------------------
.. autofunction:: tvm.relay.right_shift
.. autofunction:: tvm.relay.left_shift
.. autofunction:: tvm.relay.equal
.. autofunction:: tvm.relay.not_equal
.. autofunction:: tvm.relay.greater
......
......@@ -120,7 +120,6 @@ def subtract(lhs, rhs):
return _make.subtract(lhs, rhs)
def equal(lhs, rhs):
"""Broadcasted elementwise test for (lhs == rhs).
......@@ -247,6 +246,24 @@ def right_shift(lhs, rhs):
return _make.right_shift(lhs, rhs)
def left_shift(lhs, rhs):
"""Left shift with numpy-style broadcasting.
Parameters
----------
lhs : relay.Expr
The left hand side input data
rhs : relay.Expr
The right hand side input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.left_shift(lhs, rhs)
def concat(*args):
"""Concatenate the input tensors along the zero axis.
......
......@@ -27,16 +27,23 @@ RELAY_REGISTER_BINARY_OP("add")
.describe("Elementwise add with with broadcasting")
.set_support_level(1);
// Subtraction
RELAY_REGISTER_BINARY_OP("subtract")
.describe("Elementwise substract with broadcasting")
.set_support_level(1);
// Right shift
RELAY_REGISTER_BINARY_OP("right_shift")
.describe("Elementwise right shift with broadcasting")
.set_support_level(4);
// Left shift
RELAY_REGISTER_BINARY_OP("left_shift")
.describe("Elementwise left shift with broadcasting")
.set_support_level(4);
// Comparisons
#define RELAY_REGISTER_CMP_OP(OpName, SupportLevel) \
#define RELAY_REGISTER_CMP_OP(OpName) \
TVM_REGISTER_API("relay.op._make." OpName) \
.set_body_typed<Expr(Expr, Expr)>([](Expr lhs, Expr rhs) { \
static const Op& op = Op::Get(OpName); \
......@@ -46,15 +53,26 @@ RELAY_REGISTER_BINARY_OP("right_shift")
.set_num_inputs(2) \
.add_argument("lhs", "Tensor", "The left hand side tensor.") \
.add_argument("rhs", "Tensor", "The right hand side tensor.") \
.set_support_level(SupportLevel) \
.add_type_rel("BroadcastComp", BroadcastCompRel);
RELAY_REGISTER_CMP_OP("equal", 4);
RELAY_REGISTER_CMP_OP("not_equal", 4);
RELAY_REGISTER_CMP_OP("less", 4);
RELAY_REGISTER_CMP_OP("less_equal", 4);
RELAY_REGISTER_CMP_OP("greater", 4);
RELAY_REGISTER_CMP_OP("greater_equal", 4);
.add_type_rel("BroadcastComp", BroadcastCompRel)
RELAY_REGISTER_CMP_OP("equal")
.describe("Elementwise equal compare with broadcasting")
.set_support_level(4);
RELAY_REGISTER_CMP_OP("not_equal")
.describe("Elementwise not equal with broadcasting")
.set_support_level(4);
RELAY_REGISTER_CMP_OP("less")
.describe("Elementwise less than with broadcasting")
.set_support_level(4);
RELAY_REGISTER_CMP_OP("less_equal")
.describe("Elementwise less than or equal compare with broadcasting")
.set_support_level(4);
RELAY_REGISTER_CMP_OP("greater")
.describe("Elementwise greater than compare with broadcasting")
.set_support_level(4);
RELAY_REGISTER_CMP_OP("greater_equal")
.describe("Elementwise greater than or equal compare with broadcasting")
.set_support_level(4);
} // namespace relay
} // namespace tvm
......@@ -21,7 +21,8 @@ def test_cmp_type():
def test_binary_broadcast():
for op in [relay.right_shift]:
for op in (relay.right_shift,
relay.left_shift):
ib = relay.ir_builder.IRBuilder()
x = ib.param("x", relay.TensorType((10, 4), "int32"))
y = ib.param("y", relay.TensorType((5, 10, 1), "int32"))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment