Commit eacc2bc4 by Stilistik Committed by Tianqi Chen

Implement tensorflow relational operators and related tests (#1714)

parent 47e57be4
......@@ -759,6 +759,15 @@ def _mean():
extras={'axis': tuple(axis.asnumpy())})(inputs[0], attr)
return _impl
def _broadcast(name):
def _impl(inputs, attr, params):
op_name = _math_name_picker(name)(attr)
return AttrCvt(
op_name=op_name,
ignores=['name', 'Tidx']
)(inputs, attr)
return _impl
# compatible operators that do NOT require any conversion.
_identity_list = []
......@@ -819,6 +828,12 @@ _convert_map = {
'Transpose' : _transpose(),
'Tanh' : AttrCvt('tanh'),
'Mean' : _mean(),
'Less' : _broadcast('less'),
'Greater' : _broadcast('greater'),
'LessEqual' : _broadcast('less_equal'),
'GreaterEqual' : _broadcast('greater_equal'),
'Equal' : _broadcast('equal'),
'NotEqual' : _broadcast('not_equal'),
}
# _convert_map_rnn defines maps of rnn operator name to
......
......@@ -378,7 +378,7 @@ def test_forward_reduce():
data = np.random.uniform(size=(8,4,9)).astype('float32')
_test_reduce(tf.reduce_sum, data=data)
_test_reduce(tf.reduce_sum, data=data, axis=0)
_test_reduce(tf.reduce_sum, data=data, axis=(0,1))
_test_reduce(tf.reduce_sum, data=data, axis=(0,1))
#######################################################################
......@@ -979,6 +979,28 @@ def test_forward_mean():
check_mean((10, 8, 16, 32), axis=(1,2), keepdims=True)
#######################################################################
# Relational operators
# --------------------
def _test_forward_rel_op(data, func):
with tf.Graph().as_default():
in1 = tf.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in1')
in2 = tf.placeholder(shape=data[1].shape, dtype=data[1].dtype, name='in2')
op = func(in1, in2, name='op')
out = tf.cast(op, tf.int32, name='out1')
compare_tf_with_tvm([data[0], data[1]], ['in1:0', 'in2:0'], 'out1:0')
def test_forward_rel_ops():
t1 = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
t2 = np.array([[9, 8, 7], [6, 5, 4], [3, 2, 1]])
_test_forward_rel_op([t1, t2], math_ops.less)
_test_forward_rel_op([t1, t2], math_ops.greater)
_test_forward_rel_op([t1, t2], math_ops.less_equal)
_test_forward_rel_op([t1, t2], math_ops.greater_equal)
_test_forward_rel_op([t1, t2], math_ops.equal)
_test_forward_rel_op([t1, t2], math_ops.not_equal)
#######################################################################
# Main
# ----
if __name__ == '__main__':
......@@ -1030,3 +1052,6 @@ if __name__ == '__main__':
# Elementwise
test_forward_ceil()
test_forward_floor()
# Relational ops
test_forward_rel_ops()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment