Commit 2239508b by abergeron Committed by Tianqi Chen

[Relay] Add logical operators (#2743)

parent 695647db
...@@ -366,7 +366,7 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(logical_and) ...@@ -366,7 +366,7 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(logical_and)
.describe(R"code(Elementwise compute the logical AND .describe(R"code(Elementwise compute the logical AND
)code") )code")
.set_support_level(1) .set_support_level(4)
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs, "FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
...@@ -378,7 +378,7 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(logical_or) ...@@ -378,7 +378,7 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(logical_or)
.describe(R"code(Elementwise compute the logical OR .describe(R"code(Elementwise compute the logical OR
)code") )code")
.set_support_level(1) .set_support_level(4)
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs, "FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
...@@ -413,7 +413,7 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(logical_not) ...@@ -413,7 +413,7 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(logical_not)
.describe(R"code(Elementwise compute the logical NOT .describe(R"code(Elementwise compute the logical NOT
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.set_support_level(3) .set_support_level(4)
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs, "FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
......
...@@ -849,6 +849,11 @@ def _softmax(): ...@@ -849,6 +849,11 @@ def _softmax():
transforms={'axis': ('axis', 1)})([inputs[0]], attr) transforms={'axis': ('axis', 1)})([inputs[0]], attr)
return _impl return _impl
def _logical(name):
def _impl(inputs, attr, params):
return AttrCvt(op_name=name)(inputs, attr)
return _impl
# compatible operators that do NOT require any conversion. # compatible operators that do NOT require any conversion.
_identity_list = [] _identity_list = []
...@@ -909,6 +914,9 @@ _convert_map = { ...@@ -909,6 +914,9 @@ _convert_map = {
'Transpose' : _transpose(), 'Transpose' : _transpose(),
'Tanh' : AttrCvt('tanh'), 'Tanh' : AttrCvt('tanh'),
'Mean' : _mean(), 'Mean' : _mean(),
'LogicalAnd' : _logical('logical_and'),
'LogicalOr' : _logical('logical_or'),
'LogicalNot' : _logical('logical_not'),
'Less' : _broadcast('less'), 'Less' : _broadcast('less'),
'Greater' : _broadcast('greater'), 'Greater' : _broadcast('greater'),
'LessEqual' : _broadcast('less_equal'), 'LessEqual' : _broadcast('less_equal'),
......
...@@ -18,6 +18,7 @@ register_schedule("trunc", schedule_broadcast) ...@@ -18,6 +18,7 @@ register_schedule("trunc", schedule_broadcast)
register_schedule("round", schedule_broadcast) register_schedule("round", schedule_broadcast)
register_schedule("abs", schedule_broadcast) register_schedule("abs", schedule_broadcast)
register_schedule("tanh", schedule_broadcast) register_schedule("tanh", schedule_broadcast)
register_schedule("logical_not", schedule_broadcast)
register_schedule("negative", schedule_broadcast) register_schedule("negative", schedule_broadcast)
register_schedule("copy", schedule_broadcast) register_schedule("copy", schedule_broadcast)
...@@ -27,6 +28,8 @@ register_schedule("multiply", schedule_broadcast) ...@@ -27,6 +28,8 @@ register_schedule("multiply", schedule_broadcast)
register_schedule("divide", schedule_broadcast) register_schedule("divide", schedule_broadcast)
register_schedule("power", schedule_injective) register_schedule("power", schedule_injective)
register_schedule("mod", schedule_broadcast) register_schedule("mod", schedule_broadcast)
register_schedule("logical_and", schedule_broadcast)
register_schedule("logical_or", schedule_broadcast)
register_schedule("equal", schedule_broadcast) register_schedule("equal", schedule_broadcast)
register_schedule("not_equal", schedule_broadcast) register_schedule("not_equal", schedule_broadcast)
register_schedule("less", schedule_broadcast) register_schedule("less", schedule_broadcast)
......
...@@ -191,6 +191,22 @@ def negative(data): ...@@ -191,6 +191,22 @@ def negative(data):
return _make.negative(data) return _make.negative(data)
def logical_not(data):
"""Compute element-wise logical not of data.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.logical_not(data)
def add(lhs, rhs): def add(lhs, rhs):
"""Addition with numpy-style broadcasting. """Addition with numpy-style broadcasting.
...@@ -307,6 +323,42 @@ def mod(lhs, rhs): ...@@ -307,6 +323,42 @@ def mod(lhs, rhs):
return _make.mod(lhs, rhs) return _make.mod(lhs, rhs)
def logical_and(lhs, rhs):
"""logical AND with numpy-style broadcasting.
Parameters
----------
lhs : relay.Expr
The left hand side input data
rhs : relay.Expr
The right hand side input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.logical_and(lhs, rhs)
def logical_or(lhs, rhs):
"""logical OR with numpy-style broadcasting.
Parameters
----------
lhs : relay.Expr
The left hand side input data
rhs : relay.Expr
The right hand side input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.logical_or(lhs, rhs)
def equal(lhs, rhs): def equal(lhs, rhs):
"""Broadcasted elementwise test for (lhs == rhs). """Broadcasted elementwise test for (lhs == rhs).
......
...@@ -82,6 +82,18 @@ RELAY_REGISTER_BINARY_OP("mod") ...@@ -82,6 +82,18 @@ RELAY_REGISTER_BINARY_OP("mod")
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::mod)); .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::mod));
RELAY_REGISTER_BINARY_OP("logical_and")
.describe("Elementwise logical AND with broadcasting")
.set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_and));
RELAY_REGISTER_BINARY_OP("logical_or")
.describe("Elementwise logical OR with broadcasting")
.set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_or));
RELAY_REGISTER_CMP_OP("equal") RELAY_REGISTER_CMP_OP("equal")
.describe("Elementwise equal compare with broadcasting") .describe("Elementwise equal compare with broadcasting")
.set_support_level(4) .set_support_level(4)
......
...@@ -178,5 +178,16 @@ RELAY_REGISTER_UNARY_OP("negative") ...@@ -178,5 +178,16 @@ RELAY_REGISTER_UNARY_OP("negative")
.set_support_level(3) .set_support_level(3)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::negative)); .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::negative));
RELAY_REGISTER_UNARY_OP("logical_not")
.describe(R"code(Returns the logical inverse of input array, computed element-wise.
.. math::
~(x)
)code" TVM_ADD_FILELINE)
.set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::logical_not));
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -682,6 +682,49 @@ def test_forward_pad(): ...@@ -682,6 +682,49 @@ def test_forward_pad():
_test_pad((2, 3), [[1,1], [2,2]], mode="CONSTANT") _test_pad((2, 3), [[1,1], [2,2]], mode="CONSTANT")
_test_pad((2, 3), [[1,1], [2,2]], mode="CONSTANT", constant_values=1.0) _test_pad((2, 3), [[1,1], [2,2]], mode="CONSTANT", constant_values=1.0)
#######################################################################
# Logical operators
# --------------------
def test_logical_and():
with tf.Graph().as_default():
in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1')
in2 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in2')
out = tf.logical_and(in1, in2, name='out')
in_data1 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool')
in_data2 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool')
compare_tf_with_tvm([in_data1, in_data2], ['in1:0', 'in2:0'], 'out:0')
def test_logical_or():
with tf.Graph().as_default():
in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1')
in2 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in2')
out = tf.logical_or(in1, in2, name='out')
in_data1 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool')
in_data2 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool')
compare_tf_with_tvm([in_data1, in_data2], ['in1:0', 'in2:0'], 'out:0')
def test_logical_xor():
with tf.Graph().as_default():
in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1')
in2 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in2')
out = tf.logical_xor(in1, in2, name='out')
in_data1 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool')
in_data2 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool')
compare_tf_with_tvm([in_data1, in_data2], ['in1:0', 'in2:0'], 'out:0')
def test_logical_not():
with tf.Graph().as_default():
in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1')
out = tf.logical_not(in1, name='out')
in_data1 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool')
compare_tf_with_tvm(in_data1, 'in1:0', 'out:0')
def test_forward_logical():
test_logical_and()
test_logical_or()
test_logical_xor()
test_logical_not()
####################################################################### #######################################################################
# Inception V3 # Inception V3
...@@ -1109,5 +1152,4 @@ if __name__ == '__main__': ...@@ -1109,5 +1152,4 @@ if __name__ == '__main__':
# Relational ops # Relational ops
test_forward_rel_ops() test_forward_rel_ops()
test_forward_logical()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment