Commit 9bb16872 by Yong Wu Committed by Tianqi Chen

[Relay][Frontend] Add a bunch of ops in tf converter (#3270)

parent c9e96d9f
...@@ -777,12 +777,12 @@ def _sum(): ...@@ -777,12 +777,12 @@ def _sum():
ignores=['name', 'Tidx'])([inputs[0]], attr) ignores=['name', 'Tidx'])([inputs[0]], attr)
return _impl return _impl
def _reduce_all(): def _reduce(op):
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
axis = params.pop(inputs[1].name_hint).asnumpy() axis = params.pop(inputs[1].name_hint).asnumpy()
axis = tuple(axis) axis = tuple(axis)
return AttrCvt( return AttrCvt(
op_name='all', op_name=op,
extras={'axis': axis}, extras={'axis': axis},
transforms={'keep_dims':'keepdims'}, transforms={'keep_dims':'keepdims'},
ignores=['name', 'Tidx'])([inputs[0]], attr) ignores=['name', 'Tidx'])([inputs[0]], attr)
...@@ -807,6 +807,14 @@ def _gather(): ...@@ -807,6 +807,14 @@ def _gather():
'Taxis', '_class'])(new_input, attr) 'Taxis', '_class'])(new_input, attr)
return _impl return _impl
def _gather_nd():
"""GatherNd"""
def _impl(inputs, attr, params):
return AttrCvt(op_name="gather_nd",
ignores=['Tindices', 'Tparams',\
'Taxis', '_class'])(inputs, attr)
return _impl
def _stridedSlice(): def _stridedSlice():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
"""Strided Slice. """Strided Slice.
...@@ -971,15 +979,18 @@ def _rank(): ...@@ -971,15 +979,18 @@ def _rank():
def _range(): def _range():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
start = _get_num_param(params, inputs[0]) start = params.pop(inputs[0].name_hint).asnumpy()[0]
limit = _get_num_param(params, inputs[1]) limit = params.pop(inputs[1].name_hint).asnumpy()[0] \
delta = _get_num_param(params, inputs[2]) if hasattr(inputs[1], "name_hint") else params.pop('Rank').asnumpy()[0]
delta = params.pop(inputs[2].name_hint).asnumpy()[0]
name = attr["_node_name"] dtype = attr['dtype'].name if 'dtype' in attr else "int32"
params[name] = tvm.nd.array([start, limit, delta]) return AttrCvt(
return [_expr.var(name, op_name="arange",
shape=params[name].shape, ignores=['Tidx'],
dtype='int32')] extras={'start': start,
"stop": limit,
'step': delta,
'dtype': dtype})([], attr)
return _impl return _impl
def _elu(): def _elu():
...@@ -1099,6 +1110,13 @@ def _topk(): ...@@ -1099,6 +1110,13 @@ def _topk():
extras={'k': k, 'is_ascend': False, 'dtype': 'int32'})(inputs, attr) extras={'k': k, 'is_ascend': False, 'dtype': 'int32'})(inputs, attr)
return _impl return _impl
def _floordiv():
def _impl(inputs, attr, params):
assert len(inputs) == 2
div = AttrCvt('divide')(inputs, attr)
return _get_relay_op('floor')(div)
return _impl
def _logical(name): def _logical(name):
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
return AttrCvt(op_name=name)(inputs, attr) return AttrCvt(op_name=name)(inputs, attr)
...@@ -1207,8 +1225,9 @@ _identity_list = [] ...@@ -1207,8 +1225,9 @@ _identity_list = []
# for 1 to N mapping(composed), use custom callable functions # for 1 to N mapping(composed), use custom callable functions
# for N to 1 mapping, currently not supported(?) # for N to 1 mapping, currently not supported(?)
_convert_map = { _convert_map = {
'Abs' : AttrCvt('abs'),
'Add' : _elemwise('add'), 'Add' : _elemwise('add'),
'All' : _reduce_all(), 'All' : _reduce('all'),
'ArgMax' : _argx(_op.argmax, 'argmax'), 'ArgMax' : _argx(_op.argmax, 'argmax'),
'ArgMin' : _argx(_op.argmin, 'argmin'), 'ArgMin' : _argx(_op.argmin, 'argmin'),
'AvgPool' : _pooling('avg_pool'), 'AvgPool' : _pooling('avg_pool'),
...@@ -1232,26 +1251,33 @@ _convert_map = { ...@@ -1232,26 +1251,33 @@ _convert_map = {
'ExpandDims' : _expand_dims(), 'ExpandDims' : _expand_dims(),
'Fill' : _fill(), 'Fill' : _fill(),
'Floor' : AttrCvt('floor'), 'Floor' : AttrCvt('floor'),
'FloorDiv' : _floordiv(),
'FusedBatchNorm' : _fused_batch_norm(), 'FusedBatchNorm' : _fused_batch_norm(),
'FusedBatchNormV2' : _fused_batch_norm(), 'FusedBatchNormV2' : _fused_batch_norm(),
'Gather' : _gather(), 'Gather' : _gather(),
'GatherNd' : _gather_nd(),
'GatherV2' : _gather(), 'GatherV2' : _gather(),
'Greater' : _broadcast('greater'), 'Greater' : _broadcast('greater'),
'GreaterEqual' : _broadcast('greater_equal'), 'GreaterEqual' : _broadcast('greater_equal'),
'Identity' : _identity(), 'Identity' : _identity(),
'LeakyRelu' : AttrCvt('leaky_relu'), 'LeakyRelu' : AttrCvt('leaky_relu'),
'LeftShift' : AttrCvt('left_shift'),
'Less' : _broadcast('less'), 'Less' : _broadcast('less'),
'LessEqual' : _broadcast('less_equal'), 'LessEqual' : _broadcast('less_equal'),
'Log' : AttrCvt('log'), 'Log' : AttrCvt('log'),
'LogicalAnd' : _logical('logical_and'), 'LogicalAnd' : _logical('logical_and'),
'LogicalOr' : _logical('logical_or'), 'LogicalOr' : _logical('logical_or'),
'LogicalNot' : _logical('logical_not'), 'LogicalNot' : _logical('logical_not'),
'LogSoftmax' : AttrCvt('log_softmax'),
'LRN' : _lrn(), 'LRN' : _lrn(),
'MatMul' : _matmul(), 'MatMul' : _matmul(),
'Max' : _reduce('max'),
'MaxPool' : _pooling('max_pool'), 'MaxPool' : _pooling('max_pool'),
'Maximum' : _elemwise('maximum'), 'Maximum' : _elemwise('maximum'),
'Mean' : _mean(), 'Mean' : _mean(),
'Min' : _reduce('min'),
'Minimum' : _elemwise('minimum'), 'Minimum' : _elemwise('minimum'),
'Mod' : _elemwise('mod'),
'Mul' : _elemwise('multiply'), 'Mul' : _elemwise('multiply'),
'Neg' : AttrCvt('negative'), 'Neg' : AttrCvt('negative'),
'NotEqual' : _broadcast('not_equal'), 'NotEqual' : _broadcast('not_equal'),
...@@ -1269,6 +1295,7 @@ _convert_map = { ...@@ -1269,6 +1295,7 @@ _convert_map = {
'ResizeBilinear' : _resize_bilinear(), 'ResizeBilinear' : _resize_bilinear(),
'ResizeBicubic' : _resize_bilinear(), 'ResizeBicubic' : _resize_bilinear(),
'ReverseV2' : _reverse_v2(), 'ReverseV2' : _reverse_v2(),
'RightShift' : AttrCvt('right_shift'),
'Round' : AttrCvt('round'), 'Round' : AttrCvt('round'),
'Rsqrt' : _rsqrt(), 'Rsqrt' : _rsqrt(),
'Select' : _where(), 'Select' : _where(),
...@@ -1292,7 +1319,9 @@ _convert_map = { ...@@ -1292,7 +1319,9 @@ _convert_map = {
'Tile' : _tile(), 'Tile' : _tile(),
'TopKV2' : _topk(), 'TopKV2' : _topk(),
'Transpose' : _transpose(), 'Transpose' : _transpose(),
'TruncateMod' : _elemwise('mod'),
'Unpack' : _unpack(), 'Unpack' : _unpack(),
'ZerosLike' : AttrCvt('zeros_like'),
} }
......
...@@ -64,6 +64,7 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1, ...@@ -64,6 +64,7 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
layout=layout, layout=layout,
shape=shape_dict, shape=shape_dict,
outputs=out_names) outputs=out_names)
with relay.build_config(opt_level=opt_level): with relay.build_config(opt_level=opt_level):
graph, lib, params = relay.build(sym, target, target_host, params) graph, lib, params = relay.build(sym, target, target_host, params)
...@@ -642,10 +643,53 @@ def test_forward_stridedslice(): ...@@ -642,10 +643,53 @@ def test_forward_stridedslice():
'float32', shrink_axis_mask=8, new_axis_mask=1, ellipsis_mask=2, begin_mask=5, 'float32', shrink_axis_mask=8, new_axis_mask=1, ellipsis_mask=2, begin_mask=5,
end_mask=8) end_mask=8)
#######################################################################
# FloorDiv, RealDiv
# -----------------
def _test_forward_divide(ip_shape, dtype):
np_numer = np.random.uniform(-100, 100, size=ip_shape).astype(dtype)
np_denomin = np.random.uniform(1, 100, size=ip_shape).astype(dtype)
tf.reset_default_graph()
numerator = tf.placeholder(dtype, ip_shape, name="numer")
denominator = tf.placeholder(dtype, ip_shape, name="denomin")
tf.math.divide(numerator, denominator, name='RealDiv')
compare_tf_with_tvm([np_numer, np_denomin], ['numer:0', 'denomin:0'], 'RealDiv:0')
def _test_forward_floordiv(ip_shape, dtype):
np_numer = np.random.uniform(-100, 100, size=ip_shape).astype(dtype)
tf.reset_default_graph()
numerator = tf.placeholder(dtype, ip_shape, name="numer")
tf.math.floordiv(numerator, tf.constant(5, dtype=dtype), name='FloorDiv')
compare_tf_with_tvm([np_numer], ['numer:0'], 'FloorDiv:0')
def test_forward_divide():
'''test FloorDiv, RealDiv'''
_test_forward_divide((4,), 'int32')
_test_forward_divide((4, 3, 7), 'float32')
_test_forward_floordiv((4, 3, 7), 'float32')
####################################################################### #######################################################################
# Gather, GatherV2 # TruncateMod
# ---------------- # -----------
def _test_forward_truncatemod(ip_shape, dtype):
np_data_1 = np.random.uniform(-100, 100, size=ip_shape).astype(dtype)
np_data_2 = np.random.uniform(1, 10, size=ip_shape).astype(dtype)
tf.reset_default_graph()
in_data_1 = tf.placeholder(dtype, ip_shape, name="in_data_1")
in_data_2 = tf.placeholder(dtype, ip_shape, name="in_data_2")
tf.truncatemod(in_data_1, in_data_2, name='truncatemod')
compare_tf_with_tvm([np_data_1, np_data_2], ['in_data_1:0', 'in_data_2:0'], 'truncatemod:0')
def test_forward_truncatemod():
'''test TruncateMod'''
_test_forward_truncatemod((4, 3, 7), 'int32')
#######################################################################
# Gather, GatherV2, GatherNd
# --------------------------
def _test_gather(ip_shape, indice_shape, indice_value, axis, dtype): def _test_gather(ip_shape, indice_shape, indice_value, axis, dtype):
""" One iteration of a GatherV2 """ """ One iteration of a GatherV2 """
...@@ -718,6 +762,33 @@ def test_forward_gather_v1(): ...@@ -718,6 +762,33 @@ def test_forward_gather_v1():
_test_gather_v1((4, 3, 5, 6), (1, 4), [[2, 1, 0, 0]], 'float32') _test_gather_v1((4, 3, 5, 6), (1, 4), [[2, 1, 0, 0]], 'float32')
def test_forward_gather_nd():
"""test operator GatherNd"""
np_data = np.random.uniform(1, 100, size=(2, 2)).astype(np.float32)
tf.reset_default_graph()
in_data = tf.placeholder(tf.float32, (2, 2), name="in_data")
tf.gather_nd(in_data, indices=[[1, 0], [0, 1]], name="gather_nd")
compare_tf_with_tvm([np_data], ['in_data:0'], 'gather_nd:0')
#######################################################################
# BiasAdd
# -------
def test_forward_bias_add():
"""test Op BiasAdd"""
def check_bias_add(lh_shpae, rh_shape, dtype):
tf.reset_default_graph()
lh_data = np.random.uniform(size=lh_shpae).astype(dtype)
rh_data = np.random.uniform(size=rh_shape).astype(dtype)
lft_data = tf.placeholder(dtype, name="lft_data")
rgt_data = tf.placeholder(dtype, name="rgt_data")
tf.nn.bias_add(lft_data, rgt_data, name="BiasAdd")
compare_tf_with_tvm([lh_data, rh_data], ['lft_data:0', 'rgt_data:0'], 'BiasAdd:0')
check_bias_add((10, 8, 16, 32), (32,), dtype="int32")
check_bias_add((10, 20), (20,), dtype="float32")
####################################################################### #######################################################################
# Split # Split
# ----- # -----
...@@ -1109,6 +1180,32 @@ def test_forward_pack(): ...@@ -1109,6 +1180,32 @@ def test_forward_pack():
_test_pack(axis, [3]) _test_pack(axis, [3])
_test_pack(0, []) _test_pack(0, [])
#######################################################################
# Unpack
# ------
def _test_forward_unpack(in_shape, axis, dtype):
"""test operator Unpack"""
np_data = np.random.uniform(-100, 100, size=in_shape).astype(dtype)
tf.reset_default_graph()
in_data = tf.placeholder(dtype, in_shape, name="in_data")
tf.unstack(in_data, axis=axis, name="Unpack")
compare_tf_with_tvm([np_data], ['in_data:0'], 'Unpack:0')
def test_forward_unpack():
_test_forward_unpack((3,), 0, 'int32')
_test_forward_unpack((3,), -1, 'int16')
_test_forward_unpack((21, 23, 3), 2, 'float32')
#######################################################################
# Range
# -----
def test_forward_range():
"""test operator Range"""
tf.reset_default_graph()
tf.range(1, 18, 3, name="range")
compare_tf_with_tvm([], [], 'range:0')
####################################################################### #######################################################################
# Pad # Pad
# --- # ---
...@@ -1182,7 +1279,7 @@ def test_forward_logical(): ...@@ -1182,7 +1279,7 @@ def test_forward_logical():
####################################################################### #######################################################################
# Where, Select # Where, Select
# ------------- # -------------
def test_where(): def test_forward_where():
''' Where: return elements depending on conditions''' ''' Where: return elements depending on conditions'''
with tf.Graph().as_default(): with tf.Graph().as_default():
with tf.Session() as sess: with tf.Session() as sess:
...@@ -1553,6 +1650,22 @@ def test_forward_tanh(): ...@@ -1553,6 +1650,22 @@ def test_forward_tanh():
tf.nn.tanh(in1) tf.nn.tanh(in1)
compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Tanh:0') compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Tanh:0')
#######################################################################
# Softmax
# -------
def test_forward_softmax():
"""test operator Softmax """
def check_softmax(in_shape, axis, dtype):
np_data = np.random.uniform(-100, 100, size=in_shape).astype(dtype)
tf.reset_default_graph()
in_data = tf.placeholder(dtype, in_shape, name="in_data")
tf.nn.softmax(in_data, axis=axis, name="Softmax")
compare_tf_with_tvm([np_data], ['in_data:0'], 'Softmax:0')
check_softmax((2, 3, 5), 2, "float32")
check_softmax((2, 3, 5), -1, "float32")
####################################################################### #######################################################################
# Tensor # Tensor
# ------ # ------
...@@ -1565,6 +1678,29 @@ def test_forward_round(): ...@@ -1565,6 +1678,29 @@ def test_forward_round():
tf.round(in_data, name="round") tf.round(in_data, name="round")
compare_tf_with_tvm([np_data], ['in_data:0'], 'round:0') compare_tf_with_tvm([np_data], ['in_data:0'], 'round:0')
def test_forward_abs():
"""test operator Abs"""
np_data = np.random.uniform(1, 100, size=(9, 11)).astype(np.float32)
tf.reset_default_graph()
in_data = tf.placeholder(tf.float32, (9, 11), name="in_data")
tf.math.abs(in_data, name="abs")
compare_tf_with_tvm([np_data], ['in_data:0'], 'abs:0')
def _test_forward_zeros_like(in_shape, dtype):
np_data = np.random.uniform(-10, 10, size=in_shape).astype(dtype)
tf.reset_default_graph()
in_data = tf.placeholder(dtype, in_shape, name="in_data")
tf.zeros_like(in_data, name="zeros_like")
compare_tf_with_tvm([np_data], ['in_data:0'], 'zeros_like:0')
def test_forward_zeros_like():
if tf.__version__ < LooseVersion('1.2'):
_test_forward_zeros_like((2, 3), "int32")
_test_forward_zeros_like((2, 3, 5), "int8")
_test_forward_zeros_like((2, 3, 5, 7), "uint16")
_test_forward_zeros_like((2, 3, 11), "float32")
_test_forward_zeros_like((2, 3, 11), "float64")
def _test_forward_reverse_v2(in_shape, axis, dtype): def _test_forward_reverse_v2(in_shape, axis, dtype):
np_data = np.random.uniform(-10, 10, size=in_shape).astype(dtype) np_data = np.random.uniform(-10, 10, size=in_shape).astype(dtype)
tf.reset_default_graph() tf.reset_default_graph()
...@@ -1588,6 +1724,14 @@ def test_forward_sign(): ...@@ -1588,6 +1724,14 @@ def test_forward_sign():
tf.sign(in_data, name="sign") tf.sign(in_data, name="sign")
compare_tf_with_tvm([np_data], ['in_data:0'], 'sign:0') compare_tf_with_tvm([np_data], ['in_data:0'], 'sign:0')
def test_forward_square():
"""test operator Square """
np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32)
tf.reset_default_graph()
in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data")
tf.square(in_data, name="square")
compare_tf_with_tvm([np_data], ['in_data:0'], 'square:0')
def test_forward_pow_exp(): def test_forward_pow_exp():
"""test Pow and Exp """ """test Pow and Exp """
np_in1 = np.random.uniform(-2, 2, size=(5, 7, 11)).astype(np.float32) np_in1 = np.random.uniform(-2, 2, size=(5, 7, 11)).astype(np.float32)
...@@ -1616,6 +1760,14 @@ def test_forward_negative(): ...@@ -1616,6 +1760,14 @@ def test_forward_negative():
tf.negative(in_data, name="negative") tf.negative(in_data, name="negative")
compare_tf_with_tvm([np_data], ['in_data:0'], 'negative:0') compare_tf_with_tvm([np_data], ['in_data:0'], 'negative:0')
def test_forward_log_softmax():
"""test operator LogSoftmax"""
np_data = np.random.uniform(1, 100, size=(9, 11)).astype(np.float32)
tf.reset_default_graph()
in_data = tf.placeholder(tf.float32, (9, 11), name="in_data")
tf.math.log_softmax(in_data, name="LogSoftmax")
compare_tf_with_tvm([np_data], ['in_data:0'], 'LogSoftmax:0')
def test_forward_softplus(): def test_forward_softplus():
"""test operator Softplus""" """test operator Softplus"""
np_data = np.random.uniform(1, 10, size=(2, 3, 5)).astype(np.float32) np_data = np.random.uniform(1, 10, size=(2, 3, 5)).astype(np.float32)
...@@ -1640,6 +1792,34 @@ def test_forward_sqrt(): ...@@ -1640,6 +1792,34 @@ def test_forward_sqrt():
tf.sqrt(in_data, name="sqrt") tf.sqrt(in_data, name="sqrt")
compare_tf_with_tvm([np_data], ['in_data:0'], 'sqrt:0') compare_tf_with_tvm([np_data], ['in_data:0'], 'sqrt:0')
def _test_forward_right_shift(in_shape, dtype):
"""test operator RightShift"""
lh_data = np.random.randint(1, 3, size=in_shape).astype(dtype)
rh_data = np.random.randint(1, 8, size=in_shape).astype(dtype)
tf.reset_default_graph()
lft_data = tf.placeholder(dtype, in_shape, name="lft_data")
rgt_data = tf.placeholder(dtype, in_shape, name="rgt_data")
tf.bitwise.right_shift(lft_data, rgt_data, name="RightShift")
compare_tf_with_tvm([lh_data, rh_data], ['lft_data:0', 'rgt_data:0'], 'RightShift:0')
def test_forward_right_shift():
_test_forward_right_shift((7,), 'int32')
_test_forward_right_shift((3, 11), 'int16')
def _test_forward_left_shift(in_shape, dtype):
"""test operator LeftShift"""
lh_data = np.random.randint(100, 1000000, size=in_shape).astype(dtype)
rh_data = np.random.randint(1, 3, size=in_shape).astype(dtype)
tf.reset_default_graph()
lft_data = tf.placeholder(dtype, in_shape, name="lft_data")
rgt_data = tf.placeholder(dtype, in_shape, name="rgt_data")
tf.bitwise.left_shift(lft_data, rgt_data, name="LeftShift")
compare_tf_with_tvm([lh_data, rh_data], ['lft_data:0', 'rgt_data:0'], 'LeftShift:0')
def test_forward_left_shift():
_test_forward_left_shift((10,), 'int32')
_test_forward_left_shift((224, 224, 3), 'int16')
####################################################################### #######################################################################
# Mean # Mean
# ---- # ----
...@@ -1652,13 +1832,13 @@ def test_forward_mean(): ...@@ -1652,13 +1832,13 @@ def test_forward_mean():
compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Mean:0', no_gpu=True) compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Mean:0', no_gpu=True)
check_mean((10, 8, 16, 32)) check_mean((10, 8, 16, 32))
check_mean((10, 8, 16, 32), axis=(2,3)) check_mean((10, 8, 16, 32), axis=(2, 3))
check_mean((10, 8, 16, 32), axis=(1,2), keepdims=True) check_mean((10, 8, 16, 32), axis=(1, 2), keepdims=True)
####################################################################### #######################################################################
# All # All, Max, Min
# --- # -------------
def test_forward_all(): def test_forward_reduce_all():
"""Test the All operator.""" """Test the All operator."""
np_data = np.random.choice([True, False], size=(5, 7, 11)) np_data = np.random.choice([True, False], size=(5, 7, 11))
tf.reset_default_graph() tf.reset_default_graph()
...@@ -1666,6 +1846,30 @@ def test_forward_all(): ...@@ -1666,6 +1846,30 @@ def test_forward_all():
tf.reduce_all(in_data, name="all") tf.reduce_all(in_data, name="all")
compare_tf_with_tvm([np_data], ['in_data:0'], 'all:0') compare_tf_with_tvm([np_data], ['in_data:0'], 'all:0')
def test_forward_reduce_max():
def check_max(ishape, axis, keepdims, dtype):
tf.reset_default_graph()
np_data = np.random.uniform(size=ishape).astype(dtype)
in_data = tf.placeholder(dtype, name="in_data")
tf.math.reduce_max(in_data, axis=axis, keepdims=keepdims, name="reduce_max")
compare_tf_with_tvm([np_data], ['in_data:0'], 'reduce_max:0')
check_max((10, 8, 16, 32), axis=(-1), keepdims=True, dtype="int32")
check_max((10, 8, 16, 32), axis=(2, 3), keepdims=True, dtype="float32")
check_max((10, 8, 16, 32), axis=(1, 2), keepdims=True, dtype='float32')
def test_forward_reduce_min():
def check_min(ishape, axis, keepdims, dtype):
tf.reset_default_graph()
np_data = np.random.uniform(size=ishape).astype(dtype)
in_data = tf.placeholder(dtype, name="in_data")
tf.math.reduce_min(in_data, axis=axis, keepdims=keepdims, name="reduce_max")
compare_tf_with_tvm([np_data], ['in_data:0'], 'reduce_max:0')
check_min((10, 8, 16, 32), axis=(-1), keepdims=True, dtype="int32")
check_min((10, 8, 16, 32), axis=(2, 3), keepdims=True, dtype="float32")
check_min((10, 8, 16, 32), axis=(1, 2), keepdims=True, dtype='float32')
####################################################################### #######################################################################
# Relational operators # Relational operators
# -------------------- # --------------------
...@@ -1724,6 +1928,38 @@ def test_forward_reduce_prod(): ...@@ -1724,6 +1928,38 @@ def test_forward_reduce_prod():
####################################################################### #######################################################################
# Maximum, Minimum
# ----------------
def test_forward_maximum():
"""test Op Maximum"""
def check_maximum(lh_shape, rh_shape, dtype):
tf.reset_default_graph()
lh_data = np.random.uniform(size=lh_shape).astype(dtype)
rh_data = np.random.uniform(size=rh_shape).astype(dtype)
lft_data = tf.placeholder(dtype, name="lft_data")
rgt_data = tf.placeholder(dtype, name="rgt_data")
tf.math.maximum(lft_data, rgt_data, name="maximum")
compare_tf_with_tvm([lh_data, rh_data], ['lft_data:0', 'rgt_data:0'], 'maximum:0')
check_maximum((10, 8, 16, 32), (1,), dtype="int32")
check_maximum((10, 8, 16, 32), (10, 8, 16, 32), dtype="float32")
def test_forward_minimum():
"""test Op Minimum"""
def check_minimum(lh_shape, rh_shape, dtype):
tf.reset_default_graph()
lh_data = np.random.uniform(size=lh_shape).astype(dtype)
rh_data = np.random.uniform(size=rh_shape).astype(dtype)
lft_data = tf.placeholder(dtype, name="lft_data")
rgt_data = tf.placeholder(dtype, name="rgt_data")
tf.math.minimum(lft_data, rgt_data, name="minimum")
compare_tf_with_tvm([lh_data, rh_data], ['lft_data:0', 'rgt_data:0'], 'minimum:0')
check_minimum((10, 8, 16, 32), (1,), dtype="int32")
check_minimum((10, 8, 16, 32), (10, 8, 16, 32), dtype="float32")
#######################################################################
# PlaceholderWithDefault # PlaceholderWithDefault
# ---------------------- # ----------------------
def test_placeholder(): def test_placeholder():
...@@ -1740,6 +1976,7 @@ def test_placeholder(): ...@@ -1740,6 +1976,7 @@ def test_placeholder():
compare_tf_with_tvm([in_data1, in_data2], ['place1:0', 'in2:0'], 'out2:0', init_global_variables=True) compare_tf_with_tvm([in_data1, in_data2], ['place1:0', 'in2:0'], 'out2:0', init_global_variables=True)
####################################################################### #######################################################################
# Main # Main
# ---- # ----
...@@ -1756,14 +1993,22 @@ if __name__ == '__main__': ...@@ -1756,14 +1993,22 @@ if __name__ == '__main__':
test_forward_fill() test_forward_fill()
test_forward_crop() test_forward_crop()
test_forward_pad() test_forward_pad()
test_forward_unpack()
test_forward_gather() test_forward_gather()
test_forward_gather_v1() test_forward_gather_v1()
test_forward_gather_nd()
test_forward_stridedslice() test_forward_stridedslice()
test_forward_split() test_forward_split()
test_forward_unstack() test_forward_unstack()
test_forward_tile() test_forward_tile()
test_forward_top_k_v2() test_forward_top_k_v2()
test_forward_clip_by_value() test_forward_clip_by_value()
test_forward_maximum()
test_forward_minimum()
test_forward_range()
test_forward_right_shift()
test_forward_left_shift()
test_forward_truncatemod()
# Activations # Activations
test_forward_sigmoid() test_forward_sigmoid()
...@@ -1780,17 +2025,26 @@ if __name__ == '__main__': ...@@ -1780,17 +2025,26 @@ if __name__ == '__main__':
test_forward_sign() test_forward_sign()
test_forward_log() test_forward_log()
test_forward_negative() test_forward_negative()
test_forward_divide()
test_forward_abs()
test_forward_softplus() test_forward_softplus()
test_forward_sqrt() test_forward_sqrt()
test_forward_rsqrt() test_forward_rsqrt()
test_forward_expand_dims() test_forward_expand_dims()
test_forward_square()
test_forward_softmax()
test_forward_log_softmax()
test_forward_bias_add()
test_forward_zeros_like()
# Reductions # Reductions
test_forward_argminmax() test_forward_argminmax()
test_forward_reduce() test_forward_reduce()
test_forward_mean() test_forward_mean()
test_forward_reduce_prod() test_forward_reduce_prod()
test_forward_all() test_forward_reduce_all()
test_forward_reduce_max()
test_forward_reduce_min()
# General # General
test_forward_multi_input() test_forward_multi_input()
...@@ -1826,7 +2080,7 @@ if __name__ == '__main__': ...@@ -1826,7 +2080,7 @@ if __name__ == '__main__':
# Relational ops # Relational ops
test_forward_rel_ops() test_forward_rel_ops()
test_forward_logical() test_forward_logical()
test_where() test_forward_where()
test_forward_matmul() test_forward_matmul()
# TODO missing tests: rank, range # TODO missing tests: rank, range
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment