Commit af69f873 by Ashutosh Parkhi Committed by Tianqi Chen

[Tensorflow, NNVM, TOPI] Support for logical operators (#2453)

parent 1ca0393e
...@@ -68,6 +68,9 @@ List of operators ...@@ -68,6 +68,9 @@ List of operators
topi.not_equal topi.not_equal
topi.greater_equal topi.greater_equal
topi.less_equal topi.less_equal
topi.logical_and
topi.logical_or
topi.logical_not
topi.arange topi.arange
topi.layout_transform topi.layout_transform
topi.image.resize topi.image.resize
......
...@@ -35,6 +35,9 @@ This level enables fully connected multi-layer perceptron. ...@@ -35,6 +35,9 @@ This level enables fully connected multi-layer perceptron.
nnvm.symbol.exp nnvm.symbol.exp
nnvm.symbol.log nnvm.symbol.log
nnvm.symbol.sqrt nnvm.symbol.sqrt
nnvm.symbol.logical_and
nnvm.symbol.logical_or
nnvm.symbol.logical_not
nnvm.symbol.elemwise_add nnvm.symbol.elemwise_add
nnvm.symbol.elemwise_sub nnvm.symbol.elemwise_sub
nnvm.symbol.elemwise_mul nnvm.symbol.elemwise_mul
...@@ -172,6 +175,9 @@ Detailed Definitions ...@@ -172,6 +175,9 @@ Detailed Definitions
.. autofunction:: nnvm.symbol.exp .. autofunction:: nnvm.symbol.exp
.. autofunction:: nnvm.symbol.log .. autofunction:: nnvm.symbol.log
.. autofunction:: nnvm.symbol.sqrt .. autofunction:: nnvm.symbol.sqrt
.. autofunction:: nnvm.symbol.logical_and
.. autofunction:: nnvm.symbol.logical_or
.. autofunction:: nnvm.symbol.logical_not
.. autofunction:: nnvm.symbol.elemwise_add .. autofunction:: nnvm.symbol.elemwise_add
.. autofunction:: nnvm.symbol.elemwise_sub .. autofunction:: nnvm.symbol.elemwise_sub
.. autofunction:: nnvm.symbol.elemwise_mul .. autofunction:: nnvm.symbol.elemwise_mul
......
...@@ -39,6 +39,7 @@ DTYPE_TO_TCODE = { ...@@ -39,6 +39,7 @@ DTYPE_TO_TCODE = {
"uint16": 8, "uint16": 8,
"uint32": 9, "uint32": 9,
"uint64": 10, "uint64": 10,
"bool": 11,
} }
TCODE_TO_DTYPE = { TCODE_TO_DTYPE = {
...@@ -54,6 +55,7 @@ TCODE_TO_DTYPE = { ...@@ -54,6 +55,7 @@ TCODE_TO_DTYPE = {
8: "uint16", 8: "uint16",
9: "uint32", 9: "uint32",
10: "uint64", 10: "uint64",
11: "bool",
} }
def set_dtype_inputs(g, dtype): def set_dtype_inputs(g, dtype):
......
...@@ -867,6 +867,11 @@ def _expand_dims_0d_aware(data, attr, axis, num_newaxis=1): ...@@ -867,6 +867,11 @@ def _expand_dims_0d_aware(data, attr, axis, num_newaxis=1):
return _sym.expand_dims(data, axis=axis, num_newaxis=num_newaxis) return _sym.expand_dims(data, axis=axis, num_newaxis=num_newaxis)
def _logical(name):
def _impl(inputs, attr, params):
return AttrCvt(op_name=name)(inputs, attr)
return _impl
# compatible operators that do NOT require any conversion. # compatible operators that do NOT require any conversion.
_identity_list = [] _identity_list = []
...@@ -929,6 +934,9 @@ _convert_map = { ...@@ -929,6 +934,9 @@ _convert_map = {
'Transpose' : _transpose(), 'Transpose' : _transpose(),
'Tanh' : AttrCvt('tanh'), 'Tanh' : AttrCvt('tanh'),
'Mean' : _mean(), 'Mean' : _mean(),
'LogicalAnd' : _logical('logical_and'),
'LogicalOr' : _logical('logical_or'),
'LogicalNot' : _logical('logical_not'),
'Less' : _broadcast('less'), 'Less' : _broadcast('less'),
'Greater' : _broadcast('greater'), 'Greater' : _broadcast('greater'),
'LessEqual' : _broadcast('less_equal'), 'LessEqual' : _broadcast('less_equal'),
......
...@@ -140,6 +140,18 @@ reg.register_schedule("__lshift_scalar__", _fschedule_broadcast) ...@@ -140,6 +140,18 @@ reg.register_schedule("__lshift_scalar__", _fschedule_broadcast)
reg.register_pattern("__rshift_scalar__", OpPattern.ELEMWISE) reg.register_pattern("__rshift_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__rshift_scalar__", _fschedule_broadcast) reg.register_schedule("__rshift_scalar__", _fschedule_broadcast)
# logical_and
reg.register_pattern("logical_and", OpPattern.ELEMWISE)
reg.register_schedule("logical_and", _fschedule_broadcast)
# logical_or
reg.register_pattern("logical_or", OpPattern.ELEMWISE)
reg.register_schedule("logical_or", _fschedule_broadcast)
# logical_not
reg.register_pattern("logical_not", OpPattern.ELEMWISE)
reg.register_schedule("logical_not", _fschedule_broadcast)
# elemwise_add # elemwise_add
reg.register_pattern("elemwise_add", OpPattern.BROADCAST) reg.register_pattern("elemwise_add", OpPattern.BROADCAST)
reg.register_schedule("elemwise_add", _fschedule_broadcast) reg.register_schedule("elemwise_add", _fschedule_broadcast)
......
...@@ -40,6 +40,7 @@ int GetTypeFlag(tvm::Type type) { ...@@ -40,6 +40,7 @@ int GetTypeFlag(tvm::Type type) {
if (type == tvm::UInt(16)) return 8; if (type == tvm::UInt(16)) return 8;
if (type == tvm::UInt(32)) return 9; if (type == tvm::UInt(32)) return 9;
if (type == tvm::UInt(64)) return 10; if (type == tvm::UInt(64)) return 10;
if (type == tvm::UInt(1)) return 11;
LOG(FATAL) << "cannot convert " << type; LOG(FATAL) << "cannot convert " << type;
return 0; return 0;
} }
...@@ -68,6 +69,8 @@ Type GetTVMType(int type_flag) { ...@@ -68,6 +69,8 @@ Type GetTVMType(int type_flag) {
return tvm::UInt(32); return tvm::UInt(32);
case 10: case 10:
return tvm::UInt(64); return tvm::UInt(64);
case 11:
return tvm::UInt(1);
default: default:
LOG(FATAL) << "unknown type_flag=" << type_flag; LOG(FATAL) << "unknown type_flag=" << type_flag;
return Float(32); return Float(32);
......
...@@ -361,6 +361,31 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_pow) ...@@ -361,6 +361,31 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_pow)
return Array<Tensor>{ topi::power(inputs[0], inputs[1]) }; return Array<Tensor>{ topi::power(inputs[0], inputs[1]) };
}); });
// logical
NNVM_REGISTER_ELEMWISE_BINARY_OP(logical_and)
.describe(R"code(Elementwise compute the logical AND
)code")
.set_support_level(1)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::logical_and(inputs[0], inputs[1]) };
});
NNVM_REGISTER_ELEMWISE_BINARY_OP(logical_or)
.describe(R"code(Elementwise compute the logical OR
)code")
.set_support_level(1)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::logical_or(inputs[0], inputs[1]) };
});
// negative // negative
NNVM_REGISTER_ELEMWISE_UNARY_OP(negative) NNVM_REGISTER_ELEMWISE_UNARY_OP(negative)
.describe(R"code(Elemenwise numeric negative .describe(R"code(Elemenwise numeric negative
...@@ -383,6 +408,19 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(negative) ...@@ -383,6 +408,19 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(negative)
}; };
}); });
// logical NOT
NNVM_REGISTER_ELEMWISE_UNARY_OP(logical_not)
.describe(R"code(Elementwise compute the logical NOT
)code" NNVM_ADD_FILELINE)
.set_support_level(3)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::logical_not(inputs[0]) };
});
// copy // copy
NNVM_REGISTER_ELEMWISE_UNARY_OP(copy) NNVM_REGISTER_ELEMWISE_UNARY_OP(copy)
.describe(R"code(Copy tensor to another one. .describe(R"code(Copy tensor to another one.
......
...@@ -777,6 +777,48 @@ def test_forward_pad(): ...@@ -777,6 +777,48 @@ def test_forward_pad():
_test_pad((2, 3), [[1,1], [2,2]], mode="CONSTANT") _test_pad((2, 3), [[1,1], [2,2]], mode="CONSTANT")
_test_pad((2, 3), [[1,1], [2,2]], mode="CONSTANT", constant_values=1.0) _test_pad((2, 3), [[1,1], [2,2]], mode="CONSTANT", constant_values=1.0)
#######################################################################
# Logical operators
# --------------------
def test_logical_and():
with tf.Graph().as_default():
in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1')
in2 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in2')
out = tf.logical_and(in1, in2, name='out')
in_data1 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool')
in_data2 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool')
compare_tf_with_tvm([in_data1, in_data2], ['in1:0', 'in2:0'], 'out:0')
def test_logical_or():
with tf.Graph().as_default():
in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1')
in2 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in2')
out = tf.logical_or(in1, in2, name='out')
in_data1 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool')
in_data2 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool')
compare_tf_with_tvm([in_data1, in_data2], ['in1:0', 'in2:0'], 'out:0')
def test_logical_xor():
with tf.Graph().as_default():
in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1')
in2 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in2')
out = tf.logical_xor(in1, in2, name='out')
in_data1 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool')
in_data2 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool')
compare_tf_with_tvm([in_data1, in_data2], ['in1:0', 'in2:0'], 'out:0')
def test_logical_not():
with tf.Graph().as_default():
in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1')
out = tf.logical_not(in1, name='out')
in_data1 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool')
compare_tf_with_tvm(in_data1, 'in1:0', 'out:0')
def test_forward_logical():
test_logical_and()
test_logical_or()
test_logical_xor()
test_logical_not()
####################################################################### #######################################################################
# Inception V3 # Inception V3
...@@ -1205,3 +1247,4 @@ if __name__ == '__main__': ...@@ -1205,3 +1247,4 @@ if __name__ == '__main__':
# Relational ops # Relational ops
test_forward_rel_ops() test_forward_rel_ops()
test_forward_logical()
...@@ -93,6 +93,33 @@ inline tvm::Tensor broadcast_to(const tvm::Tensor& t, ...@@ -93,6 +93,33 @@ inline tvm::Tensor broadcast_to(const tvm::Tensor& t,
return topi::OpName(A, B); \ return topi::OpName(A, B); \
} }
/*!
* \fn logical_and
* \brief Compute A && B with auto-broadcasting.
*
* \param A The first tensor, or Expr
* \param B The second tensor, or Expr
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return The result.
*/
TOPI_DEFINE_BCAST_OP(logical_and, { return a && b; });
TOPI_DEFINE_OP_OVERLOAD(operator&&, logical_and);
/*!
* \fn logical_or
* \brief Compute A || B with auto-broadcasting.
*
* \param A The first tensor, or Expr
* \param B The second tensor, or Expr
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return The result.
*/
TOPI_DEFINE_BCAST_OP(logical_or, { return a || b; });
TOPI_DEFINE_OP_OVERLOAD(operator||, logical_or);
/*! /*!
* \fn add * \fn add
......
...@@ -72,6 +72,23 @@ inline Tensor negative(const Tensor& x, ...@@ -72,6 +72,23 @@ inline Tensor negative(const Tensor& x,
} }
/*! /*!
* \brief Creates an operation that returns the logical NOT of a given tensor
*
* \param x The input tensor
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the logical NOT operation
*/
inline Tensor logical_not(const Tensor& x,
std::string name = "tensor",
std::string tag = kElementWise) {
return compute(x->shape, [&](const Array<Var>& i) {
return !x(i);
}, name, tag);
}
/*!
* \brief Creates an operation that clips each element of a tensor to * \brief Creates an operation that clips each element of a tensor to
* the interval [a_min, a_max] * the interval [a_min, a_max]
* *
......
...@@ -112,6 +112,8 @@ TOPI_REGISTER_BCAST_OP("topi.maximum", topi::maximum); ...@@ -112,6 +112,8 @@ TOPI_REGISTER_BCAST_OP("topi.maximum", topi::maximum);
TOPI_REGISTER_BCAST_OP("topi.minimum", topi::minimum); TOPI_REGISTER_BCAST_OP("topi.minimum", topi::minimum);
TOPI_REGISTER_BCAST_OP("topi.power", topi::power); TOPI_REGISTER_BCAST_OP("topi.power", topi::power);
TOPI_REGISTER_BCAST_OP("topi.left_shift", topi::left_shift); TOPI_REGISTER_BCAST_OP("topi.left_shift", topi::left_shift);
TOPI_REGISTER_BCAST_OP("topi.logical_and", topi::logical_and);
TOPI_REGISTER_BCAST_OP("topi.logical_or", topi::logical_or);
TOPI_REGISTER_BCAST_OP("topi.right_shift", topi::right_shift); TOPI_REGISTER_BCAST_OP("topi.right_shift", topi::right_shift);
TOPI_REGISTER_BCAST_OP("topi.greater", topi::greater); TOPI_REGISTER_BCAST_OP("topi.greater", topi::greater);
TOPI_REGISTER_BCAST_OP("topi.less", topi::less); TOPI_REGISTER_BCAST_OP("topi.less", topi::less);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment