Commit 463e5c38 by Tatsuya Nishiyama Committed by Yizhi Liu

[FRONTEND][TENSORFLOW] fix the convertion of sum and add testcase for it (#1654)

* [TENSORFLOW] fix the convertion of sum and add testcase for it

* delete checking tyoe of axis and divide reduce test
parent a9e0567d
......@@ -444,6 +444,8 @@ def _lrn():
def _sum():
def _impl(inputs, attr, params):
axis = params.pop(inputs[1].list_output_names()[0]).asnumpy()
# convert to tuple for preventing invalid parameter format error
axis = tuple(axis)
return AttrCvt(
op_name='sum',
extras={'axis': axis},
......
......@@ -349,6 +349,26 @@ def test_forward_argminmax():
_test_argx(tf.argmin, data=data, axis=axis)
#######################################################################
# Reduce
# ------
def _test_reduce(func, data, **kwargs):
""" One iteration of a reduce operation"""
with tf.Graph().as_default():
inp = array_ops.placeholder(shape=data.shape, dtype=data.dtype, name="c0")
func(inp, name="reducex0", **kwargs)
compare_tf_with_tvm(data, 'c0:0', 'reducex0:0')
def test_forward_reduce():
data = np.random.uniform(size=(8,4,9)).astype('float32')
_test_reduce(tf.reduce_sum, data=data)
_test_reduce(tf.reduce_sum, data=data, axis=0)
_test_reduce(tf.reduce_sum, data=data, axis=(0,1))
#######################################################################
# Variable
# --------
......@@ -844,6 +864,7 @@ if __name__ == '__main__':
test_forward_squeeze()
test_forward_sigmoid()
test_forward_argminmax()
test_forward_reduce()
if tf.__version__ == '1.4.1':
_test_forward_concat_v2()
test_forward_multi_input()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment