Commit 05c77280 by zhengdi Committed by Siva

[TEST][TENSORFLOW] clean up code (#3342)

parent 8a2dcf1f
...@@ -22,8 +22,6 @@ This article is a test script to test tensorflow operator with Relay. ...@@ -22,8 +22,6 @@ This article is a test script to test tensorflow operator with Relay.
""" """
from __future__ import print_function from __future__ import print_function
import numpy as np import numpy as np
import tvm
from tvm import relay
import tensorflow as tf import tensorflow as tf
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import graph_util from tensorflow.python.framework import graph_util
...@@ -35,8 +33,9 @@ from tensorflow.python.ops import math_ops ...@@ -35,8 +33,9 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.ops import init_ops from tensorflow.python.ops import init_ops
from distutils.version import LooseVersion from distutils.version import LooseVersion
import tvm
from tvm import relay
import tvm.relay.testing.tf as tf_testing import tvm.relay.testing.tf as tf_testing
####################################################################### #######################################################################
...@@ -179,46 +178,46 @@ def test_forward_pooling(): ...@@ -179,46 +178,46 @@ def test_forward_pooling():
""" Pooling """ """ Pooling """
for pool_type in ['AVG', 'MAX']: for pool_type in ['AVG', 'MAX']:
_test_pooling(input_shape=[2, 9, 10, 2], _test_pooling(input_shape=[2, 9, 10, 2],
window_shape=[1, 1], window_shape=[1, 1],
padding='SAME', padding='SAME',
pooling_type=pool_type, pooling_type=pool_type,
dilation_rate=[1, 1], dilation_rate=[1, 1],
strides=[1, 1]) strides=[1, 1])
_test_pooling(input_shape=[2, 10, 9, 2], _test_pooling(input_shape=[2, 10, 9, 2],
window_shape=[1, 1], window_shape=[1, 1],
padding='SAME', padding='SAME',
pooling_type=pool_type, pooling_type=pool_type,
dilation_rate=[1, 1], dilation_rate=[1, 1],
strides=[1, 1]) strides=[1, 1])
_test_pooling(input_shape=[2, 9, 10, 2], _test_pooling(input_shape=[2, 9, 10, 2],
window_shape=[2, 1], window_shape=[2, 1],
padding='SAME', padding='SAME',
pooling_type=pool_type, pooling_type=pool_type,
dilation_rate=[1, 1], dilation_rate=[1, 1],
strides=[1, 1]) strides=[1, 1])
_test_pooling(input_shape=[2, 10, 9, 2], _test_pooling(input_shape=[2, 10, 9, 2],
window_shape=[2, 3], window_shape=[2, 3],
padding='SAME', padding='SAME',
pooling_type=pool_type, pooling_type=pool_type,
dilation_rate=[1, 1], dilation_rate=[1, 1],
strides=[2, 1]) strides=[2, 1])
# Tests involving SpaceToBatchND # Tests involving SpaceToBatchND
_test_pooling(input_shape=[1, 1, 2, 1], _test_pooling(input_shape=[1, 1, 2, 1],
window_shape=[1, 1], window_shape=[1, 1],
padding='VALID', padding='VALID',
pooling_type=pool_type, pooling_type=pool_type,
dilation_rate=[1, 2]) dilation_rate=[1, 2])
_test_pooling(input_shape=[1, 2, 1], _test_pooling(input_shape=[1, 2, 1],
window_shape=[1], window_shape=[1],
padding='VALID', padding='VALID',
pooling_type=pool_type, pooling_type=pool_type,
dilation_rate=[2]) dilation_rate=[2])
####################################################################### #######################################################################
# Convolution # Convolution
...@@ -461,24 +460,29 @@ def test_forward_squeeze(): ...@@ -461,24 +460,29 @@ def test_forward_squeeze():
# ConcatV2 # ConcatV2
# -------- # --------
def _test_concat_v2(data, dim): def _test_concat_v2(shape1, shape2, dim):
""" One iteration of ConcatV2 """ """ One iteration of ConcatV2 """
with tf.Graph().as_default(): with tf.Graph().as_default():
gen_array_ops._concat_v2(data, dim) dtype = 'float32'
in1 = tf.placeholder(shape=shape1, dtype=dtype, name='in1')
in2 = tf.placeholder(shape=shape2, dtype=dtype, name='in2')
array_ops.concat_v2([in1, in2], dim)
compare_tf_with_tvm(data, ['ConcatV2/values_0:0', 'ConcatV2/values_1:0'], np_data1 = np.random.uniform(size=shape1).astype(dtype)
'ConcatV2:0') np_data2 = np.random.uniform(size=shape2).astype(dtype)
def _test_forward_concat_v2(): compare_tf_with_tvm([np_data1, np_data2], ['in1:0', 'in2:0'], 'ConcatV2:0')
t1 = np.array([])
t2 = np.array([])
_test_concat_v2([t1, t2], 0)
t1 = np.array([[1, 2, 3], [4, 5, 6]]) def test_forward_concat_v2():
t2 = np.array([[7, 8, 9], [10, 11, 12]]) if tf.__version__ < LooseVersion('1.4.1'):
return
_test_concat_v2([t1, t2], 1) _test_concat_v2([2, 3], [2, 3], 0)
_test_concat_v2([10, 3, 5], [2, 3, 5], 0)
_test_concat_v2([2, 3], [2, 3], 1)
_test_concat_v2([5, 8], [5, 4], 1)
_test_concat_v2([2, 8, 5], [2, 8, 6], -1)
####################################################################### #######################################################################
# Sigmoid # Sigmoid
...@@ -511,8 +515,8 @@ def _test_argx(func, data, **kwargs): ...@@ -511,8 +515,8 @@ def _test_argx(func, data, **kwargs):
compare_tf_with_tvm(data, 'c0:0', 'argx0:0') compare_tf_with_tvm(data, 'c0:0', 'argx0:0')
def test_forward_argminmax(): def test_forward_argminmax():
for axis in [None,0,1,2]: for axis in [None, 0, 1, 2]:
data = np.random.uniform(size=(8,4,9)).astype('float32') data = np.random.uniform(size=(8, 4, 9)).astype('float32')
_test_argx(tf.argmax, data=data, axis=axis) _test_argx(tf.argmax, data=data, axis=axis)
_test_argx(tf.argmin, data=data, axis=axis) _test_argx(tf.argmin, data=data, axis=axis)
...@@ -530,10 +534,10 @@ def _test_reduce(func, data, **kwargs): ...@@ -530,10 +534,10 @@ def _test_reduce(func, data, **kwargs):
compare_tf_with_tvm(data, 'c0:0', 'reducex0:0') compare_tf_with_tvm(data, 'c0:0', 'reducex0:0')
def test_forward_reduce(): def test_forward_reduce():
data = np.random.uniform(size=(8,4,9)).astype('float32') data = np.random.uniform(size=(8, 4, 9)).astype('float32')
_test_reduce(tf.reduce_sum, data=data) _test_reduce(tf.reduce_sum, data=data)
_test_reduce(tf.reduce_sum, data=data, axis=0) _test_reduce(tf.reduce_sum, data=data, axis=0)
_test_reduce(tf.reduce_sum, data=data, axis=(0,1)) _test_reduce(tf.reduce_sum, data=data, axis=(0, 1))
####################################################################### #######################################################################
...@@ -597,16 +601,16 @@ def test_forward_matmul(): ...@@ -597,16 +601,16 @@ def test_forward_matmul():
# ------------ # ------------
def _test_stridedslice(ip_shape, begin, end, stride, dtype, def _test_stridedslice(ip_shape, begin, end, stride, dtype,
begin_mask=0, end_mask=0, new_axis_mask=0, begin_mask=0, end_mask=0, new_axis_mask=0,
shrink_axis_mask=0, ellipsis_mask=0): shrink_axis_mask=0, ellipsis_mask=0):
""" One iteration of a Stridedslice """ """ One iteration of a Stridedslice """
tf.reset_default_graph() tf.reset_default_graph()
in_data = tf.placeholder(dtype, ip_shape, name="in_data") in_data = tf.placeholder(dtype, ip_shape, name="in_data")
tf.strided_slice(in_data, begin, end, stride, begin_mask=begin_mask, tf.strided_slice(in_data, begin, end, stride, begin_mask=begin_mask,
end_mask=end_mask, new_axis_mask=new_axis_mask, end_mask=end_mask, new_axis_mask=new_axis_mask,
shrink_axis_mask=shrink_axis_mask, shrink_axis_mask=shrink_axis_mask,
ellipsis_mask=ellipsis_mask, name="strided_slice") ellipsis_mask=ellipsis_mask, name="strided_slice")
np_data = np.random.uniform(size=ip_shape).astype(dtype) np_data = np.random.uniform(size=ip_shape).astype(dtype)
compare_tf_with_tvm(np_data, 'in_data:0', 'strided_slice:0') compare_tf_with_tvm(np_data, 'in_data:0', 'strided_slice:0')
...@@ -621,26 +625,39 @@ def test_forward_stridedslice(): ...@@ -621,26 +625,39 @@ def test_forward_stridedslice():
_test_stridedslice((3, 4, 5, 3), [1, 0], [4, 2], [2, 1], 'float32', ellipsis_mask=2) _test_stridedslice((3, 4, 5, 3), [1, 0], [4, 2], [2, 1], 'float32', ellipsis_mask=2)
_test_stridedslice((3, 4, 5, 3), [1, 0, 1], [4, 2, 2], [2, 1, 1], 'float32', ellipsis_mask=2) _test_stridedslice((3, 4, 5, 3), [1, 0, 1], [4, 2, 2], [2, 1, 1], 'float32', ellipsis_mask=2)
_test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 2], [2, 1, 1], 'float32', new_axis_mask=5) _test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 2], [2, 1, 1], 'float32', new_axis_mask=5)
_test_stridedslice((3, 4, 3), [1, 1, 1], [4, 4, 1], [2, 1, 1], 'float32', ellipsis_mask=2, new_axis_mask=4) _test_stridedslice((3, 4, 3), [1, 1, 1], [4, 4, 1], [2, 1, 1], 'float32', ellipsis_mask=2,
_test_stridedslice((6, 4, 5), [1, 1, 1], [6, 3, 4], [2, 1, 1], 'float32', ellipsis_mask=2, new_axis_mask=5) new_axis_mask=4)
_test_stridedslice((3, 4, 3), [1, 1, 2], [4, 4, 3], [2, 1, 1], 'float32', ellipsis_mask=4, new_axis_mask=2) _test_stridedslice((6, 4, 5), [1, 1, 1], [6, 3, 4], [2, 1, 1], 'float32', ellipsis_mask=2,
_test_stridedslice((3, 4, 3), [1, 1, 2], [4, 4, 3], [2, 1, 1], 'float32', ellipsis_mask=2, new_axis_mask=3) new_axis_mask=5)
_test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 1], [2, 1, 1], 'float32', ellipsis_mask=2, new_axis_mask=3) _test_stridedslice((3, 4, 3), [1, 1, 2], [4, 4, 3], [2, 1, 1], 'float32', ellipsis_mask=4,
_test_stridedslice((3, 4, 3), [1, 1, 2], [4, 4, 3], [2, 1, 1], 'float32', ellipsis_mask=2, new_axis_mask=2) new_axis_mask=2)
_test_stridedslice((3,4), [1, 0], [4, 4], [1, 1], 'float32', shrink_axis_mask=2) _test_stridedslice((3, 4, 3), [1, 1, 2], [4, 4, 3], [2, 1, 1], 'float32', ellipsis_mask=2,
_test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], 'float32', shrink_axis_mask=2, new_axis_mask=2) new_axis_mask=3)
_test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], 'float32', shrink_axis_mask=1, new_axis_mask=2) _test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 1], [2, 1, 1], 'float32', ellipsis_mask=2,
_test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], 'float32', shrink_axis_mask=2, new_axis_mask=1) new_axis_mask=3)
_test_stridedslice((3, 4, 5, 4, 5, 6), [0, 0], [2, 3], [1, 1], 'float32', shrink_axis_mask=5, new_axis_mask=1) _test_stridedslice((3, 4, 3), [1, 1, 2], [4, 4, 3], [2, 1, 1], 'float32', ellipsis_mask=2,
new_axis_mask=2)
_test_stridedslice((3, 4), [1, 0], [4, 4], [1, 1], 'float32', shrink_axis_mask=2)
_test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], 'float32', shrink_axis_mask=2,
new_axis_mask=2)
_test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], 'float32', shrink_axis_mask=1,
new_axis_mask=2)
_test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], 'float32', shrink_axis_mask=2,
new_axis_mask=1)
_test_stridedslice((3, 4, 5, 4, 5, 6), [0, 0], [2, 3], [1, 1], 'float32', shrink_axis_mask=5,
new_axis_mask=1)
_test_stridedslice((3, 4, 5, 4, 5, 6), [0, 0, 1, 2, 1], [2, 3, 4, 5, 3], [1, 1, 2, 2, 1], _test_stridedslice((3, 4, 5, 4, 5, 6), [0, 0, 1, 2, 1], [2, 3, 4, 5, 3], [1, 1, 2, 2, 1],
'float32', shrink_axis_mask=5, new_axis_mask=1, ellipsis_mask=2, begin_mask=8, end_mask=8) 'float32', shrink_axis_mask=5, new_axis_mask=1, ellipsis_mask=2,
begin_mask=8, end_mask=8)
_test_stridedslice((3, 4, 5, 4, 5, 6), [0, 0, 1, 2, 1], [2, 3, 4, 5, 3], [1, 1, 2, 2, 1], _test_stridedslice((3, 4, 5, 4, 5, 6), [0, 0, 1, 2, 1], [2, 3, 4, 5, 3], [1, 1, 2, 2, 1],
'float32', shrink_axis_mask=8, new_axis_mask=1, ellipsis_mask=2, begin_mask=5, end_mask=5) 'float32', shrink_axis_mask=8, new_axis_mask=1, ellipsis_mask=2,
begin_mask=5, end_mask=5)
_test_stridedslice((3, 4, 5, 4, 5, 6), [0, 0, 1, 2, 1], [2, 3, 4, 5, 3], [1, 1, 2, 2, 1], _test_stridedslice((3, 4, 5, 4, 5, 6), [0, 0, 1, 2, 1], [2, 3, 4, 5, 3], [1, 1, 2, 2, 1],
'float32', shrink_axis_mask=16, new_axis_mask=1, ellipsis_mask=2, begin_mask=5, end_mask=5) 'float32', shrink_axis_mask=16, new_axis_mask=1, ellipsis_mask=2,
begin_mask=5, end_mask=5)
_test_stridedslice((3, 4, 5, 4, 5, 6), [1, 2, 0, -3], [4, 5, 3, 3], [2, 2, 1, 1], _test_stridedslice((3, 4, 5, 4, 5, 6), [1, 2, 0, -3], [4, 5, 3, 3], [2, 2, 1, 1],
'float32', shrink_axis_mask=8, new_axis_mask=1, ellipsis_mask=2, begin_mask=5, 'float32', shrink_axis_mask=8, new_axis_mask=1, ellipsis_mask=2,
end_mask=8) begin_mask=5, end_mask=8)
####################################################################### #######################################################################
# FloorDiv, RealDiv # FloorDiv, RealDiv
...@@ -696,7 +713,7 @@ def _test_gather(ip_shape, indice_shape, indice_value, axis, dtype): ...@@ -696,7 +713,7 @@ def _test_gather(ip_shape, indice_shape, indice_value, axis, dtype):
tf.reset_default_graph() tf.reset_default_graph()
in_data = tf.placeholder(dtype, ip_shape, name="in_data") in_data = tf.placeholder(dtype, ip_shape, name="in_data")
indices = tf.placeholder("int32", indice_shape, name="indices") indices = tf.placeholder("int32", indice_shape, name="indices")
tf.gather(in_data, indices, axis=axis) out = tf.gather(in_data, indices, axis=axis)
np_data = np.random.uniform(1, 10, size=ip_shape).astype(dtype) np_data = np.random.uniform(1, 10, size=ip_shape).astype(dtype)
def _fill_indices(indice_value): def _fill_indices(indice_value):
...@@ -708,59 +725,21 @@ def _test_gather(ip_shape, indice_shape, indice_value, axis, dtype): ...@@ -708,59 +725,21 @@ def _test_gather(ip_shape, indice_shape, indice_value, axis, dtype):
return indices return indices
np_indices = _fill_indices(indice_value) np_indices = _fill_indices(indice_value)
compare_tf_with_tvm([np_data, np_indices], ['in_data:0', 'indices:0'], 'GatherV2:0') compare_tf_with_tvm([np_data, np_indices], ['in_data:0', 'indices:0'], out.name)
def test_forward_gather(): def test_forward_gather():
'''test GatherV2 layer''' '''test Gather/GatherV2 layer'''
_test_gather((4,), (1,), 1, 0, 'int32') _test_gather((4,), (1,), 1, 0, 'int32')
_test_gather((4,), (1,), 1, 0, 'float32') _test_gather((4,), (1,), 1, 0, 'float32')
_test_gather((1, 4), (1,), [0], 0, 'int32') _test_gather((1, 4), (1,), [0], 0, 'int32')
_test_gather((4,), (1, 2, 2), [[[1, 0],[0, 1]]], 0, 'float32') _test_gather((4,), (1, 2, 2), [[[1, 0], [0, 1]]], 0, 'float32')
_test_gather((2, 2), (1, 2, 2), [[[1, 0],[0, 1]]], 0, 'int32') _test_gather((2, 2), (1, 2, 2), [[[1, 0], [0, 1]]], 0, 'int32')
_test_gather((2, 2), (1, 2, 2), [[[1, 0],[0, 1]]], 1, 'int32') _test_gather((2, 2), (1, 2, 2), [[[1, 0], [0, 1]]], 1, 'int32')
_test_gather((2, 2), (1, 2, 2), [[[1, 0],[0, 1]]], 0, 'float32') _test_gather((2, 2), (1, 2, 2), [[[1, 0], [0, 1]]], 0, 'float32')
_test_gather((3, 3, 3), (1, 1, 2), [[[1, 0]]], 0, 'int32') _test_gather((3, 3, 3), (1, 1, 2), [[[1, 0]]], 0, 'int32')
_test_gather((3, 3, 3), (1, 1, 2), [[[1, 0]]], 2, 'int32') _test_gather((3, 3, 3), (1, 1, 2), [[[1, 0]]], 2, 'int32')
_test_gather((4, 3, 5, 6), (1, 4), [[2, 1, 0, 0]], 0, 'float32') _test_gather((4, 3, 5, 6), (1, 4), [[2, 1, 0, 0]], 0, 'float32')
def _test_gather_v1(ip_shape, indice_shape, indice_value, dtype):
""" One iteration of a Gather"""
tf.reset_default_graph()
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
indices = tf.placeholder("int32", indice_shape, name="indices")
tf.gather(in_data, indices)
np_data = np.random.uniform(size=ip_shape).astype(dtype)
def _fill_indices(indice_value):
indices = np.array(ip_shape, dtype=dtype)
if isinstance(indice_value, int):
indices = np.array([indice_value], dtype='int32')
else:
indices = np.asarray(indice_value, dtype='int32')
return indices
np_indices = _fill_indices(indice_value)
compare_tf_with_tvm([np_data, np_indices], ['in_data:0', 'indices:0'], 'Gather:0')
def test_forward_gather_v1():
'''test gather layer'''
if tf.__version__ < LooseVersion('1.7'):
_test_gather_v1((4,), (1, 2, 2), [[[1, 0], [0, 1]]], 'float32')
_test_gather_v1((4,), (1,), 1, 'int32')
_test_gather_v1((4,), (1,), 1, 'float32')
_test_gather_v1((1, 4), (1,), [0], 'int32')
_test_gather_v1((4,), (1, 2, 2), [[[1, 0], [0, 1]]], 'float32')
_test_gather_v1((2, 2), (1, 2, 2), [[[1, 0], [0, 1]]], 'int32')
_test_gather_v1((2, 2), (1, 2, 2), [[[1, 0], [0, 1]]], 'int32')
_test_gather_v1((2, 2), (1, 2, 2), [[[1, 0], [0, 1]]], 'float32')
_test_gather_v1((3, 3, 3), (1, 1, 2), [[[1, 0]]], 'int32')
_test_gather_v1((3, 3, 3), (1, 1, 2), [[[1, 0]]], 'int32')
_test_gather_v1((4, 3, 5, 6), (1, 4), [[2, 1, 0, 0]], 'float32')
def test_forward_gather_nd(): def test_forward_gather_nd():
"""test operator GatherNd""" """test operator GatherNd"""
np_data = np.random.uniform(1, 100, size=(2, 2)).astype(np.float32) np_data = np.random.uniform(1, 100, size=(2, 2)).astype(np.float32)
...@@ -798,7 +777,8 @@ def _test_split(in_shape, axis, num_or_size_splits, dtype): ...@@ -798,7 +777,8 @@ def _test_split(in_shape, axis, num_or_size_splits, dtype):
""" One iteration of a Split """ """ One iteration of a Split """
tf.reset_default_graph() tf.reset_default_graph()
in_data = tf.placeholder(dtype, in_shape, name="in_data") in_data = tf.placeholder(dtype, in_shape, name="in_data")
num_split = len(num_or_size_splits) if isinstance(num_or_size_splits, list) else num_or_size_splits num_split = len(num_or_size_splits) if isinstance(num_or_size_splits, list)\
else num_or_size_splits
tf.split(in_data, num_or_size_splits, axis=axis) tf.split(in_data, num_or_size_splits, axis=axis)
compare_tf_with_tvm([np_data], ['in_data:0'], [f'split:{n}' for n in range(num_split)]) compare_tf_with_tvm([np_data], ['in_data:0'], [f'split:{n}' for n in range(num_split)])
...@@ -1116,10 +1096,10 @@ def _test_lstm_cell(batch_size, num_hidden, num_layers, forget_bias, dtype): ...@@ -1116,10 +1096,10 @@ def _test_lstm_cell(batch_size, num_hidden, num_layers, forget_bias, dtype):
def _get_tensorflow_output(): def _get_tensorflow_output():
with tf.Session() as sess: with tf.Session() as sess:
with variable_scope.variable_scope( with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)): "root", initializer=init_ops.constant_initializer(0.5)):
m0 = array_ops.zeros([batch_size, num_hidden]) m0 = array_ops.zeros([batch_size, num_hidden])
m1 = array_ops.zeros([batch_size, num_hidden]) m1 = array_ops.zeros([batch_size, num_hidden])
x=tf.placeholder(shape=(batch_size, input_size), dtype=dtype) x = tf.placeholder(shape=(batch_size, input_size), dtype=dtype)
g, ((out_m0, out_m1)) = \ g, ((out_m0, out_m1)) = \
tf.contrib.rnn.LSTMBlockCell(num_hidden, tf.contrib.rnn.LSTMBlockCell(num_hidden,
forget_bias=forget_bias)(x, ((m0, m1))) forget_bias=forget_bias)(x, ((m0, m1)))
...@@ -1167,15 +1147,15 @@ def _test_pack(axis, shape, **kwargs): ...@@ -1167,15 +1147,15 @@ def _test_pack(axis, shape, **kwargs):
with tf.Graph().as_default(): with tf.Graph().as_default():
tf_a = array_ops.placeholder(shape=shape, dtype='float32', name='pl_a') tf_a = array_ops.placeholder(shape=shape, dtype='float32', name='pl_a')
tf_b = array_ops.placeholder(shape=shape, dtype='float32', name='pl_b') tf_b = array_ops.placeholder(shape=shape, dtype='float32', name='pl_b')
tf_c = tf.stack([tf_a,tf_b], axis=axis, **kwargs) tf_c = tf.stack([tf_a, tf_b], axis=axis, **kwargs)
assert tf_c.op.op_def.name == 'Pack', "tf.stack() is expected to produce 'Pack' operation" assert tf_c.op.op_def.name == 'Pack', "tf.stack() is expected to produce 'Pack' operation"
compare_tf_with_tvm([a,b], ['pl_a:0','pl_b:0'], 'stack:0') compare_tf_with_tvm([a, b], ['pl_a:0', 'pl_b:0'], 'stack:0')
def test_forward_pack(): def test_forward_pack():
for axis in range(-3,3): for axis in range(-3, 3):
_test_pack(axis, [3,2,1]) _test_pack(axis, [3, 2, 1])
for axis in range(-1,1): for axis in range(-1, 1):
_test_pack(axis, [3]) _test_pack(axis, [3])
_test_pack(0, []) _test_pack(0, [])
...@@ -1228,8 +1208,8 @@ def _test_pad(input_shape, paddings, mode, **kwargs): ...@@ -1228,8 +1208,8 @@ def _test_pad(input_shape, paddings, mode, **kwargs):
def test_forward_pad(): def test_forward_pad():
""" Pad """ """ Pad """
_test_pad((2, 3), [[1,1], [2,2]], mode="CONSTANT") _test_pad((2, 3), [[1, 1], [2, 2]], mode="CONSTANT")
_test_pad((2, 3), [[1,1], [2,2]], mode="CONSTANT", constant_values=1.0) _test_pad((2, 3), [[1, 1], [2, 2]], mode="CONSTANT", constant_values=1.0)
####################################################################### #######################################################################
# Logical operators # Logical operators
...@@ -1239,8 +1219,8 @@ def test_logical_and(): ...@@ -1239,8 +1219,8 @@ def test_logical_and():
in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1') in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1')
in2 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in2') in2 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in2')
out = tf.logical_and(in1, in2, name='out') out = tf.logical_and(in1, in2, name='out')
in_data1 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool') in_data1 = np.random.choice(a=[False, True], size=(1, 4, 4, 3)).astype('bool')
in_data2 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool') in_data2 = np.random.choice(a=[False, True], size=(1, 4, 4, 3)).astype('bool')
compare_tf_with_tvm([in_data1, in_data2], ['in1:0', 'in2:0'], 'out:0') compare_tf_with_tvm([in_data1, in_data2], ['in1:0', 'in2:0'], 'out:0')
def test_logical_or(): def test_logical_or():
...@@ -1248,8 +1228,8 @@ def test_logical_or(): ...@@ -1248,8 +1228,8 @@ def test_logical_or():
in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1') in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1')
in2 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in2') in2 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in2')
out = tf.logical_or(in1, in2, name='out') out = tf.logical_or(in1, in2, name='out')
in_data1 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool') in_data1 = np.random.choice(a=[False, True], size=(1, 4, 4, 3)).astype('bool')
in_data2 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool') in_data2 = np.random.choice(a=[False, True], size=(1, 4, 4, 3)).astype('bool')
compare_tf_with_tvm([in_data1, in_data2], ['in1:0', 'in2:0'], 'out:0') compare_tf_with_tvm([in_data1, in_data2], ['in1:0', 'in2:0'], 'out:0')
def test_logical_xor(): def test_logical_xor():
...@@ -1257,15 +1237,15 @@ def test_logical_xor(): ...@@ -1257,15 +1237,15 @@ def test_logical_xor():
in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1') in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1')
in2 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in2') in2 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in2')
out = tf.logical_xor(in1, in2, name='out') out = tf.logical_xor(in1, in2, name='out')
in_data1 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool') in_data1 = np.random.choice(a=[False, True], size=(1, 4, 4, 3)).astype('bool')
in_data2 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool') in_data2 = np.random.choice(a=[False, True], size=(1, 4, 4, 3)).astype('bool')
compare_tf_with_tvm([in_data1, in_data2], ['in1:0', 'in2:0'], 'out:0') compare_tf_with_tvm([in_data1, in_data2], ['in1:0', 'in2:0'], 'out:0')
def test_logical_not(): def test_logical_not():
with tf.Graph().as_default(): with tf.Graph().as_default():
in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1') in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1')
out = tf.logical_not(in1, name='out') out = tf.logical_not(in1, name='out')
in_data1 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool') in_data1 = np.random.choice(a=[False, True], size=(1, 4, 4, 3)).astype('bool')
compare_tf_with_tvm(in_data1, 'in1:0', 'out:0') compare_tf_with_tvm(in_data1, 'in1:0', 'out:0')
def test_forward_logical(): def test_forward_logical():
...@@ -1297,7 +1277,8 @@ def test_forward_where(): ...@@ -1297,7 +1277,8 @@ def test_forward_where():
def test_forward_inception_v3(): def test_forward_inception_v3():
'''test inception V3 model''' '''test inception V3 model'''
with tf.Graph().as_default(): with tf.Graph().as_default():
graph_def = tf_testing.get_workload('InceptionV3/inception_v3_2016_08_28_frozen-with_shapes.pb') graph_def = tf_testing.get_workload(
'InceptionV3/inception_v3_2016_08_28_frozen-with_shapes.pb')
# Call the utility to import the graph definition into default graph. # Call the utility to import the graph definition into default graph.
graph_def = tf_testing.ProcessGraphDefParam(graph_def) graph_def = tf_testing.ProcessGraphDefParam(graph_def)
...@@ -1326,7 +1307,7 @@ def test_forward_inception_v1(): ...@@ -1326,7 +1307,7 @@ def test_forward_inception_v1():
img = Image.frombuffer('RGB', (600, 600), img_array.tostring(), 'raw', 'RGB', 0, 1) img = Image.frombuffer('RGB', (600, 600), img_array.tostring(), 'raw', 'RGB', 0, 1)
temp = util.tempdir() temp = util.tempdir()
img_path = temp.relpath("tf-test.jpg") img_path = temp.relpath("tf-test.jpg")
img.save(img_path); img.save(img_path)
import os.path import os.path
if not tf.gfile.Exists(os.path.join(img_path)): if not tf.gfile.Exists(os.path.join(img_path)):
...@@ -1365,7 +1346,8 @@ def test_forward_mobilenet(): ...@@ -1365,7 +1346,8 @@ def test_forward_mobilenet():
graph_def = tf_testing.AddShapesToGraphDef(sess, out_node) graph_def = tf_testing.AddShapesToGraphDef(sess, out_node)
tf_output = run_tf_graph(sess, data, 'input:0', out_node + ':0') tf_output = run_tf_graph(sess, data, 'input:0', out_node + ':0')
tvm_output = run_tvm_graph(graph_def, data, 'input') tvm_output = run_tvm_graph(graph_def, data, 'input')
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5) tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]),
rtol=1e-5, atol=1e-5)
####################################################################### #######################################################################
# ResnetV2 # ResnetV2
...@@ -1374,7 +1356,8 @@ def test_forward_resnetv2(): ...@@ -1374,7 +1356,8 @@ def test_forward_resnetv2():
'''test resnet model''' '''test resnet model'''
if is_gpu_available(): if is_gpu_available():
with tf.Graph().as_default(): with tf.Graph().as_default():
graph_def = tf_testing.get_workload("ResnetV2/resnet-20180601_resnet_v2_imagenet-shapes.pb") graph_def = tf_testing.get_workload(
"ResnetV2/resnet-20180601_resnet_v2_imagenet-shapes.pb")
# Call the utility to import the graph definition into default graph. # Call the utility to import the graph definition into default graph.
graph_def = tf_testing.ProcessGraphDefParam(graph_def) graph_def = tf_testing.ProcessGraphDefParam(graph_def)
...@@ -1388,8 +1371,10 @@ def test_forward_resnetv2(): ...@@ -1388,8 +1371,10 @@ def test_forward_resnetv2():
if not ctx.exist: if not ctx.exist:
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
continue continue
tvm_output = run_tvm_graph(graph_def, data, 'input_tensor', len(tf_output), target=device) tvm_output = run_tvm_graph(graph_def, data, 'input_tensor', len(tf_output),
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5) target=device)
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]),
rtol=1e-5, atol=1e-5)
####################################################################### #######################################################################
# Placeholder # Placeholder
...@@ -1409,7 +1394,8 @@ def test_forward_placeholder(): ...@@ -1409,7 +1394,8 @@ def test_forward_placeholder():
graph_def = tf_testing.AddShapesToGraphDef(sess, out_node) graph_def = tf_testing.AddShapesToGraphDef(sess, out_node)
tf_output = run_tf_graph(sess, data, 'Placeholder:0', out_node + ':0') tf_output = run_tf_graph(sess, data, 'Placeholder:0', out_node + ':0')
tvm_output = run_tvm_graph(graph_def, data, 'Placeholder') tvm_output = run_tvm_graph(graph_def, data, 'Placeholder')
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5) tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]),
rtol=1e-5, atol=1e-5)
####################################################################### #######################################################################
# PTB # PTB
...@@ -1438,8 +1424,10 @@ def test_forward_ptb(): ...@@ -1438,8 +1424,10 @@ def test_forward_ptb():
def _get_tvm_graph_module(graph_def): def _get_tvm_graph_module(graph_def):
#Cell inputs 'c and 'h' consist of all layers values #Cell inputs 'c and 'h' consist of all layers values
shape_dict = {'Model/Placeholder': (batch_size, num_steps), shape_dict = {'Model/Placeholder': (batch_size, num_steps),
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_c':(num_layers, batch_size, num_hidden), 'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_c':
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_h':(num_layers, batch_size, num_hidden)} (num_layers, batch_size, num_hidden),
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_h':
(num_layers, batch_size, num_hidden)}
mod, params = relay.frontend.from_tensorflow(graph_def, shape=shape_dict) mod, params = relay.frontend.from_tensorflow(graph_def, shape=shape_dict)
...@@ -1468,15 +1456,15 @@ def test_forward_ptb(): ...@@ -1468,15 +1456,15 @@ def test_forward_ptb():
model.set_input('Model/Placeholder', tvm.nd.array(input_data.astype("int32"))) model.set_input('Model/Placeholder', tvm.nd.array(input_data.astype("int32")))
model.set_input('Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_c', model.set_input('Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_c',
tvm.nd.array(in_state_c.astype("float32"))) tvm.nd.array(in_state_c.astype("float32")))
model.set_input('Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_h', model.set_input('Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_h',
tvm.nd.array(in_state_h.astype("float32"))) tvm.nd.array(in_state_h.astype("float32")))
model.set_input(**params) model.set_input(**params)
model.run() model.run()
tvm_output = model.get_output(0, tvm.nd.empty(out_sample_shape, tvm_output = model.get_output(0, tvm.nd.empty(out_sample_shape,
"float32")).asnumpy() "float32")).asnumpy()
state_output = model.get_output(1, tvm.nd.empty(out_state_shape, state_output = model.get_output(1, tvm.nd.empty(out_state_shape,
"float32")).asnumpy() "float32")).asnumpy()
sample = tf_testing.pick_from_weight(tvm_output[0]) sample = tf_testing.pick_from_weight(tvm_output[0])
return sample, state_output return sample, state_output
...@@ -1516,13 +1504,14 @@ def test_forward_ptb(): ...@@ -1516,13 +1504,14 @@ def test_forward_ptb():
for word in seed_for_sample], for word in seed_for_sample],
in_state, params, cnt_sample) in_state, params, cnt_sample)
tvm_sample_str = _pretty_print(tvm_samples, False, id_to_word) tvm_sample_str = _pretty_print(tvm_samples, False, id_to_word)
tf_samples, tf_state = tf_testing.do_tf_sample(sess, tf_samples, tf_state = tf_testing.do_tf_sample(
[word_to_id[word] for word in seed_for_sample], sess,
in_state, cnt_sample) [word_to_id[word] for word in seed_for_sample],
in_state, cnt_sample)
tf_sample_str = _pretty_print(tf_samples, False, id_to_word) tf_sample_str = _pretty_print(tf_samples, False, id_to_word)
inpt = tvm_sample_str inpt = tvm_sample_str
tvm.testing.assert_allclose(tf_samples, tvm_samples, rtol=1e-5, atol=1e-5) tvm.testing.assert_allclose(tf_samples, tvm_samples, rtol=1e-5, atol=1e-5)
assert(tvm_sample_str == tf_sample_str) assert tvm_sample_str == tf_sample_str
####################################################################### #######################################################################
# LRN (Local Response Normalization) # LRN (Local Response Normalization)
...@@ -1975,7 +1964,8 @@ def test_placeholder(): ...@@ -1975,7 +1964,8 @@ def test_placeholder():
out1 = tf.math.add(var1, var2, name='out1') out1 = tf.math.add(var1, var2, name='out1')
out2 = tf.math.add(out1, place1, name='out2') out2 = tf.math.add(out1, place1, name='out2')
compare_tf_with_tvm([in_data1, in_data2], ['place1:0', 'in2:0'], 'out2:0', init_global_variables=True) compare_tf_with_tvm([in_data1, in_data2], ['place1:0', 'in2:0'], 'out2:0',
init_global_variables=True)
####################################################################### #######################################################################
...@@ -1996,7 +1986,6 @@ if __name__ == '__main__': ...@@ -1996,7 +1986,6 @@ if __name__ == '__main__':
test_forward_pad() test_forward_pad()
test_forward_unpack() test_forward_unpack()
test_forward_gather() test_forward_gather()
test_forward_gather_v1()
test_forward_gather_nd() test_forward_gather_nd()
test_forward_stridedslice() test_forward_stridedslice()
test_forward_split() test_forward_split()
...@@ -2056,8 +2045,7 @@ if __name__ == '__main__': ...@@ -2056,8 +2045,7 @@ if __name__ == '__main__':
# NN # NN
test_forward_convolution() test_forward_convolution()
test_forward_pooling() test_forward_pooling()
if tf.__version__ == '1.4.1': test_forward_concat_v2()
_test_forward_concat_v2()
test_forward_lrn() test_forward_lrn()
test_forward_l2_normalize() test_forward_l2_normalize()
test_forward_space_to_batch_nd() test_forward_space_to_batch_nd()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment