Commit d6ff734b by Junru Shao Committed by Tianqi Chen

[RELAY][OP] comparison (#1824)

parent 160eaf54
......@@ -42,6 +42,15 @@ This level enables typical convnet models.
**Level 4: Broadcast and Reductions**
.. autosummary::
:nosignatures:
tvm.relay.equal
tvm.relay.not_equal
tvm.relay.greater
tvm.relay.greater_equal
tvm.relay.less
tvm.relay.less_equal
**Level 5: Vision/Image Operators**
......
......@@ -104,10 +104,115 @@ def subtract(lhs, rhs):
return _make.subtract(lhs, rhs)
def equal(lhs, rhs):
"""Broadcasted elementwise test for (lhs == rhs).
Parameters
----------
lhs : relay.Expr
The left hand side input data
rhs : relay.Expr
The right hand side input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.equal(lhs, rhs)
def not_equal(lhs, rhs):
"""Broadcasted elementwise test for (lhs != rhs).
Parameters
----------
lhs : relay.Expr
The left hand side input data
rhs : relay.Expr
The right hand side input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.not_equal(lhs, rhs)
def less(lhs, rhs):
"""Broadcasted elementwise test for (lhs < rhs).
Parameters
----------
lhs : relay.Expr
The left hand side input data
rhs : relay.Expr
The right hand side input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.less(lhs, rhs)
def less_equal(lhs, rhs):
"""Broadcasted elementwise test for (lhs <= rhs).
Parameters
----------
lhs : relay.Expr
The left hand side input data
rhs : relay.Expr
The right hand side input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.less_equal(lhs, rhs)
def greater(lhs, rhs):
"""Broadcasted elementwise test for (lhs > rhs).
Parameters
----------
lhs : relay.Expr
The left hand side input data
rhs : relay.Expr
The right hand side input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.greater(lhs, rhs)
def greater_equal(lhs, rhs):
"""Broadcasted elementwise test for (lhs >= rhs).
Parameters
----------
lhs : relay.Expr
The left hand side input data
rhs : relay.Expr
The right hand side input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.greater_equal(lhs, rhs)
def concat(*args):
"""Concatenate the input tensors along the zero axis.
......
......@@ -106,20 +106,27 @@ RELAY_REGISTER_OP("subtract")
// input2: Tensor[dtype, s2]
// output: Tensor[dtype, broadcast(s1, s2)]
// Equality Comparison
TVM_REGISTER_API("relay.op._make.equal")
.set_body_typed<Expr(Expr, Expr)>([](Expr lhs, Expr rhs) {
static const Op& op = Op::Get("equal");
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
});
RELAY_REGISTER_OP("equal")
.set_num_inputs(2)
.add_argument("lhs", "Tensor", "The left hand side tensor.")
.add_argument("rhs", "Tensor", "The right hand side tensor.")
.set_support_level(1)
// Comparisons
#define RELAY_REGISTER_CMP_OP(OpName, SupportLevel) \
TVM_REGISTER_API("relay.op._make." OpName) \
.set_body_typed<Expr(Expr, Expr)>([](Expr lhs, Expr rhs) { \
static const Op& op = Op::Get(OpName); \
return CallNode::make(op, {lhs, rhs}, Attrs(), {}); \
}); \
RELAY_REGISTER_OP(OpName) \
.set_num_inputs(2) \
.add_argument("lhs", "Tensor", "The left hand side tensor.") \
.add_argument("rhs", "Tensor", "The right hand side tensor.") \
.set_support_level(SupportLevel) \
.add_type_rel("BroadcastComp", BroadcastCompRel);
RELAY_REGISTER_CMP_OP("equal", 4);
RELAY_REGISTER_CMP_OP("not_equal", 4);
RELAY_REGISTER_CMP_OP("less", 4);
RELAY_REGISTER_CMP_OP("less_equal", 4);
RELAY_REGISTER_CMP_OP("greater", 4);
RELAY_REGISTER_CMP_OP("greater_equal", 4);
// Concat
TVM_REGISTER_API("relay.op._make.concat")
.set_body_typed<Expr(Expr)>([](Expr tuple) {
......
import tvm
from tvm import relay
def test_cmp_type():
for op in (relay.greater,
relay.greater_equal,
relay.less,
relay.less_equal,
relay.equal,
relay.not_equal):
ib = relay.ir_builder.IRBuilder()
x = ib.param("x", relay.TensorType((10, 4), "float32"))
y = ib.param("y", relay.TensorType((5, 10, 1), "float32"))
with ib.function(x, y) as func:
ib.ret(op(x.var, y.var))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type()
assert ftype.ret_type == relay.TensorType((5, 10, 4), "uint1")
if __name__ == "__main__":
test_cmp_type()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment