Commit 168d4d1d by Pratyush Patel Committed by Tianqi Chen

[RELAY][OP] Add relay minimum op (#1840)

parent 1bfda4d3
...@@ -61,6 +61,7 @@ This level enables typical convnet models. ...@@ -61,6 +61,7 @@ This level enables typical convnet models.
tvm.relay.less tvm.relay.less
tvm.relay.less_equal tvm.relay.less_equal
tvm.relay.maximum tvm.relay.maximum
tvm.relay.minimum
**Level 5: Vision/Image Operators** **Level 5: Vision/Image Operators**
...@@ -89,4 +90,5 @@ Level 4 Definitions ...@@ -89,4 +90,5 @@ Level 4 Definitions
.. autofunction:: tvm.relay.greater_equal .. autofunction:: tvm.relay.greater_equal
.. autofunction:: tvm.relay.less .. autofunction:: tvm.relay.less
.. autofunction:: tvm.relay.less_equal .. autofunction:: tvm.relay.less_equal
.. autofunction:: tvm.relay.maximum .. autofunction:: tvm.relay.maximum
\ No newline at end of file .. autofunction:: tvm.relay.minimum
...@@ -246,6 +246,24 @@ def maximum(lhs, rhs): ...@@ -246,6 +246,24 @@ def maximum(lhs, rhs):
return _make.maximum(lhs, rhs) return _make.maximum(lhs, rhs)
def minimum(lhs, rhs):
"""Minimum with numpy-style broadcasting.
Parameters
----------
lhs : relay.Expr
The left hand side input data
rhs : relay.Expr
The right hand side input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.minimum(lhs, rhs)
def right_shift(lhs, rhs): def right_shift(lhs, rhs):
"""Right shift with numpy-style broadcasting. """Right shift with numpy-style broadcasting.
......
...@@ -45,6 +45,10 @@ RELAY_REGISTER_BINARY_OP("maximum") ...@@ -45,6 +45,10 @@ RELAY_REGISTER_BINARY_OP("maximum")
.describe("Elementwise maximum of two tensors with broadcasting") .describe("Elementwise maximum of two tensors with broadcasting")
.set_support_level(4); .set_support_level(4);
RELAY_REGISTER_BINARY_OP("minimum")
.describe("Elementwise minimum of two tensors with broadcasting")
.set_support_level(4);
// Comparisons // Comparisons
#define RELAY_REGISTER_CMP_OP(OpName) \ #define RELAY_REGISTER_CMP_OP(OpName) \
TVM_REGISTER_API("relay.op._make." OpName) \ TVM_REGISTER_API("relay.op._make." OpName) \
......
...@@ -23,7 +23,8 @@ def test_cmp_type(): ...@@ -23,7 +23,8 @@ def test_cmp_type():
def test_binary_broadcast(): def test_binary_broadcast():
for op in [relay.right_shift, for op in [relay.right_shift,
relay.left_shift, relay.left_shift,
relay.maximum]: relay.maximum,
relay.minimum]:
ib = relay.ir_builder.IRBuilder() ib = relay.ir_builder.IRBuilder()
x = ib.param("x", relay.TensorType((10, 4), "int32")) x = ib.param("x", relay.TensorType((10, 4), "int32"))
y = ib.param("y", relay.TensorType((5, 10, 1), "int32")) y = ib.param("y", relay.TensorType((5, 10, 1), "int32"))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment