Commit 463e5c38 by Tatsuya Nishiyama Committed by Yizhi Liu

[FRONTEND][TENSORFLOW] fix the convertion of sum and add testcase for it (#1654)

* [TENSORFLOW] fix the convertion of sum and add testcase for it

* delete checking tyoe of axis and divide reduce test
parent a9e0567d
...@@ -444,6 +444,8 @@ def _lrn(): ...@@ -444,6 +444,8 @@ def _lrn():
def _sum(): def _sum():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
axis = params.pop(inputs[1].list_output_names()[0]).asnumpy() axis = params.pop(inputs[1].list_output_names()[0]).asnumpy()
# convert to tuple for preventing invalid parameter format error
axis = tuple(axis)
return AttrCvt( return AttrCvt(
op_name='sum', op_name='sum',
extras={'axis': axis}, extras={'axis': axis},
......
...@@ -349,6 +349,26 @@ def test_forward_argminmax(): ...@@ -349,6 +349,26 @@ def test_forward_argminmax():
_test_argx(tf.argmin, data=data, axis=axis) _test_argx(tf.argmin, data=data, axis=axis)
####################################################################### #######################################################################
# Reduce
# ------
def _test_reduce(func, data, **kwargs):
""" One iteration of a reduce operation"""
with tf.Graph().as_default():
inp = array_ops.placeholder(shape=data.shape, dtype=data.dtype, name="c0")
func(inp, name="reducex0", **kwargs)
compare_tf_with_tvm(data, 'c0:0', 'reducex0:0')
def test_forward_reduce():
data = np.random.uniform(size=(8,4,9)).astype('float32')
_test_reduce(tf.reduce_sum, data=data)
_test_reduce(tf.reduce_sum, data=data, axis=0)
_test_reduce(tf.reduce_sum, data=data, axis=(0,1))
#######################################################################
# Variable # Variable
# -------- # --------
...@@ -844,6 +864,7 @@ if __name__ == '__main__': ...@@ -844,6 +864,7 @@ if __name__ == '__main__':
test_forward_squeeze() test_forward_squeeze()
test_forward_sigmoid() test_forward_sigmoid()
test_forward_argminmax() test_forward_argminmax()
test_forward_reduce()
if tf.__version__ == '1.4.1': if tf.__version__ == '1.4.1':
_test_forward_concat_v2() _test_forward_concat_v2()
test_forward_multi_input() test_forward_multi_input()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment