Commit cd717dea by Kim Committed by Tianqi Chen

[ Relay ][ Frontend ][ Tensorflow ]add op add_n to relay/frontend/tensorflow.py (#4181)

parent bafc675c
...@@ -115,6 +115,7 @@ Supported Ops ...@@ -115,6 +115,7 @@ Supported Ops
- Abs - Abs
- Add - Add
- AddN
- All - All
- Any - Any
- ArgMax - ArgMax
......
...@@ -1318,6 +1318,18 @@ def _size(): ...@@ -1318,6 +1318,18 @@ def _size():
return AttrCvt('ndarray_size', transforms={'out_type' : 'dtype'})(inputs, new_attr) return AttrCvt('ndarray_size', transforms={'out_type' : 'dtype'})(inputs, new_attr)
return _impl return _impl
def _add_n():
def _impl(inputs, attr, params):
if not isinstance(inputs, tuple):
inputs = list(inputs)
assert len(inputs) > 0, "add_n take >=1 inputs, but 0 given."
_res = inputs[0]
for each in inputs[1:]:
_res = _op.add(_res, each)
return _res
return _impl
# compatible operators that do NOT require any conversion. # compatible operators that do NOT require any conversion.
_identity_list = [] _identity_list = []
...@@ -1329,6 +1341,7 @@ _identity_list = [] ...@@ -1329,6 +1341,7 @@ _identity_list = []
_convert_map = { _convert_map = {
'Abs' : AttrCvt('abs'), 'Abs' : AttrCvt('abs'),
'Add' : _elemwise('add'), 'Add' : _elemwise('add'),
'AddN' : _add_n(),
'All' : _reduce('all'), 'All' : _reduce('all'),
'Any' : _reduce('any'), 'Any' : _reduce('any'),
'ArgMax' : _argx(_op.argmax, 'argmax'), 'ArgMax' : _argx(_op.argmax, 'argmax'),
......
...@@ -41,11 +41,14 @@ import tvm.relay.testing.tf as tf_testing ...@@ -41,11 +41,14 @@ import tvm.relay.testing.tf as tf_testing
####################################################################### #######################################################################
# Generic run functions for TVM & tensorflow # Generic run functions for TVM & tensorflow
# ------------------------------------------ # ------------------------------------------
def convert_to_list(x): def convert_to_list(x):
if not isinstance(x, list): if not isinstance(x, list):
x = [x] x = [x]
return x return x
def vmobj_to_list(o): def vmobj_to_list(o):
if isinstance(o, tvm.relay.backend.vmobj.Tensor): if isinstance(o, tvm.relay.backend.vmobj.Tensor):
return [o.asnumpy().tolist()] return [o.asnumpy().tolist()]
...@@ -72,12 +75,14 @@ def vmobj_to_list(o): ...@@ -72,12 +75,14 @@ def vmobj_to_list(o):
elif 'tensor' in o.constructor.name_hint: elif 'tensor' in o.constructor.name_hint:
return [o.fields[0].asnumpy()] return [o.fields[0].asnumpy()]
else: else:
raise RuntimeError("Unknown object type: %s" % o.constructor.name_hint) raise RuntimeError("Unknown object type: %s" %
o.constructor.name_hint)
elif isinstance(o, tvm.relay.backend.interpreter.TensorValue): elif isinstance(o, tvm.relay.backend.interpreter.TensorValue):
return [o.data.asnumpy()] return [o.data.asnumpy()]
else: else:
raise RuntimeError("Unknown object type: %s" % type(o)) raise RuntimeError("Unknown object type: %s" % type(o))
def run_tvm_graph(graph_def, input_data, input_node, num_output=1, def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
target='llvm', out_names=None, opt_level=3, mode='graph_runtime'): target='llvm', out_names=None, opt_level=3, mode='graph_runtime'):
""" Generic function to compile on relay and execute on tvm """ """ Generic function to compile on relay and execute on tvm """
...@@ -116,16 +121,19 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1, ...@@ -116,16 +121,19 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
# get outputs # get outputs
assert out_names is None or num_output == len(out_names), ( assert out_names is None or num_output == len(out_names), (
"out_names: {} num_output: {}".format(out_names, num_output)) "out_names: {} num_output: {}".format(out_names, num_output))
tvm_output_list = [m.get_output(i).asnumpy() for i in range(num_output)] tvm_output_list = [m.get_output(i).asnumpy()
for i in range(num_output)]
return tvm_output_list return tvm_output_list
def run_tf_graph(sess, input_data, input_node, output_node): def run_tf_graph(sess, input_data, input_node, output_node):
""" Generic function to execute tensorflow """ """ Generic function to execute tensorflow """
input_data = convert_to_list(input_data) input_data = convert_to_list(input_data)
input_node = convert_to_list(input_node) input_node = convert_to_list(input_node)
output_node = convert_to_list(output_node) output_node = convert_to_list(output_node)
tensor = [sess.graph.get_tensor_by_name(output_name) for output_name in output_node] tensor = [sess.graph.get_tensor_by_name(
output_name) for output_name in output_node]
input_dict = {e: input_data[i] for i, e in enumerate(input_node)} input_dict = {e: input_data[i] for i, e in enumerate(input_node)}
...@@ -152,7 +160,7 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False, ...@@ -152,7 +160,7 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False,
sess, sess,
sess.graph.as_graph_def(add_shapes=True), sess.graph.as_graph_def(add_shapes=True),
out_node, out_node,
) )
tf_output = run_tf_graph(sess, in_data, in_name, out_name) tf_output = run_tf_graph(sess, in_data, in_name, out_name)
for device in ["llvm", "cuda"]: for device in ["llvm", "cuda"]:
...@@ -169,10 +177,12 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False, ...@@ -169,10 +177,12 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False,
# since the names from tensorflow and relay runs are not exactly same, # since the names from tensorflow and relay runs are not exactly same,
# first len(tf_output) will be compared # first len(tf_output) will be compared
for i in range(len(tf_output)): for i in range(len(tf_output)):
tvm.testing.assert_allclose(tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5) tvm.testing.assert_allclose(
tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)
sess.close() sess.close()
def is_gpu_available(): def is_gpu_available():
from tensorflow.python.client import device_lib from tensorflow.python.client import device_lib
local_device_protos = device_lib.list_local_devices() local_device_protos = device_lib.list_local_devices()
...@@ -186,6 +196,8 @@ def is_gpu_available(): ...@@ -186,6 +196,8 @@ def is_gpu_available():
####################################################################### #######################################################################
# Pooling # Pooling
# ------- # -------
def _test_pooling_iteration(input_shape, **kwargs): def _test_pooling_iteration(input_shape, **kwargs):
""" One iteration of pool operation with given shapes and attributes """ """ One iteration of pool operation with given shapes and attributes """
...@@ -203,6 +215,7 @@ def _test_pooling_iteration(input_shape, **kwargs): ...@@ -203,6 +215,7 @@ def _test_pooling_iteration(input_shape, **kwargs):
compare_tf_with_tvm(x, 'Placeholder:0', out_name) compare_tf_with_tvm(x, 'Placeholder:0', out_name)
def _test_pooling(input_shape, **kwargs): def _test_pooling(input_shape, **kwargs):
_test_pooling_iteration(input_shape, **kwargs) _test_pooling_iteration(input_shape, **kwargs)
...@@ -211,6 +224,7 @@ def _test_pooling(input_shape, **kwargs): ...@@ -211,6 +224,7 @@ def _test_pooling(input_shape, **kwargs):
kwargs['data_format'] = 'NCHW' kwargs['data_format'] = 'NCHW'
_test_pooling_iteration(input_shape, **kwargs) _test_pooling_iteration(input_shape, **kwargs)
def test_forward_pooling(): def test_forward_pooling():
""" Pooling """ """ Pooling """
...@@ -260,6 +274,7 @@ def test_forward_pooling(): ...@@ -260,6 +274,7 @@ def test_forward_pooling():
# Convolution # Convolution
# ----------- # -----------
def _test_convolution(opname, tensor_in_sizes, filter_in_sizes, def _test_convolution(opname, tensor_in_sizes, filter_in_sizes,
dilations, strides, padding, data_format): dilations, strides, padding, data_format):
""" One iteration of convolution with given shapes and attributes """ """ One iteration of convolution with given shapes and attributes """
...@@ -273,7 +288,8 @@ def _test_convolution(opname, tensor_in_sizes, filter_in_sizes, ...@@ -273,7 +288,8 @@ def _test_convolution(opname, tensor_in_sizes, filter_in_sizes,
with tf.Graph().as_default(): with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype='float32') in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype='float32')
in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype='float32') in_filter = constant_op.constant(
filter_array, shape=filter_in_sizes, dtype='float32')
if data_format == 'NHWC': if data_format == 'NHWC':
strides = [1] + strides + [1] strides = [1] + strides + [1]
dilations = [1] + dilations + [1] dilations = [1] + dilations + [1]
...@@ -293,15 +309,16 @@ def _test_convolution(opname, tensor_in_sizes, filter_in_sizes, ...@@ -293,15 +309,16 @@ def _test_convolution(opname, tensor_in_sizes, filter_in_sizes,
'Placeholder:0', 'Conv2D:0') 'Placeholder:0', 'Conv2D:0')
else: else:
nn_ops.depthwise_conv2d_native(in_data, nn_ops.depthwise_conv2d_native(in_data,
in_filter, in_filter,
strides=strides, strides=strides,
dilations=dilations, dilations=dilations,
padding=padding, padding=padding,
data_format=data_format) data_format=data_format)
compare_tf_with_tvm(np.reshape(data_array, tensor_in_sizes).astype('float32'), compare_tf_with_tvm(np.reshape(data_array, tensor_in_sizes).astype('float32'),
'Placeholder:0', 'DepthwiseConv2dNative:0') 'Placeholder:0', 'DepthwiseConv2dNative:0')
def test_forward_convolution(): def test_forward_convolution():
if is_gpu_available(): if is_gpu_available():
_test_convolution('conv', [4, 176, 8, 8], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NCHW') _test_convolution('conv', [4, 176, 8, 8], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NCHW')
...@@ -327,13 +344,16 @@ def test_forward_convolution(): ...@@ -327,13 +344,16 @@ def test_forward_convolution():
####################################################################### #######################################################################
# BiasAdd # BiasAdd
# ----------- # -----------
def _test_biasadd(tensor_in_sizes, data_format): def _test_biasadd(tensor_in_sizes, data_format):
""" One iteration of biasadd with given shapes and attributes """ """ One iteration of biasadd with given shapes and attributes """
total_size_1 = 1 total_size_1 = 1
for s in tensor_in_sizes: for s in tensor_in_sizes:
total_size_1 *= s total_size_1 *= s
tensor_bias_sizes = [tensor_in_sizes[1]] if data_format == 'NCHW' else [tensor_in_sizes[3]] tensor_bias_sizes = [tensor_in_sizes[1]
] if data_format == 'NCHW' else [tensor_in_sizes[3]]
total_size_2 = tensor_bias_sizes[0] total_size_2 = tensor_bias_sizes[0]
# Initializes the input tensor with array containing incrementing # Initializes the input tensor with array containing incrementing
# numbers from 1. # numbers from 1.
...@@ -342,7 +362,8 @@ def _test_biasadd(tensor_in_sizes, data_format): ...@@ -342,7 +362,8 @@ def _test_biasadd(tensor_in_sizes, data_format):
with tf.Graph().as_default(): with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype='float32') in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype='float32')
in_bias = constant_op.constant(bias_array, shape=tensor_bias_sizes, dtype='float32') in_bias = constant_op.constant(
bias_array, shape=tensor_bias_sizes, dtype='float32')
nn_ops.bias_add(in_data, nn_ops.bias_add(in_data,
in_bias, in_bias,
data_format=data_format) data_format=data_format)
...@@ -350,6 +371,7 @@ def _test_biasadd(tensor_in_sizes, data_format): ...@@ -350,6 +371,7 @@ def _test_biasadd(tensor_in_sizes, data_format):
compare_tf_with_tvm(np.reshape(data_array, tensor_in_sizes).astype('float32'), compare_tf_with_tvm(np.reshape(data_array, tensor_in_sizes).astype('float32'),
'Placeholder:0', 'BiasAdd:0') 'Placeholder:0', 'BiasAdd:0')
def test_forward_biasadd(): def test_forward_biasadd():
if is_gpu_available(): if is_gpu_available():
_test_biasadd([4, 176, 8, 8], 'NCHW') _test_biasadd([4, 176, 8, 8], 'NCHW')
...@@ -362,15 +384,17 @@ def test_forward_biasadd(): ...@@ -362,15 +384,17 @@ def test_forward_biasadd():
_test_biasadd([4, 17, 17, 19], 'NHWC') _test_biasadd([4, 17, 17, 19], 'NHWC')
_test_biasadd([4, 3, 3, 124], 'NHWC') _test_biasadd([4, 3, 3, 124], 'NHWC')
def _test_forward_where(input_shape): def _test_forward_where(input_shape):
with tf.Graph().as_default(): with tf.Graph().as_default():
dtype = tf.float32 dtype = tf.float32
t = tf.constant(np.random.choice([0, 1, -2, 3, -1, 0.1, -0.2], t = tf.constant(np.random.choice([0, 1, -2, 3, -1, 0.1, -0.2],
size=input_shape).astype(dtype.name)) size=input_shape).astype(dtype.name))
out = tf.where(t) out = tf.where(t)
compare_tf_with_tvm([], [], out.name, mode='debug') compare_tf_with_tvm([], [], out.name, mode='debug')
compare_tf_with_tvm([], [], out.name, mode='vm') compare_tf_with_tvm([], [], out.name, mode='vm')
def test_forward_argwhere(): def test_forward_argwhere():
_test_forward_where((5,)) _test_forward_where((5,))
_test_forward_where((5, 5)) _test_forward_where((5, 5))
...@@ -381,6 +405,8 @@ def test_forward_argwhere(): ...@@ -381,6 +405,8 @@ def test_forward_argwhere():
####################################################################### #######################################################################
# SpaceToBatchND # SpaceToBatchND
# -------------- # --------------
def _test_space_to_batch_nd(input_shape, block_shape, paddings, dtype='int32'): def _test_space_to_batch_nd(input_shape, block_shape, paddings, dtype='int32'):
data = np.random.uniform(0, 5, size=input_shape).astype(dtype) data = np.random.uniform(0, 5, size=input_shape).astype(dtype)
...@@ -390,6 +416,7 @@ def _test_space_to_batch_nd(input_shape, block_shape, paddings, dtype='int32'): ...@@ -390,6 +416,7 @@ def _test_space_to_batch_nd(input_shape, block_shape, paddings, dtype='int32'):
compare_tf_with_tvm(data, in_data.name, out.name) compare_tf_with_tvm(data, in_data.name, out.name)
def test_forward_space_to_batch_nd(): def test_forward_space_to_batch_nd():
# test cases: https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/space-to-batch-n-d # test cases: https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/space-to-batch-n-d
_test_space_to_batch_nd( _test_space_to_batch_nd(
...@@ -436,6 +463,8 @@ def test_forward_space_to_batch_nd(): ...@@ -436,6 +463,8 @@ def test_forward_space_to_batch_nd():
####################################################################### #######################################################################
# BatchToSpaceND # BatchToSpaceND
# -------------- # --------------
def _test_batch_to_space_nd(input_shape, block_shape, crops, dtype='int32'): def _test_batch_to_space_nd(input_shape, block_shape, crops, dtype='int32'):
data = np.random.uniform(0, 5, size=input_shape).astype(dtype) data = np.random.uniform(0, 5, size=input_shape).astype(dtype)
...@@ -445,6 +474,7 @@ def _test_batch_to_space_nd(input_shape, block_shape, crops, dtype='int32'): ...@@ -445,6 +474,7 @@ def _test_batch_to_space_nd(input_shape, block_shape, crops, dtype='int32'):
compare_tf_with_tvm(data, in_data.name, out.name) compare_tf_with_tvm(data, in_data.name, out.name)
def test_forward_batch_to_space_nd(): def test_forward_batch_to_space_nd():
# test cases: https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d # test cases: https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d
_test_batch_to_space_nd( _test_batch_to_space_nd(
...@@ -492,6 +522,7 @@ def test_forward_batch_to_space_nd(): ...@@ -492,6 +522,7 @@ def test_forward_batch_to_space_nd():
# Reshape # Reshape
# ------- # -------
def _test_reshape(data, out_shape): def _test_reshape(data, out_shape):
""" One iteration of reshape operation with given data and out shape """ """ One iteration of reshape operation with given data and out shape """
...@@ -501,6 +532,7 @@ def _test_reshape(data, out_shape): ...@@ -501,6 +532,7 @@ def _test_reshape(data, out_shape):
compare_tf_with_tvm(data, 'Placeholder:0', 'Reshape:0') compare_tf_with_tvm(data, 'Placeholder:0', 'Reshape:0')
def test_forward_reshape(): def test_forward_reshape():
_test_reshape(np.arange(6.0), [2, 3]) _test_reshape(np.arange(6.0), [2, 3])
_test_reshape(np.arange(6), [-1, 2]) _test_reshape(np.arange(6), [-1, 2])
...@@ -511,6 +543,7 @@ def test_forward_reshape(): ...@@ -511,6 +543,7 @@ def test_forward_reshape():
# DepthToSpace # DepthToSpace
# ------------ # ------------
def _test_depthtospace(data, block_size): def _test_depthtospace(data, block_size):
""" One iteration of depth_to_space operation with given data and block size """ """ One iteration of depth_to_space operation with given data and block size """
...@@ -520,6 +553,7 @@ def _test_depthtospace(data, block_size): ...@@ -520,6 +553,7 @@ def _test_depthtospace(data, block_size):
compare_tf_with_tvm(data, 'Placeholder:0', 'DepthToSpace:0') compare_tf_with_tvm(data, 'Placeholder:0', 'DepthToSpace:0')
def test_forward_depthtospace(): def test_forward_depthtospace():
_test_depthtospace(np.random.normal(size=[1, 32, 32, 4]), 2) _test_depthtospace(np.random.normal(size=[1, 32, 32, 4]), 2)
_test_depthtospace(np.random.normal(size=[1, 16, 8, 32]), 4) _test_depthtospace(np.random.normal(size=[1, 16, 8, 32]), 4)
...@@ -528,6 +562,7 @@ def test_forward_depthtospace(): ...@@ -528,6 +562,7 @@ def test_forward_depthtospace():
# SpaceToDepth # SpaceToDepth
# ------------ # ------------
def _test_spacetodepth(data, block_size): def _test_spacetodepth(data, block_size):
""" One iteration of space_to_depth operation with given data and block size """ """ One iteration of space_to_depth operation with given data and block size """
...@@ -537,6 +572,7 @@ def _test_spacetodepth(data, block_size): ...@@ -537,6 +572,7 @@ def _test_spacetodepth(data, block_size):
compare_tf_with_tvm(data, 'Placeholder:0', 'SpaceToDepth:0') compare_tf_with_tvm(data, 'Placeholder:0', 'SpaceToDepth:0')
def test_forward_spacetodepth(): def test_forward_spacetodepth():
_test_spacetodepth(np.random.normal(size=[1, 32, 32, 4]), 2) _test_spacetodepth(np.random.normal(size=[1, 32, 32, 4]), 2)
_test_spacetodepth(np.random.normal(size=[1, 16, 8, 32]), 4) _test_spacetodepth(np.random.normal(size=[1, 16, 8, 32]), 4)
...@@ -545,6 +581,7 @@ def test_forward_spacetodepth(): ...@@ -545,6 +581,7 @@ def test_forward_spacetodepth():
# Squeeze # Squeeze
# ------- # -------
def _test_squeeze(data, squeeze_dims=None): def _test_squeeze(data, squeeze_dims=None):
""" One iteration of squeeze """ """ One iteration of squeeze """
...@@ -561,6 +598,7 @@ def _test_squeeze(data, squeeze_dims=None): ...@@ -561,6 +598,7 @@ def _test_squeeze(data, squeeze_dims=None):
compare_tf_with_tvm(data, 'Placeholder:0', 'Squeeze:0') compare_tf_with_tvm(data, 'Placeholder:0', 'Squeeze:0')
def test_forward_squeeze(): def test_forward_squeeze():
""" Squeeze """ """ Squeeze """
...@@ -584,16 +622,20 @@ def test_forward_squeeze(): ...@@ -584,16 +622,20 @@ def test_forward_squeeze():
_test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [-3, -5]) _test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [-3, -5])
_test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [-3, -5, -1]) _test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [-3, -5, -1])
def test_tensor_array_constructor(): def test_tensor_array_constructor():
def run(dtype_str): def run(dtype_str):
with tf.Graph().as_default(): with tf.Graph().as_default():
dtype = { dtype = {
'float32': tf.float32, 'float32': tf.float32,
'int32' : tf.int32 'int32': tf.int32
}[dtype_str] }[dtype_str]
t = tf.constant(np.array([[1.0, 2.0], [3.0, 4.0]]).astype(dtype_str), dtype=dtype) t = tf.constant(np.array([[1.0, 2.0], [3.0, 4.0]]).astype(
t2 = tf.constant(np.array([[1.0, 2.0], [3.0, 4.0]]).astype(dtype_str), dtype=dtype) dtype_str), dtype=dtype)
ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=False, dynamic_size=False) t2 = tf.constant(np.array([[1.0, 2.0], [3.0, 4.0]]).astype(
dtype_str), dtype=dtype)
ta1 = tf.TensorArray(dtype=dtype, size=2,
infer_shape=False, dynamic_size=False)
ta2 = ta1.write(0, t) ta2 = ta1.write(0, t)
ta3 = ta2.write(1, t2) ta3 = ta2.write(1, t2)
out = ta3.read(0) out = ta3.read(0)
...@@ -602,24 +644,29 @@ def test_tensor_array_constructor(): ...@@ -602,24 +644,29 @@ def test_tensor_array_constructor():
run('float32') run('float32')
run('int32') run('int32')
def test_tensor_array_scatter(): def test_tensor_array_scatter():
def run(dtype_str): def run(dtype_str):
with tf.Graph().as_default(): with tf.Graph().as_default():
dtype = { dtype = {
'float32': tf.float32, 'float32': tf.float32,
'int32' : tf.int32 'int32': tf.int32
}[dtype_str] }[dtype_str]
t = tf.constant(np.array([[1.0], [2.0], [3.0]]).astype(dtype_str), dtype=dtype) t = tf.constant(np.array([[1.0], [2.0], [3.0]]).astype(
dtype_str), dtype=dtype)
indices = tf.constant([2, 1, 0]) indices = tf.constant([2, 1, 0])
ta1 = tf.TensorArray(dtype=dtype, size=3, infer_shape=False, dynamic_size=False) ta1 = tf.TensorArray(dtype=dtype, size=3,
infer_shape=False, dynamic_size=False)
ta2 = ta1.scatter(indices, t) ta2 = ta1.scatter(indices, t)
out0 = ta2.read(0) out0 = ta2.read(0)
out1 = ta2.read(1) out1 = ta2.read(1)
out2 = ta2.read(2) out2 = ta2.read(2)
g = tf.get_default_graph() g = tf.get_default_graph()
compare_tf_with_tvm([], [], ['TensorArrayReadV3:0'], mode='debug') compare_tf_with_tvm([], [], ['TensorArrayReadV3:0'], mode='debug')
compare_tf_with_tvm([], [], ['TensorArrayReadV3_1:0'], mode='debug') compare_tf_with_tvm(
compare_tf_with_tvm([], [], ['TensorArrayReadV3_2:0'], mode='debug') [], [], ['TensorArrayReadV3_1:0'], mode='debug')
compare_tf_with_tvm(
[], [], ['TensorArrayReadV3_2:0'], mode='debug')
run('float32') run('float32')
run('int32') run('int32')
...@@ -636,16 +683,19 @@ def test_tensor_array_scatter(): ...@@ -636,16 +683,19 @@ def test_tensor_array_scatter():
# g = tf.get_default_graph() # g = tf.get_default_graph()
# compare_tf_with_tvm([], [], ['TensorArrayGatherV3:0'], mode='debug') # compare_tf_with_tvm([], [], ['TensorArrayGatherV3:0'], mode='debug')
def test_tensor_array_split(): def test_tensor_array_split():
def run(dtype_str): def run(dtype_str):
with tf.Graph().as_default(): with tf.Graph().as_default():
dtype = { dtype = {
'float32': tf.float32, 'float32': tf.float32,
'int32' : tf.int32 'int32': tf.int32
}[dtype_str] }[dtype_str]
t = tf.constant(np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0]]).astype(dtype_str), dtype=dtype) t = tf.constant(np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [
6.0], [7.0], [8.0]]).astype(dtype_str), dtype=dtype)
split_length = tf.constant([2, 2, 2, 2], dtype=tf.int32) split_length = tf.constant([2, 2, 2, 2], dtype=tf.int32)
ta1 = tf.TensorArray(dtype=dtype, size=4, infer_shape=False, dynamic_size=False) ta1 = tf.TensorArray(dtype=dtype, size=4,
infer_shape=False, dynamic_size=False)
ta2 = ta1.split(t, split_length) ta2 = ta1.split(t, split_length)
out0 = ta2.read(0) out0 = ta2.read(0)
out1 = ta2.read(1) out1 = ta2.read(1)
...@@ -653,36 +703,45 @@ def test_tensor_array_split(): ...@@ -653,36 +703,45 @@ def test_tensor_array_split():
out3 = ta2.read(3) out3 = ta2.read(3)
g = tf.get_default_graph() g = tf.get_default_graph()
compare_tf_with_tvm([], [], ['TensorArrayReadV3:0'], mode='debug') compare_tf_with_tvm([], [], ['TensorArrayReadV3:0'], mode='debug')
compare_tf_with_tvm([], [], ['TensorArrayReadV3_1:0'], mode='debug') compare_tf_with_tvm(
compare_tf_with_tvm([], [], ['TensorArrayReadV3_2:0'], mode='debug') [], [], ['TensorArrayReadV3_1:0'], mode='debug')
compare_tf_with_tvm([], [], ['TensorArrayReadV3_3:0'], mode='debug') compare_tf_with_tvm(
[], [], ['TensorArrayReadV3_2:0'], mode='debug')
compare_tf_with_tvm(
[], [], ['TensorArrayReadV3_3:0'], mode='debug')
run('float32') run('float32')
run('int32') run('int32')
def test_tensor_array_concat(): def test_tensor_array_concat():
def run(dtype_str): def run(dtype_str):
with tf.Graph().as_default(): with tf.Graph().as_default():
dtype = { dtype = {
'float32': tf.float32, 'float32': tf.float32,
'int32' : tf.int32 'int32': tf.int32
}[dtype_str] }[dtype_str]
t = tf.constant(np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0]]).astype(dtype_str), dtype=dtype) t = tf.constant(np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [
6.0], [7.0], [8.0]]).astype(dtype_str), dtype=dtype)
split_length = tf.constant([2, 2, 2, 2], dtype=tf.int32) split_length = tf.constant([2, 2, 2, 2], dtype=tf.int32)
ta1 = tf.TensorArray(dtype=dtype, size=4, infer_shape=False, dynamic_size=False) ta1 = tf.TensorArray(dtype=dtype, size=4,
infer_shape=False, dynamic_size=False)
ta2 = ta1.split(t, split_length) ta2 = ta1.split(t, split_length)
t = ta2.concat() t = ta2.concat()
compare_tf_with_tvm([], [], ['TensorArrayConcatV3:0'], mode='debug') compare_tf_with_tvm(
[], [], ['TensorArrayConcatV3:0'], mode='debug')
run('float32') run('float32')
run('int32') run('int32')
def test_tensor_array_size(): def test_tensor_array_size():
def run(dtype_str): def run(dtype_str):
with tf.Graph().as_default(): with tf.Graph().as_default():
dtype = { dtype = {
'float32': tf.float32, 'float32': tf.float32,
'int32' : tf.int32 'int32': tf.int32
}[dtype_str] }[dtype_str]
ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=False, dynamic_size=False) ta1 = tf.TensorArray(dtype=dtype, size=2,
infer_shape=False, dynamic_size=False)
out = ta1.size() out = ta1.size()
g = tf.get_default_graph() g = tf.get_default_graph()
compare_tf_with_tvm([], [], 'TensorArraySizeV3:0', mode='debug') compare_tf_with_tvm([], [], 'TensorArraySizeV3:0', mode='debug')
...@@ -693,6 +752,7 @@ def test_tensor_array_size(): ...@@ -693,6 +752,7 @@ def test_tensor_array_size():
# ConcatV2 # ConcatV2
# -------- # --------
def _test_concat_v2(shape1, shape2, dim): def _test_concat_v2(shape1, shape2, dim):
""" One iteration of ConcatV2 """ """ One iteration of ConcatV2 """
...@@ -705,7 +765,9 @@ def _test_concat_v2(shape1, shape2, dim): ...@@ -705,7 +765,9 @@ def _test_concat_v2(shape1, shape2, dim):
np_data1 = np.random.uniform(size=shape1).astype(dtype) np_data1 = np.random.uniform(size=shape1).astype(dtype)
np_data2 = np.random.uniform(size=shape2).astype(dtype) np_data2 = np.random.uniform(size=shape2).astype(dtype)
compare_tf_with_tvm([np_data1, np_data2], ['in1:0', 'in2:0'], 'ConcatV2:0') compare_tf_with_tvm([np_data1, np_data2], [
'in1:0', 'in2:0'], 'ConcatV2:0')
def test_forward_concat_v2(): def test_forward_concat_v2():
if tf.__version__ < LooseVersion('1.4.1'): if tf.__version__ < LooseVersion('1.4.1'):
...@@ -721,6 +783,7 @@ def test_forward_concat_v2(): ...@@ -721,6 +783,7 @@ def test_forward_concat_v2():
# Sigmoid # Sigmoid
# ------- # -------
def _test_sigmoid(data): def _test_sigmoid(data):
""" One iteration of sigmoid """ """ One iteration of sigmoid """
...@@ -730,6 +793,7 @@ def _test_sigmoid(data): ...@@ -730,6 +793,7 @@ def _test_sigmoid(data):
compare_tf_with_tvm(data, 'Placeholder:0', 'Sigmoid:0') compare_tf_with_tvm(data, 'Placeholder:0', 'Sigmoid:0')
def test_forward_sigmoid(): def test_forward_sigmoid():
""" Sigmoid """ """ Sigmoid """
...@@ -739,14 +803,17 @@ def test_forward_sigmoid(): ...@@ -739,14 +803,17 @@ def test_forward_sigmoid():
# Argmin/Argmax # Argmin/Argmax
# ------------- # -------------
def _test_argx(func, data, **kwargs): def _test_argx(func, data, **kwargs):
with tf.Graph().as_default(): with tf.Graph().as_default():
inp = array_ops.placeholder(shape=data.shape, dtype=data.dtype, name="c0") inp = array_ops.placeholder(
shape=data.shape, dtype=data.dtype, name="c0")
func(inp, name="argx0", output_type=tf.int32, **kwargs) func(inp, name="argx0", output_type=tf.int32, **kwargs)
compare_tf_with_tvm(data, 'c0:0', 'argx0:0') compare_tf_with_tvm(data, 'c0:0', 'argx0:0')
def test_forward_argminmax(): def test_forward_argminmax():
for axis in [None, 0, 1, 2]: for axis in [None, 0, 1, 2]:
data = np.random.uniform(size=(8, 4, 9)).astype('float32') data = np.random.uniform(size=(8, 4, 9)).astype('float32')
...@@ -757,15 +824,18 @@ def test_forward_argminmax(): ...@@ -757,15 +824,18 @@ def test_forward_argminmax():
# Reduce # Reduce
# ------ # ------
def _test_reduce(func, data, **kwargs): def _test_reduce(func, data, **kwargs):
""" One iteration of a reduce operation""" """ One iteration of a reduce operation"""
with tf.Graph().as_default(): with tf.Graph().as_default():
inp = array_ops.placeholder(shape=data.shape, dtype=data.dtype, name="c0") inp = array_ops.placeholder(
shape=data.shape, dtype=data.dtype, name="c0")
func(inp, name="reducex0", **kwargs) func(inp, name="reducex0", **kwargs)
compare_tf_with_tvm(data, 'c0:0', 'reducex0:0') compare_tf_with_tvm(data, 'c0:0', 'reducex0:0')
def test_forward_reduce(): def test_forward_reduce():
data = np.random.uniform(size=(8, 4, 9)).astype('float32') data = np.random.uniform(size=(8, 4, 9)).astype('float32')
_test_reduce(tf.reduce_sum, data=data) _test_reduce(tf.reduce_sum, data=data)
...@@ -790,7 +860,9 @@ def _test_variable(data): ...@@ -790,7 +860,9 @@ def _test_variable(data):
"w", shape=[size, size], dtype=input_tensor.dtype) "w", shape=[size, size], dtype=input_tensor.dtype)
math_ops.matmul(input_tensor, w) math_ops.matmul(input_tensor, w)
compare_tf_with_tvm(data, 'Placeholder:0', 'MatMul:0', init_global_variables=True) compare_tf_with_tvm(data, 'Placeholder:0', 'MatMul:0',
init_global_variables=True)
def test_forward_variable(): def test_forward_variable():
"""Variable type op test""" """Variable type op test"""
...@@ -810,23 +882,29 @@ def _test_matmul(i, j, k, dtype, outer=None): ...@@ -810,23 +882,29 @@ def _test_matmul(i, j, k, dtype, outer=None):
for transpose_a in [False, True]: for transpose_a in [False, True]:
for transpose_b in [False, True]: for transpose_b in [False, True]:
outer = outer or [] outer = outer or []
A_shape = outer + (A_shape_init[::-1] if transpose_a else A_shape_init) A_shape = outer + \
B_shape = outer + (B_shape_init[::-1] if transpose_b else B_shape_init) (A_shape_init[::-1] if transpose_a else A_shape_init)
B_shape = outer + \
(B_shape_init[::-1] if transpose_b else B_shape_init)
with tf.Graph().as_default(): with tf.Graph().as_default():
A = tf.placeholder(shape=A_shape, dtype=dtype, name='A') A = tf.placeholder(shape=A_shape, dtype=dtype, name='A')
B = tf.placeholder(shape=B_shape, dtype=dtype, name='B') B = tf.placeholder(shape=B_shape, dtype=dtype, name='B')
result = tf.matmul(A, B, transpose_a=transpose_a, transpose_b=transpose_b) result = tf.matmul(
A, B, transpose_a=transpose_a, transpose_b=transpose_b)
A_np = np.random.uniform(high=5.0, size=A_shape).astype(dtype) A_np = np.random.uniform(high=5.0, size=A_shape).astype(dtype)
B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype) B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype)
compare_tf_with_tvm([A_np, B_np], [A.name, B.name], result.name) compare_tf_with_tvm(
[A_np, B_np], [A.name, B.name], result.name)
def test_forward_matmul(): def test_forward_matmul():
""" MatMul op test""" """ MatMul op test"""
_test_matmul(1, 3, 6, 'int32') _test_matmul(1, 3, 6, 'int32')
_test_matmul(5, 3, 1, 'float64') _test_matmul(5, 3, 1, 'float64')
def _test_batch_matmul(A_shape, B_shape, dtype, adjoint_a=False, adjoint_b=False): def _test_batch_matmul(A_shape, B_shape, dtype, adjoint_a=False, adjoint_b=False):
with tf.Graph().as_default(): with tf.Graph().as_default():
...@@ -839,6 +917,7 @@ def _test_batch_matmul(A_shape, B_shape, dtype, adjoint_a=False, adjoint_b=False ...@@ -839,6 +917,7 @@ def _test_batch_matmul(A_shape, B_shape, dtype, adjoint_a=False, adjoint_b=False
B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype) B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype)
compare_tf_with_tvm([A_np, B_np], [A.name, B.name], result.name) compare_tf_with_tvm([A_np, B_np], [A.name, B.name], result.name)
def test_forward_batch_matmul(): def test_forward_batch_matmul():
""" TF op BatchMatMul, BatchMatMulV2 test""" """ TF op BatchMatMul, BatchMatMulV2 test"""
_test_batch_matmul((3, 5, 4), (3, 4, 5), 'int32') _test_batch_matmul((3, 5, 4), (3, 4, 5), 'int32')
...@@ -846,9 +925,11 @@ def test_forward_batch_matmul(): ...@@ -846,9 +925,11 @@ def test_forward_batch_matmul():
_test_batch_matmul((3, 5, 4), (3, 5, 4), 'int32', True, False) _test_batch_matmul((3, 5, 4), (3, 5, 4), 'int32', True, False)
_test_batch_matmul((3, 5, 4), (3, 5, 4), 'float32', False, True) _test_batch_matmul((3, 5, 4), (3, 5, 4), 'float32', False, True)
_test_batch_matmul((2, 3, 4, 5, 6), (2, 3, 4, 6, 5), 'int32') _test_batch_matmul((2, 3, 4, 5, 6), (2, 3, 4, 6, 5), 'int32')
_test_batch_matmul((1, 2, 3, 4, 5, 6), (1, 2, 3, 4, 6, 5), 'float32', True, True) _test_batch_matmul((1, 2, 3, 4, 5, 6),
(1, 2, 3, 4, 6, 5), 'float32', True, True)
_test_batch_matmul((3, 4, 5, 6), (3, 4, 5, 6), 'int32', True, False) _test_batch_matmul((3, 4, 5, 6), (3, 4, 5, 6), 'int32', True, False)
_test_batch_matmul((2, 3, 4, 2, 3, 4, 5, 6), (2, 3, 4, 2, 3, 4, 5, 6), 'float32', False, True) _test_batch_matmul((2, 3, 4, 2, 3, 4, 5, 6),
(2, 3, 4, 2, 3, 4, 5, 6), 'float32', False, True)
####################################################################### #######################################################################
...@@ -870,16 +951,23 @@ def _test_stridedslice(ip_shape, begin, end, stride, dtype, ...@@ -870,16 +951,23 @@ def _test_stridedslice(ip_shape, begin, end, stride, dtype,
compare_tf_with_tvm(np_data, 'in_data:0', 'strided_slice:0') compare_tf_with_tvm(np_data, 'in_data:0', 'strided_slice:0')
def test_forward_stridedslice(): def test_forward_stridedslice():
'''test StridedSlice''' '''test StridedSlice'''
_test_stridedslice((2), [1], [1], [1], 'float32', shrink_axis_mask=1) _test_stridedslice((2), [1], [1], [1], 'float32', shrink_axis_mask=1)
_test_stridedslice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], 'float32') _test_stridedslice((3, 4, 3), [1, -1, 0],
_test_stridedslice((3, 4, 3), [1, 0], [4, 3], [2, 1], 'float32', ellipsis_mask=8) [4, -5, 3], [2, -1, 1], 'float32')
_test_stridedslice((3, 4, 3), [1, 0], [4, 2], [2, 1], 'float32', ellipsis_mask=2) _test_stridedslice((3, 4, 3), [1, 0], [4, 3], [
_test_stridedslice((3, 4, 5, 3), [1, 0], [4, 2], [2, 1], 'float32', ellipsis_mask=2) 2, 1], 'float32', ellipsis_mask=8)
_test_stridedslice((3, 4, 5, 3), [1, 0, 1], [4, 2, 2], [2, 1, 1], 'float32', ellipsis_mask=2) _test_stridedslice((3, 4, 3), [1, 0], [4, 2], [
_test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 2], [2, 1, 1], 'float32', new_axis_mask=5) 2, 1], 'float32', ellipsis_mask=2)
_test_stridedslice((3, 4, 5, 3), [1, 0], [4, 2], [
2, 1], 'float32', ellipsis_mask=2)
_test_stridedslice((3, 4, 5, 3), [1, 0, 1], [4, 2, 2], [
2, 1, 1], 'float32', ellipsis_mask=2)
_test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 2], [
2, 1, 1], 'float32', new_axis_mask=5)
_test_stridedslice((3, 4, 3), [1, 1, 1], [4, 4, 1], [2, 1, 1], 'float32', ellipsis_mask=2, _test_stridedslice((3, 4, 3), [1, 1, 1], [4, 4, 1], [2, 1, 1], 'float32', ellipsis_mask=2,
new_axis_mask=4) new_axis_mask=4)
_test_stridedslice((6, 4, 5), [1, 1, 1], [6, 3, 4], [2, 1, 1], 'float32', ellipsis_mask=2, _test_stridedslice((6, 4, 5), [1, 1, 1], [6, 3, 4], [2, 1, 1], 'float32', ellipsis_mask=2,
...@@ -892,7 +980,8 @@ def test_forward_stridedslice(): ...@@ -892,7 +980,8 @@ def test_forward_stridedslice():
new_axis_mask=3) new_axis_mask=3)
_test_stridedslice((3, 4, 3), [1, 1, 2], [4, 4, 3], [2, 1, 1], 'float32', ellipsis_mask=2, _test_stridedslice((3, 4, 3), [1, 1, 2], [4, 4, 3], [2, 1, 1], 'float32', ellipsis_mask=2,
new_axis_mask=2) new_axis_mask=2)
_test_stridedslice((3, 4), [1, 0], [4, 4], [1, 1], 'float32', shrink_axis_mask=2) _test_stridedslice((3, 4), [1, 0], [4, 4], [
1, 1], 'float32', shrink_axis_mask=2)
_test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], 'float32', shrink_axis_mask=2, _test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], 'float32', shrink_axis_mask=2,
new_axis_mask=2) new_axis_mask=2)
_test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], 'float32', shrink_axis_mask=1, _test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], 'float32', shrink_axis_mask=1,
...@@ -918,6 +1007,7 @@ def test_forward_stridedslice(): ...@@ -918,6 +1007,7 @@ def test_forward_stridedslice():
# FloorDiv, RealDiv # FloorDiv, RealDiv
# ----------------- # -----------------
def _test_forward_divide(ip_shape, dtype): def _test_forward_divide(ip_shape, dtype):
np_numer = np.random.uniform(-100, 100, size=ip_shape).astype(dtype) np_numer = np.random.uniform(-100, 100, size=ip_shape).astype(dtype)
np_denomin = np.random.uniform(1, 100, size=ip_shape).astype(dtype) np_denomin = np.random.uniform(1, 100, size=ip_shape).astype(dtype)
...@@ -925,7 +1015,9 @@ def _test_forward_divide(ip_shape, dtype): ...@@ -925,7 +1015,9 @@ def _test_forward_divide(ip_shape, dtype):
numerator = tf.placeholder(dtype, ip_shape, name="numer") numerator = tf.placeholder(dtype, ip_shape, name="numer")
denominator = tf.placeholder(dtype, ip_shape, name="denomin") denominator = tf.placeholder(dtype, ip_shape, name="denomin")
tf.math.divide(numerator, denominator, name='RealDiv') tf.math.divide(numerator, denominator, name='RealDiv')
compare_tf_with_tvm([np_numer, np_denomin], ['numer:0', 'denomin:0'], 'RealDiv:0') compare_tf_with_tvm([np_numer, np_denomin], [
'numer:0', 'denomin:0'], 'RealDiv:0')
def _test_forward_floordiv(ip_shape, dtype): def _test_forward_floordiv(ip_shape, dtype):
np_numer = np.random.uniform(-100, 100, size=ip_shape).astype(dtype) np_numer = np.random.uniform(-100, 100, size=ip_shape).astype(dtype)
...@@ -934,6 +1026,7 @@ def _test_forward_floordiv(ip_shape, dtype): ...@@ -934,6 +1026,7 @@ def _test_forward_floordiv(ip_shape, dtype):
tf.math.floordiv(numerator, tf.constant(5, dtype=dtype), name='FloorDiv') tf.math.floordiv(numerator, tf.constant(5, dtype=dtype), name='FloorDiv')
compare_tf_with_tvm([np_numer], ['numer:0'], 'FloorDiv:0') compare_tf_with_tvm([np_numer], ['numer:0'], 'FloorDiv:0')
def test_forward_divide(): def test_forward_divide():
'''test FloorDiv, RealDiv''' '''test FloorDiv, RealDiv'''
_test_forward_divide((4,), 'int32') _test_forward_divide((4,), 'int32')
...@@ -951,7 +1044,9 @@ def _test_forward_truncatemod(ip_shape, dtype): ...@@ -951,7 +1044,9 @@ def _test_forward_truncatemod(ip_shape, dtype):
in_data_1 = tf.placeholder(dtype, ip_shape, name="in_data_1") in_data_1 = tf.placeholder(dtype, ip_shape, name="in_data_1")
in_data_2 = tf.placeholder(dtype, ip_shape, name="in_data_2") in_data_2 = tf.placeholder(dtype, ip_shape, name="in_data_2")
tf.truncatemod(in_data_1, in_data_2, name='truncatemod') tf.truncatemod(in_data_1, in_data_2, name='truncatemod')
compare_tf_with_tvm([np_data_1, np_data_2], ['in_data_1:0', 'in_data_2:0'], 'truncatemod:0') compare_tf_with_tvm([np_data_1, np_data_2], [
'in_data_1:0', 'in_data_2:0'], 'truncatemod:0')
def test_forward_truncatemod(): def test_forward_truncatemod():
'''test TruncateMod''' '''test TruncateMod'''
...@@ -980,7 +1075,9 @@ def _test_gather(ip_shape, indice_shape, indice_value, axis, dtype): ...@@ -980,7 +1075,9 @@ def _test_gather(ip_shape, indice_shape, indice_value, axis, dtype):
return indices return indices
np_indices = _fill_indices(indice_value) np_indices = _fill_indices(indice_value)
compare_tf_with_tvm([np_data, np_indices], ['in_data:0', 'indices:0'], out.name) compare_tf_with_tvm([np_data, np_indices], [
'in_data:0', 'indices:0'], out.name)
def test_forward_gather(): def test_forward_gather():
'''test Gather/GatherV2 layer''' '''test Gather/GatherV2 layer'''
...@@ -995,6 +1092,7 @@ def test_forward_gather(): ...@@ -995,6 +1092,7 @@ def test_forward_gather():
_test_gather((3, 3, 3), (1, 1, 2), [[[1, 0]]], 2, 'int32') _test_gather((3, 3, 3), (1, 1, 2), [[[1, 0]]], 2, 'int32')
_test_gather((4, 3, 5, 6), (1, 4), [[2, 1, 0, 0]], 0, 'float32') _test_gather((4, 3, 5, 6), (1, 4), [[2, 1, 0, 0]], 0, 'float32')
def test_forward_gather_nd(): def test_forward_gather_nd():
"""test operator GatherNd""" """test operator GatherNd"""
np_data = np.random.uniform(1, 100, size=(2, 2)).astype(np.float32) np_data = np.random.uniform(1, 100, size=(2, 2)).astype(np.float32)
...@@ -1016,7 +1114,8 @@ def test_forward_bias_add(): ...@@ -1016,7 +1114,8 @@ def test_forward_bias_add():
lft_data = tf.placeholder(dtype, name="lft_data") lft_data = tf.placeholder(dtype, name="lft_data")
rgt_data = tf.placeholder(dtype, name="rgt_data") rgt_data = tf.placeholder(dtype, name="rgt_data")
tf.nn.bias_add(lft_data, rgt_data, name="BiasAdd") tf.nn.bias_add(lft_data, rgt_data, name="BiasAdd")
compare_tf_with_tvm([lh_data, rh_data], ['lft_data:0', 'rgt_data:0'], 'BiasAdd:0') compare_tf_with_tvm([lh_data, rh_data], [
'lft_data:0', 'rgt_data:0'], 'BiasAdd:0')
check_bias_add((10, 8, 16, 32), (32,), dtype="int32") check_bias_add((10, 8, 16, 32), (32,), dtype="int32")
check_bias_add((10, 20), (20,), dtype="float32") check_bias_add((10, 20), (20,), dtype="float32")
...@@ -1033,7 +1132,7 @@ def _test_split(in_shape, axis, num_or_size_splits, dtype): ...@@ -1033,7 +1132,7 @@ def _test_split(in_shape, axis, num_or_size_splits, dtype):
tf.reset_default_graph() tf.reset_default_graph()
in_data = tf.placeholder(dtype, in_shape, name="in_data") in_data = tf.placeholder(dtype, in_shape, name="in_data")
num_split = len(num_or_size_splits) if isinstance(num_or_size_splits, list)\ num_split = len(num_or_size_splits) if isinstance(num_or_size_splits, list)\
else num_or_size_splits else num_or_size_splits
split = tf.split(in_data, num_or_size_splits, axis=axis) split = tf.split(in_data, num_or_size_splits, axis=axis)
relu = [tf.nn.relu(i) for i in split] relu = [tf.nn.relu(i) for i in split]
...@@ -1047,6 +1146,7 @@ def _test_split(in_shape, axis, num_or_size_splits, dtype): ...@@ -1047,6 +1146,7 @@ def _test_split(in_shape, axis, num_or_size_splits, dtype):
compare_tf_with_tvm([np_data], 'in_data:0', 'concat:0') compare_tf_with_tvm([np_data], 'in_data:0', 'concat:0')
def test_forward_split(): def test_forward_split():
'''test split layer''' '''test split layer'''
# rank 1 # rank 1
...@@ -1086,6 +1186,7 @@ def _test_forward_top_k_v2(in_shape, k): ...@@ -1086,6 +1186,7 @@ def _test_forward_top_k_v2(in_shape, k):
tf.math.top_k(in_data, k, name='TopK') tf.math.top_k(in_data, k, name='TopK')
compare_tf_with_tvm([np_data], ['in_data:0'], 'TopK:0') compare_tf_with_tvm([np_data], ['in_data:0'], 'TopK:0')
def test_forward_top_k_v2(): def test_forward_top_k_v2():
_test_forward_top_k_v2((3,), 1) _test_forward_top_k_v2((3,), 1)
_test_forward_top_k_v2((3,), 3) _test_forward_top_k_v2((3,), 3)
...@@ -1112,6 +1213,7 @@ def _test_unstack(ip_shape, axis, dtype): ...@@ -1112,6 +1213,7 @@ def _test_unstack(ip_shape, axis, dtype):
compare_tf_with_tvm([np_data], ['in_data:0'], 'stack:0') compare_tf_with_tvm([np_data], ['in_data:0'], 'stack:0')
def test_forward_unstack(): def test_forward_unstack():
'''test unstack layer''' '''test unstack layer'''
_test_unstack((6,), 0, 'int32') _test_unstack((6,), 0, 'int32')
...@@ -1132,6 +1234,7 @@ def _test_tile(in_shape, multiples, dtype): ...@@ -1132,6 +1234,7 @@ def _test_tile(in_shape, multiples, dtype):
tf.tile(in_data, multiples=multiples, name="tile") tf.tile(in_data, multiples=multiples, name="tile")
compare_tf_with_tvm([np_data], ['in_data:0'], 'tile:0') compare_tf_with_tvm([np_data], ['in_data:0'], 'tile:0')
def test_forward_tile(): def test_forward_tile():
'''test Tile''' '''test Tile'''
_test_tile((2, ), (3, ), "int32") _test_tile((2, ), (3, ), "int32")
...@@ -1146,10 +1249,12 @@ def test_forward_tile(): ...@@ -1146,10 +1249,12 @@ def test_forward_tile():
def _test_forward_clip_by_value(ip_shape, clip_value_min, clip_value_max, dtype): def _test_forward_clip_by_value(ip_shape, clip_value_min, clip_value_max, dtype):
tf.reset_default_graph() tf.reset_default_graph()
in_data = tf.placeholder(dtype, ip_shape, name="in_data") in_data = tf.placeholder(dtype, ip_shape, name="in_data")
tf.clip_by_value(in_data, clip_value_min, clip_value_max, name="ClipByValue") tf.clip_by_value(in_data, clip_value_min,
clip_value_max, name="ClipByValue")
np_data = np.random.uniform(-100, 100, size=ip_shape).astype(dtype) np_data = np.random.uniform(-100, 100, size=ip_shape).astype(dtype)
compare_tf_with_tvm([np_data], ['in_data:0'], 'ClipByValue:0') compare_tf_with_tvm([np_data], ['in_data:0'], 'ClipByValue:0')
def test_forward_clip_by_value(): def test_forward_clip_by_value():
'''test ClipByValue op''' '''test ClipByValue op'''
if tf.__version__ < LooseVersion('1.9'): if tf.__version__ < LooseVersion('1.9'):
...@@ -1160,6 +1265,7 @@ def test_forward_clip_by_value(): ...@@ -1160,6 +1265,7 @@ def test_forward_clip_by_value():
# Multi Input to graph # Multi Input to graph
# -------------------- # --------------------
def test_forward_multi_input(): def test_forward_multi_input():
with tf.Graph().as_default(): with tf.Graph().as_default():
in1 = tf.placeholder(tf.int32, shape=[3, 3], name='in1') in1 = tf.placeholder(tf.int32, shape=[3, 3], name='in1')
...@@ -1179,6 +1285,7 @@ def test_forward_multi_input(): ...@@ -1179,6 +1285,7 @@ def test_forward_multi_input():
# Multi Output to Graph # Multi Output to Graph
# --------------------- # ---------------------
def test_forward_multi_output(): def test_forward_multi_output():
with tf.Graph().as_default(): with tf.Graph().as_default():
in1 = tf.placeholder(tf.int32, shape=[3, 3], name='in1') in1 = tf.placeholder(tf.int32, shape=[3, 3], name='in1')
...@@ -1202,12 +1309,14 @@ def test_forward_multi_output(): ...@@ -1202,12 +1309,14 @@ def test_forward_multi_output():
tvm_output = run_tvm_graph(final_graph_def, in_data, in_node, target='llvm', tvm_output = run_tvm_graph(final_graph_def, in_data, in_node, target='llvm',
out_names=out_node, num_output=2) out_names=out_node, num_output=2)
for i in range(len(tf_output)): for i in range(len(tf_output)):
tvm.testing.assert_allclose(tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5) tvm.testing.assert_allclose(
tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)
####################################################################### #######################################################################
# Resize Bilinear, Nearest_Neighbor # Resize Bilinear, Nearest_Neighbor
# --------------------------------- # ---------------------------------
def _test_resize_bilinear(in_shape, to_shape, align_corners): def _test_resize_bilinear(in_shape, to_shape, align_corners):
""" One iteration of resize bilinear """ """ One iteration of resize bilinear """
...@@ -1218,10 +1327,12 @@ def _test_resize_bilinear(in_shape, to_shape, align_corners): ...@@ -1218,10 +1327,12 @@ def _test_resize_bilinear(in_shape, to_shape, align_corners):
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
shape_data = constant_op.constant( shape_data = constant_op.constant(
shape_data, shape=shape_data.shape, dtype=shape_data.dtype) shape_data, shape=shape_data.shape, dtype=shape_data.dtype)
tf.image.resize_bilinear(in_data, shape_data, align_corners=align_corners) tf.image.resize_bilinear(
in_data, shape_data, align_corners=align_corners)
compare_tf_with_tvm(data, 'Placeholder:0', 'ResizeBilinear:0') compare_tf_with_tvm(data, 'Placeholder:0', 'ResizeBilinear:0')
def _test_resize_bilinear_from_tensor(in_shape, align_corners): def _test_resize_bilinear_from_tensor(in_shape, align_corners):
""" One iteration of resize bilinear with non-constant output shape, requires """ One iteration of resize bilinear with non-constant output shape, requires
value inference to get proper output shape.""" value inference to get proper output shape."""
...@@ -1232,7 +1343,8 @@ def _test_resize_bilinear_from_tensor(in_shape, align_corners): ...@@ -1232,7 +1343,8 @@ def _test_resize_bilinear_from_tensor(in_shape, align_corners):
in_data = array_ops.placeholder( in_data = array_ops.placeholder(
shape=[in_shape[0], in_shape[1], None, None], dtype=data.dtype) shape=[in_shape[0], in_shape[1], None, None], dtype=data.dtype)
to_shape = tf.shape(in_data)[2:] to_shape = tf.shape(in_data)[2:]
tf.image.resize_bilinear(in_data, to_shape, align_corners=align_corners) tf.image.resize_bilinear(
in_data, to_shape, align_corners=align_corners)
compare_tf_with_tvm(data, 'Placeholder:0', 'ResizeBilinear:0') compare_tf_with_tvm(data, 'Placeholder:0', 'ResizeBilinear:0')
...@@ -1247,7 +1359,8 @@ def _test_resize_nearest_neighbor(in_shape, to_shape): ...@@ -1247,7 +1359,8 @@ def _test_resize_nearest_neighbor(in_shape, to_shape):
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
shape_data = constant_op.constant( shape_data = constant_op.constant(
shape_data, shape=shape_data.shape, dtype=shape_data.dtype) shape_data, shape=shape_data.shape, dtype=shape_data.dtype)
tf.image.resize_nearest_neighbor(in_data, shape_data, name='resize_nearest_neighbor') tf.image.resize_nearest_neighbor(
in_data, shape_data, name='resize_nearest_neighbor')
compare_tf_with_tvm(data, 'Placeholder:0', 'resize_nearest_neighbor:0') compare_tf_with_tvm(data, 'Placeholder:0', 'resize_nearest_neighbor:0')
...@@ -1278,7 +1391,8 @@ def _test_broadcast_to(in_shape, to_shape): ...@@ -1278,7 +1391,8 @@ def _test_broadcast_to(in_shape, to_shape):
shape_data, shape=shape_data.shape, dtype=shape_data.dtype) shape_data, shape=shape_data.shape, dtype=shape_data.dtype)
tf.broadcast_to(in_data, shape_data) tf.broadcast_to(in_data, shape_data)
compare_tf_with_tvm(data, 'Placeholder:0', 'BroadcastTo:0', opt_level=0) compare_tf_with_tvm(data, 'Placeholder:0',
'BroadcastTo:0', opt_level=0)
def _test_broadcast_to_from_tensor(in_shape): def _test_broadcast_to_from_tensor(in_shape):
...@@ -1315,6 +1429,7 @@ def _test_fill(in_shape): ...@@ -1315,6 +1429,7 @@ def _test_fill(in_shape):
tf.ones(shape=in_shape, dtype='float32') tf.ones(shape=in_shape, dtype='float32')
compare_tf_with_tvm(in_shape, [], 'ones:0', opt_level=1) compare_tf_with_tvm(in_shape, [], 'ones:0', opt_level=1)
def _test_fill_from_tensor(in_shape): def _test_fill_from_tensor(in_shape):
""" Use the fill op to create a tensor of ones with non-constant shape. """ Use the fill op to create a tensor of ones with non-constant shape.
Some extra ops need to be added here to prevent the graph from Some extra ops need to be added here to prevent the graph from
...@@ -1330,6 +1445,7 @@ def _test_fill_from_tensor(in_shape): ...@@ -1330,6 +1445,7 @@ def _test_fill_from_tensor(in_shape):
y = tf.math.add(in_data, tf.reduce_mean(x), name='out1') y = tf.math.add(in_data, tf.reduce_mean(x), name='out1')
compare_tf_with_tvm(data, 'Placeholder:0', 'out1:0') compare_tf_with_tvm(data, 'Placeholder:0', 'out1:0')
def test_forward_fill(): def test_forward_fill():
""" Resize Bilinear """ """ Resize Bilinear """
...@@ -1341,13 +1457,16 @@ def test_forward_fill(): ...@@ -1341,13 +1457,16 @@ def test_forward_fill():
# Crop to bounding box # Crop to bounding box
# -------------------- # --------------------
def _test_crop(in_shape, off_h, off_w, tar_h, tar_w): def _test_crop(in_shape, off_h, off_w, tar_h, tar_w):
""" Crop to bounding box """ """ Crop to bounding box """
data = np.random.uniform(size=in_shape).astype('float32') data = np.random.uniform(size=in_shape).astype('float32')
with tf.Graph().as_default(): with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
tf.image.crop_to_bounding_box(in_data, off_h, off_w, tar_h, tar_w) tf.image.crop_to_bounding_box(in_data, off_h, off_w, tar_h, tar_w)
compare_tf_with_tvm(data, 'Placeholder:0', 'crop_to_bounding_box/Slice:0') compare_tf_with_tvm(data, 'Placeholder:0',
'crop_to_bounding_box/Slice:0')
def test_forward_crop(): def test_forward_crop():
""" Crop to bounding box """ """ Crop to bounding box """
...@@ -1366,19 +1485,25 @@ def _test_forward_crop_and_resize(img_shape, boxes, box_idx, crop_size, method=' ...@@ -1366,19 +1485,25 @@ def _test_forward_crop_and_resize(img_shape, boxes, box_idx, crop_size, method='
method=method, name="crop_and_resize") method=method, name="crop_and_resize")
compare_tf_with_tvm([image], ['in_data:0'], 'crop_and_resize:0') compare_tf_with_tvm([image], ['in_data:0'], 'crop_and_resize:0')
def test_forward_crop_and_resize(): def test_forward_crop_and_resize():
""" CropAndResize """ """ CropAndResize """
_test_forward_crop_and_resize([1, 11, 11, 3], [[0, 0, 1, 1]], [0], [5, 5]) _test_forward_crop_and_resize([1, 11, 11, 3], [[0, 0, 1, 1]], [0], [5, 5])
_test_forward_crop_and_resize([1, 11, 11, 3], [[0, 0, .9, .9]], [0], [5, 5]) _test_forward_crop_and_resize(
_test_forward_crop_and_resize([1, 11, 11, 3], [[.1, .2, 1, 1]], [0], [5, 5]) [1, 11, 11, 3], [[0, 0, .9, .9]], [0], [5, 5])
_test_forward_crop_and_resize([1, 21, 21, 3], [[.2, .3, .7, .9]], [0], [3, 4]) _test_forward_crop_and_resize(
_test_forward_crop_and_resize([1, 41, 41, 3], [[0.2, 0.4, 0.8, 0.8]], [0], [3, 3]) [1, 11, 11, 3], [[.1, .2, 1, 1]], [0], [5, 5])
_test_forward_crop_and_resize(
[1, 21, 21, 3], [[.2, .3, .7, .9]], [0], [3, 4])
_test_forward_crop_and_resize(
[1, 41, 41, 3], [[0.2, 0.4, 0.8, 0.8]], [0], [3, 3])
_test_forward_crop_and_resize([10, 11, 11, 3], _test_forward_crop_and_resize([10, 11, 11, 3],
[[0, 0, 0.9, 0.9], [0.2, 0.2, 0.8, 0.8]], [[0, 0, 0.9, 0.9], [0.2, 0.2, 0.8, 0.8]],
[0, 1], [0, 1],
[5, 5]) [5, 5])
_test_forward_crop_and_resize([3, 11, 11, 3], _test_forward_crop_and_resize([3, 11, 11, 3],
[[0, 0, 0.9, 0.9], [0.2, 0.2, 0.8, 0.8],[0, 0, 1, 1]], [[0, 0, 0.9, 0.9], [
0.2, 0.2, 0.8, 0.8], [0, 0, 1, 1]],
[0, 1, 2], [0, 1, 2],
[3, 3]) [3, 3])
_test_forward_crop_and_resize([3, 11, 11, 3], _test_forward_crop_and_resize([3, 11, 11, 3],
...@@ -1397,8 +1522,10 @@ def _test_lstm_cell(batch_size, num_hidden, num_layers, forget_bias, dtype): ...@@ -1397,8 +1522,10 @@ def _test_lstm_cell(batch_size, num_hidden, num_layers, forget_bias, dtype):
tf.reset_default_graph() tf.reset_default_graph()
input_size = num_hidden input_size = num_hidden
input_data = np.full((batch_size, input_size), 1., dtype=dtype) input_data = np.full((batch_size, input_size), 1., dtype=dtype)
in_state_c = np.full((num_layers, batch_size, num_hidden), 0.1, dtype=dtype) in_state_c = np.full(
in_state_h = np.full((num_layers, batch_size, num_hidden), 0.1, dtype=dtype) (num_layers, batch_size, num_hidden), 0.1, dtype=dtype)
in_state_h = np.full(
(num_layers, batch_size, num_hidden), 0.1, dtype=dtype)
def _get_tensorflow_output(): def _get_tensorflow_output():
with tf.Session() as sess: with tf.Session() as sess:
...@@ -1408,8 +1535,8 @@ def _test_lstm_cell(batch_size, num_hidden, num_layers, forget_bias, dtype): ...@@ -1408,8 +1535,8 @@ def _test_lstm_cell(batch_size, num_hidden, num_layers, forget_bias, dtype):
m1 = array_ops.zeros([batch_size, num_hidden]) m1 = array_ops.zeros([batch_size, num_hidden])
x = tf.placeholder(shape=(batch_size, input_size), dtype=dtype) x = tf.placeholder(shape=(batch_size, input_size), dtype=dtype)
g, ((out_m0, out_m1)) = \ g, ((out_m0, out_m1)) = \
tf.contrib.rnn.LSTMBlockCell(num_hidden, tf.contrib.rnn.LSTMBlockCell(num_hidden,
forget_bias=forget_bias)(x, ((m0, m1))) forget_bias=forget_bias)(x, ((m0, m1)))
sess.run([variables.global_variables_initializer()]) sess.run([variables.global_variables_initializer()])
res = sess.run([g, out_m0, out_m1], { res = sess.run([g, out_m0, out_m1], {
x.name: np.array([[1., 1.]]), x.name: np.array([[1., 1.]]),
...@@ -1437,12 +1564,12 @@ def _test_lstm_cell(batch_size, num_hidden, num_layers, forget_bias, dtype): ...@@ -1437,12 +1564,12 @@ def _test_lstm_cell(batch_size, num_hidden, num_layers, forget_bias, dtype):
tvm_out = [out, out_state_c, out_state_h] tvm_out = [out, out_state_c, out_state_h]
tvm.testing.assert_allclose(tf_out[0], tvm_out[0], rtol=1e-3, atol=1e-3) tvm.testing.assert_allclose(tf_out[0], tvm_out[0], rtol=1e-3, atol=1e-3)
def test_forward_lstm(): def test_forward_lstm():
'''test LSTM block cell''' '''test LSTM block cell'''
_test_lstm_cell(1, 2, 1, 0.5, 'float32') _test_lstm_cell(1, 2, 1, 0.5, 'float32')
####################################################################### #######################################################################
# Pack # Pack
# --- # ---
...@@ -1459,6 +1586,7 @@ def _test_pack(axis, shape, **kwargs): ...@@ -1459,6 +1586,7 @@ def _test_pack(axis, shape, **kwargs):
compare_tf_with_tvm([a, b], ['pl_a:0', 'pl_b:0'], 'stack:0') compare_tf_with_tvm([a, b], ['pl_a:0', 'pl_b:0'], 'stack:0')
def test_forward_pack(): def test_forward_pack():
for axis in range(-3, 3): for axis in range(-3, 3):
_test_pack(axis, [3, 2, 1]) _test_pack(axis, [3, 2, 1])
...@@ -1478,6 +1606,7 @@ def _test_forward_unpack(in_shape, axis, dtype): ...@@ -1478,6 +1606,7 @@ def _test_forward_unpack(in_shape, axis, dtype):
tf.unstack(in_data, axis=axis, name="Unpack") tf.unstack(in_data, axis=axis, name="Unpack")
compare_tf_with_tvm([np_data], ['in_data:0'], 'Unpack:0') compare_tf_with_tvm([np_data], ['in_data:0'], 'Unpack:0')
def test_forward_unpack(): def test_forward_unpack():
_test_forward_unpack((3,), 0, 'int32') _test_forward_unpack((3,), 0, 'int32')
_test_forward_unpack((3,), -1, 'int16') _test_forward_unpack((3,), -1, 'int16')
...@@ -1486,6 +1615,8 @@ def test_forward_unpack(): ...@@ -1486,6 +1615,8 @@ def test_forward_unpack():
####################################################################### #######################################################################
# Range # Range
# ----- # -----
def test_forward_range(): def test_forward_range():
"""test operator Range""" """test operator Range"""
tf.reset_default_graph() tf.reset_default_graph()
...@@ -1495,6 +1626,8 @@ def test_forward_range(): ...@@ -1495,6 +1626,8 @@ def test_forward_range():
####################################################################### #######################################################################
# Pad # Pad
# --- # ---
def _test_pad(input_shape, paddings, mode, **kwargs): def _test_pad(input_shape, paddings, mode, **kwargs):
""" One iteration of pad operation with given shape""" """ One iteration of pad operation with given shape"""
...@@ -1515,6 +1648,7 @@ def _test_pad(input_shape, paddings, mode, **kwargs): ...@@ -1515,6 +1648,7 @@ def _test_pad(input_shape, paddings, mode, **kwargs):
compare_tf_with_tvm(x, 'Placeholder:0', out_name) compare_tf_with_tvm(x, 'Placeholder:0', out_name)
def test_forward_pad(): def test_forward_pad():
""" Pad """ """ Pad """
_test_pad((2, 3), [[1, 1], [2, 2]], mode="CONSTANT") _test_pad((2, 3), [[1, 1], [2, 2]], mode="CONSTANT")
...@@ -1525,40 +1659,53 @@ def test_forward_pad(): ...@@ -1525,40 +1659,53 @@ def test_forward_pad():
####################################################################### #######################################################################
# Logical operators # Logical operators
# -------------------- # --------------------
def test_logical_and(): def test_logical_and():
with tf.Graph().as_default(): with tf.Graph().as_default():
in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1') in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1')
in2 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in2') in2 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in2')
out = tf.logical_and(in1, in2, name='out') out = tf.logical_and(in1, in2, name='out')
in_data1 = np.random.choice(a=[False, True], size=(1, 4, 4, 3)).astype('bool') in_data1 = np.random.choice(
in_data2 = np.random.choice(a=[False, True], size=(1, 4, 4, 3)).astype('bool') a=[False, True], size=(1, 4, 4, 3)).astype('bool')
in_data2 = np.random.choice(
a=[False, True], size=(1, 4, 4, 3)).astype('bool')
compare_tf_with_tvm([in_data1, in_data2], ['in1:0', 'in2:0'], 'out:0') compare_tf_with_tvm([in_data1, in_data2], ['in1:0', 'in2:0'], 'out:0')
def test_logical_or(): def test_logical_or():
with tf.Graph().as_default(): with tf.Graph().as_default():
in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1') in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1')
in2 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in2') in2 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in2')
out = tf.logical_or(in1, in2, name='out') out = tf.logical_or(in1, in2, name='out')
in_data1 = np.random.choice(a=[False, True], size=(1, 4, 4, 3)).astype('bool') in_data1 = np.random.choice(
in_data2 = np.random.choice(a=[False, True], size=(1, 4, 4, 3)).astype('bool') a=[False, True], size=(1, 4, 4, 3)).astype('bool')
in_data2 = np.random.choice(
a=[False, True], size=(1, 4, 4, 3)).astype('bool')
compare_tf_with_tvm([in_data1, in_data2], ['in1:0', 'in2:0'], 'out:0') compare_tf_with_tvm([in_data1, in_data2], ['in1:0', 'in2:0'], 'out:0')
def test_logical_xor(): def test_logical_xor():
with tf.Graph().as_default(): with tf.Graph().as_default():
in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1') in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1')
in2 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in2') in2 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in2')
out = tf.logical_xor(in1, in2, name='out') out = tf.logical_xor(in1, in2, name='out')
in_data1 = np.random.choice(a=[False, True], size=(1, 4, 4, 3)).astype('bool') in_data1 = np.random.choice(
in_data2 = np.random.choice(a=[False, True], size=(1, 4, 4, 3)).astype('bool') a=[False, True], size=(1, 4, 4, 3)).astype('bool')
in_data2 = np.random.choice(
a=[False, True], size=(1, 4, 4, 3)).astype('bool')
compare_tf_with_tvm([in_data1, in_data2], ['in1:0', 'in2:0'], 'out:0') compare_tf_with_tvm([in_data1, in_data2], ['in1:0', 'in2:0'], 'out:0')
def test_logical_not(): def test_logical_not():
with tf.Graph().as_default(): with tf.Graph().as_default():
in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1') in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1')
out = tf.logical_not(in1, name='out') out = tf.logical_not(in1, name='out')
in_data1 = np.random.choice(a=[False, True], size=(1, 4, 4, 3)).astype('bool') in_data1 = np.random.choice(
a=[False, True], size=(1, 4, 4, 3)).astype('bool')
compare_tf_with_tvm(in_data1, 'in1:0', 'out:0') compare_tf_with_tvm(in_data1, 'in1:0', 'out:0')
def test_forward_logical(): def test_forward_logical():
test_logical_and() test_logical_and()
test_logical_or() test_logical_or()
...@@ -1573,13 +1720,18 @@ def test_forward_where(): ...@@ -1573,13 +1720,18 @@ def test_forward_where():
''' Where: return elements depending on conditions''' ''' Where: return elements depending on conditions'''
with tf.Graph().as_default(): with tf.Graph().as_default():
with tf.Session() as sess: with tf.Session() as sess:
input1 = tf.placeholder(tf.int32, shape=[1, 4, 4, 3], name='input1') input1 = tf.placeholder(
input2 = tf.placeholder(tf.int32, shape=[1, 4, 4, 3], name='input2') tf.int32, shape=[1, 4, 4, 3], name='input1')
input2 = tf.placeholder(
tf.int32, shape=[1, 4, 4, 3], name='input2')
mask = input1 > input2 mask = input1 > input2
tf.where(mask, input1 + 1, input2 * 2) tf.where(mask, input1 + 1, input2 * 2)
in_data1 = np.random.uniform(0, 10, size=(1, 4, 4, 3)).astype("uint32") in_data1 = np.random.uniform(
in_data2 = np.random.uniform(0, 10, size=(1, 4, 4, 3)).astype("uint32") 0, 10, size=(1, 4, 4, 3)).astype("uint32")
compare_tf_with_tvm([in_data1, in_data2], ['input1:0', 'input2:0'], 'Select:0') in_data2 = np.random.uniform(
0, 10, size=(1, 4, 4, 3)).astype("uint32")
compare_tf_with_tvm([in_data1, in_data2], [
'input1:0', 'input2:0'], 'Select:0')
####################################################################### #######################################################################
...@@ -1596,17 +1748,22 @@ def test_forward_inception_v3(): ...@@ -1596,17 +1748,22 @@ def test_forward_inception_v3():
data = np.random.uniform(size=(1, 299, 299, 3)).astype('float32') data = np.random.uniform(size=(1, 299, 299, 3)).astype('float32')
with tf.Session() as sess: with tf.Session() as sess:
tf_output = run_tf_graph(sess, data, 'input:0', 'InceptionV3/Predictions/Reshape_1:0') tf_output = run_tf_graph(
sess, data, 'input:0', 'InceptionV3/Predictions/Reshape_1:0')
tvm_output = run_tvm_graph(graph_def, data, 'input') tvm_output = run_tvm_graph(graph_def, data, 'input')
tvm.testing.assert_allclose(tf_output[0], tvm_output[0], rtol=1e-5, atol=1e-5) tvm.testing.assert_allclose(
tf_output[0], tvm_output[0], rtol=1e-5, atol=1e-5)
####################################################################### #######################################################################
# Inception V1 # Inception V1
# ------------ # ------------
def test_forward_inception_v1(): def test_forward_inception_v1():
'''test inception V1 model''' '''test inception V1 model'''
with tf.Graph().as_default(): with tf.Graph().as_default():
graph_def = tf_testing.get_workload("InceptionV1/classify_image_graph_def-with_shapes.pb") graph_def = tf_testing.get_workload(
"InceptionV1/classify_image_graph_def-with_shapes.pb")
# Call the utility to import the graph definition into default graph. # Call the utility to import the graph definition into default graph.
graph_def = tf_testing.ProcessGraphDefParam(graph_def) graph_def = tf_testing.ProcessGraphDefParam(graph_def)
...@@ -1615,7 +1772,8 @@ def test_forward_inception_v1(): ...@@ -1615,7 +1772,8 @@ def test_forward_inception_v1():
from tvm.contrib import util from tvm.contrib import util
img_array = np.random.uniform(size=(1, 600, 600, 3)).astype("uint8") img_array = np.random.uniform(size=(1, 600, 600, 3)).astype("uint8")
img = Image.frombuffer('RGB', (600, 600), img_array.tostring(), 'raw', 'RGB', 0, 1) img = Image.frombuffer(
'RGB', (600, 600), img_array.tostring(), 'raw', 'RGB', 0, 1)
temp = util.tempdir() temp = util.tempdir()
img_path = temp.relpath("tf-test.jpg") img_path = temp.relpath("tf-test.jpg")
img.save(img_path) img.save(img_path)
...@@ -1629,16 +1787,22 @@ def test_forward_inception_v1(): ...@@ -1629,16 +1787,22 @@ def test_forward_inception_v1():
# Extract tensorflow decoded image frame for tvm input # Extract tensorflow decoded image frame for tvm input
with tf.Session() as sess: with tf.Session() as sess:
tvm_data = run_tf_graph(sess, data, 'DecodeJpeg/contents:0', 'DecodeJpeg:0') tvm_data = run_tf_graph(
sess, data, 'DecodeJpeg/contents:0', 'DecodeJpeg:0')
with tf.Session() as sess: with tf.Session() as sess:
tf_output = run_tf_graph(sess, data, 'DecodeJpeg/contents:0', 'softmax:0') tf_output = run_tf_graph(
tvm_output = run_tvm_graph(graph_def, tvm_data, 'DecodeJpeg/contents') sess, data, 'DecodeJpeg/contents:0', 'softmax:0')
tvm.testing.assert_allclose(tf_output[0], tvm_output[0], rtol=1e-5, atol=1e-5) tvm_output = run_tvm_graph(
graph_def, tvm_data, 'DecodeJpeg/contents')
tvm.testing.assert_allclose(
tf_output[0], tvm_output[0], rtol=1e-5, atol=1e-5)
####################################################################### #######################################################################
# Mobilenet # Mobilenet
# --------- # ---------
def test_forward_mobilenet(): def test_forward_mobilenet():
'''test mobilenet model''' '''test mobilenet model'''
# MobilenetV2 # MobilenetV2
...@@ -1663,6 +1827,8 @@ def test_forward_mobilenet(): ...@@ -1663,6 +1827,8 @@ def test_forward_mobilenet():
####################################################################### #######################################################################
# ResnetV2 # ResnetV2
# -------- # --------
def test_forward_resnetv2(): def test_forward_resnetv2():
'''test resnet model''' '''test resnet model'''
if is_gpu_available(): if is_gpu_available():
...@@ -1676,7 +1842,8 @@ def test_forward_resnetv2(): ...@@ -1676,7 +1842,8 @@ def test_forward_resnetv2():
out_node = 'ArgMax' out_node = 'ArgMax'
with tf.Session() as sess: with tf.Session() as sess:
tf_output = run_tf_graph(sess, data, 'input_tensor:0', out_node + ':0') tf_output = run_tf_graph(
sess, data, 'input_tensor:0', out_node + ':0')
for device in ["llvm", "cuda"]: for device in ["llvm", "cuda"]:
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
if not ctx.exist: if not ctx.exist:
...@@ -1690,6 +1857,8 @@ def test_forward_resnetv2(): ...@@ -1690,6 +1857,8 @@ def test_forward_resnetv2():
####################################################################### #######################################################################
# Placeholder # Placeholder
# ----------- # -----------
def test_forward_placeholder(): def test_forward_placeholder():
'''test a simple pb with Placeholder node in the end of GraphDef''' '''test a simple pb with Placeholder node in the end of GraphDef'''
with tf.Graph().as_default(): with tf.Graph().as_default():
...@@ -1703,15 +1872,19 @@ def test_forward_placeholder(): ...@@ -1703,15 +1872,19 @@ def test_forward_placeholder():
with tf.Session() as sess: with tf.Session() as sess:
# Add shapes to the graph. # Add shapes to the graph.
graph_def = tf_testing.AddShapesToGraphDef(sess, out_node) graph_def = tf_testing.AddShapesToGraphDef(sess, out_node)
tf_output = run_tf_graph(sess, data, 'Placeholder:0', out_node + ':0') tf_output = run_tf_graph(
sess, data, 'Placeholder:0', out_node + ':0')
tvm_output = run_tvm_graph(graph_def, data, 'Placeholder') tvm_output = run_tvm_graph(graph_def, data, 'Placeholder')
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]),
rtol=1e-5, atol=1e-5) rtol=1e-5, atol=1e-5)
####################################################################### #######################################################################
# PTB # PTB
# --- # ---
dir(tf.contrib) dir(tf.contrib)
def test_forward_ptb(): def test_forward_ptb():
'''test ptb model''' '''test ptb model'''
config = tf_testing.get_config() config = tf_testing.get_config()
...@@ -1722,7 +1895,7 @@ def test_forward_ptb(): ...@@ -1722,7 +1895,7 @@ def test_forward_ptb():
vocab_size = config.vocab_size vocab_size = config.vocab_size
out_sample_shape = (batch_size, vocab_size) out_sample_shape = (batch_size, vocab_size)
out_state_shape = (num_layers, 2, batch_size, num_hidden) out_state_shape = (num_layers, 2, batch_size, num_hidden)
#Sample input # Sample input
inpt = "we have no useful information on" inpt = "we have no useful information on"
cnt_sample = 20 cnt_sample = 20
...@@ -1733,18 +1906,19 @@ def test_forward_ptb(): ...@@ -1733,18 +1906,19 @@ def test_forward_ptb():
return ''.join([id2word[x] for x in items]).replace('_', ' ') return ''.join([id2word[x] for x in items]).replace('_', ' ')
def _get_tvm_graph_module(graph_def): def _get_tvm_graph_module(graph_def):
#Cell inputs 'c and 'h' consist of all layers values # Cell inputs 'c and 'h' consist of all layers values
shape_dict = {'Model/Placeholder': (batch_size, num_steps), shape_dict = {'Model/Placeholder': (batch_size, num_steps),
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_c': 'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_c':
(num_layers, batch_size, num_hidden), (num_layers, batch_size, num_hidden),
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_h': 'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_h':
(num_layers, batch_size, num_hidden)} (num_layers, batch_size, num_hidden)}
mod, params = relay.frontend.from_tensorflow(graph_def, shape=shape_dict) mod, params = relay.frontend.from_tensorflow(
graph_def, shape=shape_dict)
dtype_dict = {'Model/Placeholder': 'int32', dtype_dict = {'Model/Placeholder': 'int32',
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_c':'float32', 'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_c': 'float32',
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_h':'float32'} 'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_h': 'float32'}
target = 'llvm' target = 'llvm'
with relay.build_config(opt_level=0): with relay.build_config(opt_level=0):
graph, lib, params = relay.build(mod, graph, lib, params = relay.build(mod,
...@@ -1759,13 +1933,17 @@ def test_forward_ptb(): ...@@ -1759,13 +1933,17 @@ def test_forward_ptb():
samples = [] samples = []
state = in_states state = in_states
sample = None sample = None
def _get_sample(data, state): def _get_sample(data, state):
input_data = np.full((batch_size, num_steps), data, dtype="int32") input_data = np.full((batch_size, num_steps), data, dtype="int32")
in_state_tup = np.split(state, indices_or_sections=2, axis=1) in_state_tup = np.split(state, indices_or_sections=2, axis=1)
in_state_c = np.reshape(in_state_tup[0], (num_layers, batch_size, num_hidden)) in_state_c = np.reshape(
in_state_h = np.reshape(in_state_tup[1], (num_layers, batch_size, num_hidden)) in_state_tup[0], (num_layers, batch_size, num_hidden))
in_state_h = np.reshape(
in_state_tup[1], (num_layers, batch_size, num_hidden))
model.set_input('Model/Placeholder', tvm.nd.array(input_data.astype("int32"))) model.set_input('Model/Placeholder',
tvm.nd.array(input_data.astype("int32")))
model.set_input('Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_c', model.set_input('Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_c',
tvm.nd.array(in_state_c.astype("float32"))) tvm.nd.array(in_state_c.astype("float32")))
model.set_input('Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_h', model.set_input('Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_h',
...@@ -1802,16 +1980,17 @@ def test_forward_ptb(): ...@@ -1802,16 +1980,17 @@ def test_forward_ptb():
graph_def = tf_testing.ProcessGraphDefParam(graph_def) graph_def = tf_testing.ProcessGraphDefParam(graph_def)
sess = tf.Session() sess = tf.Session()
#TVM graph module creation # TVM graph module creation
params, m = _get_tvm_graph_module(graph_def) params, m = _get_tvm_graph_module(graph_def)
# Create 10 predicted statments of 20 words # Create 10 predicted statments of 20 words
cnt_stm = 0 cnt_stm = 0
while cnt_stm < 10: while cnt_stm < 10:
cnt_stm += 1 cnt_stm += 1
in_state = np.full((num_layers, 2, batch_size, num_hidden), 0, dtype="float32") in_state = np.full(
(num_layers, 2, batch_size, num_hidden), 0, dtype="float32")
seed_for_sample = inpt.split() seed_for_sample = inpt.split()
tvm_samples, tvm_state = _do_tvm_sample(m, [word_to_id[word] \ tvm_samples, tvm_state = _do_tvm_sample(m, [word_to_id[word]
for word in seed_for_sample], for word in seed_for_sample],
in_state, params, cnt_sample) in_state, params, cnt_sample)
tvm_sample_str = _pretty_print(tvm_samples, False, id_to_word) tvm_sample_str = _pretty_print(tvm_samples, False, id_to_word)
...@@ -1821,13 +2000,15 @@ def test_forward_ptb(): ...@@ -1821,13 +2000,15 @@ def test_forward_ptb():
in_state, cnt_sample) in_state, cnt_sample)
tf_sample_str = _pretty_print(tf_samples, False, id_to_word) tf_sample_str = _pretty_print(tf_samples, False, id_to_word)
inpt = tvm_sample_str inpt = tvm_sample_str
tvm.testing.assert_allclose(tf_samples, tvm_samples, rtol=1e-5, atol=1e-5) tvm.testing.assert_allclose(
tf_samples, tvm_samples, rtol=1e-5, atol=1e-5)
assert tvm_sample_str == tf_sample_str assert tvm_sample_str == tf_sample_str
####################################################################### #######################################################################
# LRN (Local Response Normalization) # LRN (Local Response Normalization)
# ---------------------------------- # ----------------------------------
def _test_lrn(ishape, size, axis, bias, alpha, beta): def _test_lrn(ishape, size, axis, bias, alpha, beta):
""" testing local response normalization """ """ testing local response normalization """
lrn_depth_radius = size / 2 lrn_depth_radius = size / 2
...@@ -1835,7 +2016,8 @@ def _test_lrn(ishape, size, axis, bias, alpha, beta): ...@@ -1835,7 +2016,8 @@ def _test_lrn(ishape, size, axis, bias, alpha, beta):
inp_array = np.random.uniform(size=ishape).astype(np.float32) inp_array = np.random.uniform(size=ishape).astype(np.float32)
with tf.Graph().as_default(): with tf.Graph().as_default():
in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype, name="lrn0_data") in1 = tf.placeholder(shape=inp_array.shape,
dtype=inp_array.dtype, name="lrn0_data")
nn_ops.local_response_normalization(in1, nn_ops.local_response_normalization(in1,
name="lrn", name="lrn",
depth_radius=lrn_depth_radius, depth_radius=lrn_depth_radius,
...@@ -1845,6 +2027,7 @@ def _test_lrn(ishape, size, axis, bias, alpha, beta): ...@@ -1845,6 +2027,7 @@ def _test_lrn(ishape, size, axis, bias, alpha, beta):
compare_tf_with_tvm(inp_array, 'lrn0_data:0', 'lrn:0') compare_tf_with_tvm(inp_array, 'lrn0_data:0', 'lrn:0')
def test_forward_lrn(): def test_forward_lrn():
_test_lrn((1, 3, 20, 20), 3, 1, 1.0, 1.0, 0.5) _test_lrn((1, 3, 20, 20), 3, 1, 1.0, 1.0, 0.5)
...@@ -1852,6 +2035,7 @@ def test_forward_lrn(): ...@@ -1852,6 +2035,7 @@ def test_forward_lrn():
# l2_normalize # l2_normalize
# ------------ # ------------
def _test_l2_normalize(ishape, eps, axis): def _test_l2_normalize(ishape, eps, axis):
""" testing l2 normalize (uses max, sum, square, sqrt frontend operators)""" """ testing l2 normalize (uses max, sum, square, sqrt frontend operators)"""
...@@ -1867,17 +2051,21 @@ def _test_l2_normalize(ishape, eps, axis): ...@@ -1867,17 +2051,21 @@ def _test_l2_normalize(ishape, eps, axis):
compare_tf_with_tvm(inp_array, 'Placeholder:0', 'l2_normalize:0') compare_tf_with_tvm(inp_array, 'Placeholder:0', 'l2_normalize:0')
def test_forward_l2_normalize(): def test_forward_l2_normalize():
_test_l2_normalize((1, 3, 20, 20), 0.001, (0,)) _test_l2_normalize((1, 3, 20, 20), 0.001, (0,))
####################################################################### #######################################################################
# transpose # transpose
# --------- # ---------
def _test_forward_transpose(ishape, axes=None): def _test_forward_transpose(ishape, axes=None):
data = np.random.uniform(size=ishape).astype(np.float32) data = np.random.uniform(size=ishape).astype(np.float32)
with tf.Graph().as_default(): with tf.Graph().as_default():
in1 = tf.placeholder(shape=data.shape, dtype=data.dtype, name="transpose_data") in1 = tf.placeholder(
shape=data.shape, dtype=data.dtype, name="transpose_data")
if axes is None: if axes is None:
tf.transpose(in1) tf.transpose(in1)
...@@ -1886,6 +2074,7 @@ def _test_forward_transpose(ishape, axes=None): ...@@ -1886,6 +2074,7 @@ def _test_forward_transpose(ishape, axes=None):
compare_tf_with_tvm(data, 'transpose_data:0', 'transpose:0') compare_tf_with_tvm(data, 'transpose_data:0', 'transpose:0')
def test_forward_transpose(): def test_forward_transpose():
_test_forward_transpose((2, 3, 4), (1, 2, 0)) _test_forward_transpose((2, 3, 4), (1, 2, 0))
_test_forward_transpose((2, 3, 4)) _test_forward_transpose((2, 3, 4))
...@@ -1903,6 +2092,7 @@ def test_forward_ceil(): ...@@ -1903,6 +2092,7 @@ def test_forward_ceil():
tf.ceil(in1) tf.ceil(in1)
compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Ceil:0') compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Ceil:0')
def test_forward_floor(): def test_forward_floor():
ishape = (1, 3, 10, 10) ishape = (1, 3, 10, 10)
inp_array = np.random.uniform(size=ishape).astype(np.float32) inp_array = np.random.uniform(size=ishape).astype(np.float32)
...@@ -1911,6 +2101,7 @@ def test_forward_floor(): ...@@ -1911,6 +2101,7 @@ def test_forward_floor():
tf.floor(in1) tf.floor(in1)
compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Floor:0') compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Floor:0')
def test_forward_relu(): def test_forward_relu():
ishape = (1, 3, 10, 10) ishape = (1, 3, 10, 10)
inp_array = np.random.uniform(-5, 5, size=ishape).astype(np.float32) inp_array = np.random.uniform(-5, 5, size=ishape).astype(np.float32)
...@@ -1919,6 +2110,7 @@ def test_forward_relu(): ...@@ -1919,6 +2110,7 @@ def test_forward_relu():
tf.nn.relu(in1) tf.nn.relu(in1)
compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Relu:0') compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Relu:0')
def test_forward_leaky_relu(): def test_forward_leaky_relu():
ishape = (1, 3, 10, 10) ishape = (1, 3, 10, 10)
inp_array = np.random.uniform(-5, 5, size=ishape).astype(np.float32) inp_array = np.random.uniform(-5, 5, size=ishape).astype(np.float32)
...@@ -1927,6 +2119,7 @@ def test_forward_leaky_relu(): ...@@ -1927,6 +2119,7 @@ def test_forward_leaky_relu():
tf.nn.leaky_relu(in1, alpha=0.4) tf.nn.leaky_relu(in1, alpha=0.4)
compare_tf_with_tvm(inp_array, 'Placeholder:0', 'LeakyRelu:0') compare_tf_with_tvm(inp_array, 'Placeholder:0', 'LeakyRelu:0')
def test_forward_elu(): def test_forward_elu():
ishape = (1, 3, 10, 10) ishape = (1, 3, 10, 10)
inp_array = np.random.uniform(-5, 5, size=ishape).astype(np.float32) inp_array = np.random.uniform(-5, 5, size=ishape).astype(np.float32)
...@@ -1935,6 +2128,7 @@ def test_forward_elu(): ...@@ -1935,6 +2128,7 @@ def test_forward_elu():
tf.nn.elu(in1) tf.nn.elu(in1)
compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Elu:0') compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Elu:0')
def test_forward_selu(): def test_forward_selu():
ishape = (1, 3, 10, 10) ishape = (1, 3, 10, 10)
inp_array = np.random.uniform(-5, 5, size=ishape).astype(np.float32) inp_array = np.random.uniform(-5, 5, size=ishape).astype(np.float32)
...@@ -1943,6 +2137,7 @@ def test_forward_selu(): ...@@ -1943,6 +2137,7 @@ def test_forward_selu():
tf.nn.selu(in1) tf.nn.selu(in1)
compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Selu:0') compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Selu:0')
def test_forward_tanh(): def test_forward_tanh():
ishape = (1, 3, 10, 10) ishape = (1, 3, 10, 10)
inp_array = np.random.uniform(-5, 5, size=ishape).astype(np.float32) inp_array = np.random.uniform(-5, 5, size=ishape).astype(np.float32)
...@@ -1979,6 +2174,7 @@ def test_forward_round(): ...@@ -1979,6 +2174,7 @@ def test_forward_round():
tf.round(in_data, name="round") tf.round(in_data, name="round")
compare_tf_with_tvm([np_data], ['in_data:0'], 'round:0') compare_tf_with_tvm([np_data], ['in_data:0'], 'round:0')
def test_forward_abs(): def test_forward_abs():
"""test operator Abs""" """test operator Abs"""
np_data = np.random.uniform(1, 100, size=(9, 11)).astype(np.float32) np_data = np.random.uniform(1, 100, size=(9, 11)).astype(np.float32)
...@@ -1987,6 +2183,7 @@ def test_forward_abs(): ...@@ -1987,6 +2183,7 @@ def test_forward_abs():
tf.math.abs(in_data, name="abs") tf.math.abs(in_data, name="abs")
compare_tf_with_tvm([np_data], ['in_data:0'], 'abs:0') compare_tf_with_tvm([np_data], ['in_data:0'], 'abs:0')
def _test_forward_zeros_like(in_shape, dtype): def _test_forward_zeros_like(in_shape, dtype):
np_data = np.random.uniform(-10, 10, size=in_shape).astype(dtype) np_data = np.random.uniform(-10, 10, size=in_shape).astype(dtype)
tf.reset_default_graph() tf.reset_default_graph()
...@@ -1994,6 +2191,7 @@ def _test_forward_zeros_like(in_shape, dtype): ...@@ -1994,6 +2191,7 @@ def _test_forward_zeros_like(in_shape, dtype):
tf.zeros_like(in_data, name="zeros_like") tf.zeros_like(in_data, name="zeros_like")
compare_tf_with_tvm([np_data], ['in_data:0'], 'zeros_like:0') compare_tf_with_tvm([np_data], ['in_data:0'], 'zeros_like:0')
def test_forward_zeros_like(): def test_forward_zeros_like():
if tf.__version__ < LooseVersion('1.2'): if tf.__version__ < LooseVersion('1.2'):
_test_forward_zeros_like((2, 3), "int32") _test_forward_zeros_like((2, 3), "int32")
...@@ -2002,6 +2200,7 @@ def test_forward_zeros_like(): ...@@ -2002,6 +2200,7 @@ def test_forward_zeros_like():
_test_forward_zeros_like((2, 3, 11), "float32") _test_forward_zeros_like((2, 3, 11), "float32")
_test_forward_zeros_like((2, 3, 11), "float64") _test_forward_zeros_like((2, 3, 11), "float64")
def test_forward_erf(): def test_forward_erf():
ishape = (1, 3, 10, 10) ishape = (1, 3, 10, 10)
inp_array = np.random.uniform(-5, 5, size=ishape).astype(np.float32) inp_array = np.random.uniform(-5, 5, size=ishape).astype(np.float32)
...@@ -2010,15 +2209,20 @@ def test_forward_erf(): ...@@ -2010,15 +2209,20 @@ def test_forward_erf():
tf.math.erf(in1) tf.math.erf(in1)
compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Erf:0') compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Erf:0')
def test_forward_squared_difference(): def test_forward_squared_difference():
ishape = (1, 3, 10, 14) ishape = (1, 3, 10, 14)
inp_array_a = np.random.uniform(-5, 5, size=ishape).astype(np.float32) inp_array_a = np.random.uniform(-5, 5, size=ishape).astype(np.float32)
inp_array_b = np.random.uniform(-5, 5, size=ishape).astype(np.float32) inp_array_b = np.random.uniform(-5, 5, size=ishape).astype(np.float32)
with tf.Graph().as_default(): with tf.Graph().as_default():
in1 = tf.placeholder(shape=inp_array_a.shape, dtype=inp_array_a.dtype, name="in1") in1 = tf.placeholder(shape=inp_array_a.shape,
in2 = tf.placeholder(shape=inp_array_b.shape, dtype=inp_array_b.dtype, name="in2") dtype=inp_array_a.dtype, name="in1")
in2 = tf.placeholder(shape=inp_array_b.shape,
dtype=inp_array_b.dtype, name="in2")
out = tf.math.squared_difference(in1, in2) out = tf.math.squared_difference(in1, in2)
compare_tf_with_tvm([inp_array_a, inp_array_b], [in1.name, in2.name], out.name) compare_tf_with_tvm([inp_array_a, inp_array_b], [
in1.name, in2.name], out.name)
def _test_forward_reverse_v2(in_shape, axis, dtype): def _test_forward_reverse_v2(in_shape, axis, dtype):
np_data = np.random.uniform(-10, 10, size=in_shape).astype(dtype) np_data = np.random.uniform(-10, 10, size=in_shape).astype(dtype)
...@@ -2027,6 +2231,7 @@ def _test_forward_reverse_v2(in_shape, axis, dtype): ...@@ -2027,6 +2231,7 @@ def _test_forward_reverse_v2(in_shape, axis, dtype):
tf.reverse(in_data, axis=[axis], name="reverse") tf.reverse(in_data, axis=[axis], name="reverse")
compare_tf_with_tvm([np_data], ['in_data:0'], 'reverse:0') compare_tf_with_tvm([np_data], ['in_data:0'], 'reverse:0')
def test_forward_reverse_v2(): def test_forward_reverse_v2():
"""test ReverseV2""" """test ReverseV2"""
_test_forward_reverse_v2((2, 3), 0, "int32") _test_forward_reverse_v2((2, 3), 0, "int32")
...@@ -2035,6 +2240,7 @@ def test_forward_reverse_v2(): ...@@ -2035,6 +2240,7 @@ def test_forward_reverse_v2():
_test_forward_reverse_v2((2, 3, 5), -1, "float64") _test_forward_reverse_v2((2, 3, 5), -1, "float64")
_test_forward_reverse_v2((2, 3, 5), -3, "float64") _test_forward_reverse_v2((2, 3, 5), -3, "float64")
def test_forward_sign(): def test_forward_sign():
"""test Sign""" """test Sign"""
np_data = np.random.uniform(-10, 10, size=(5, 7, 11)).astype(np.float32) np_data = np.random.uniform(-10, 10, size=(5, 7, 11)).astype(np.float32)
...@@ -2043,6 +2249,7 @@ def test_forward_sign(): ...@@ -2043,6 +2249,7 @@ def test_forward_sign():
tf.sign(in_data, name="sign") tf.sign(in_data, name="sign")
compare_tf_with_tvm([np_data], ['in_data:0'], 'sign:0') compare_tf_with_tvm([np_data], ['in_data:0'], 'sign:0')
def test_forward_square(): def test_forward_square():
"""test operator Square """ """test operator Square """
np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32) np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32)
...@@ -2051,6 +2258,7 @@ def test_forward_square(): ...@@ -2051,6 +2258,7 @@ def test_forward_square():
tf.square(in_data, name="square") tf.square(in_data, name="square")
compare_tf_with_tvm([np_data], ['in_data:0'], 'square:0') compare_tf_with_tvm([np_data], ['in_data:0'], 'square:0')
def test_forward_pow_exp(): def test_forward_pow_exp():
"""test Pow and Exp """ """test Pow and Exp """
np_in1 = np.random.uniform(-2, 2, size=(5, 7, 11)).astype(np.float32) np_in1 = np.random.uniform(-2, 2, size=(5, 7, 11)).astype(np.float32)
...@@ -2063,6 +2271,7 @@ def test_forward_pow_exp(): ...@@ -2063,6 +2271,7 @@ def test_forward_pow_exp():
compare_tf_with_tvm([np_in1, np_in2], ['in1:0', 'in2:0'], 'pow:0') compare_tf_with_tvm([np_in1, np_in2], ['in1:0', 'in2:0'], 'pow:0')
compare_tf_with_tvm([np_in1], ['in1:0'], 'exp:0') compare_tf_with_tvm([np_in1], ['in1:0'], 'exp:0')
def test_forward_log(): def test_forward_log():
"""test operator Log """ """test operator Log """
np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32) np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32)
...@@ -2071,6 +2280,7 @@ def test_forward_log(): ...@@ -2071,6 +2280,7 @@ def test_forward_log():
tf.log(in_data, name="log") tf.log(in_data, name="log")
compare_tf_with_tvm([np_data], ['in_data:0'], 'log:0') compare_tf_with_tvm([np_data], ['in_data:0'], 'log:0')
def test_forward_log1p(): def test_forward_log1p():
"""test operator Log1p """ """test operator Log1p """
np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32) np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32)
...@@ -2079,6 +2289,7 @@ def test_forward_log1p(): ...@@ -2079,6 +2289,7 @@ def test_forward_log1p():
tf.log1p(in_data, name="log1p") tf.log1p(in_data, name="log1p")
compare_tf_with_tvm([np_data], ['in_data:0'], 'log1p:0') compare_tf_with_tvm([np_data], ['in_data:0'], 'log1p:0')
def test_forward_cos(): def test_forward_cos():
"""test operator cos """ """test operator cos """
np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32) np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32)
...@@ -2087,6 +2298,7 @@ def test_forward_cos(): ...@@ -2087,6 +2298,7 @@ def test_forward_cos():
tf.cos(in_data, name="cos") tf.cos(in_data, name="cos")
compare_tf_with_tvm([np_data], ['in_data:0'], 'cos:0') compare_tf_with_tvm([np_data], ['in_data:0'], 'cos:0')
def test_forward_sin(): def test_forward_sin():
"""test operator sin """ """test operator sin """
np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32) np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32)
...@@ -2095,14 +2307,17 @@ def test_forward_sin(): ...@@ -2095,14 +2307,17 @@ def test_forward_sin():
tf.sin(in_data, name="sin") tf.sin(in_data, name="sin")
compare_tf_with_tvm([np_data], ['in_data:0'], 'sin:0') compare_tf_with_tvm([np_data], ['in_data:0'], 'sin:0')
def test_forward_negative(): def test_forward_negative():
"""test tf operator Neg """ """test tf operator Neg """
np_data = np.random.uniform(-100, 255, size=(224, 224, 3)).astype(np.float32) np_data = np.random.uniform(-100, 255,
size=(224, 224, 3)).astype(np.float32)
tf.reset_default_graph() tf.reset_default_graph()
in_data = tf.placeholder(tf.float32, (224, 224, 3), name="in_data") in_data = tf.placeholder(tf.float32, (224, 224, 3), name="in_data")
tf.negative(in_data, name="negative") tf.negative(in_data, name="negative")
compare_tf_with_tvm([np_data], ['in_data:0'], 'negative:0') compare_tf_with_tvm([np_data], ['in_data:0'], 'negative:0')
def test_forward_log_softmax(): def test_forward_log_softmax():
"""test operator LogSoftmax""" """test operator LogSoftmax"""
np_data = np.random.uniform(1, 100, size=(9, 11)).astype(np.float32) np_data = np.random.uniform(1, 100, size=(9, 11)).astype(np.float32)
...@@ -2111,6 +2326,7 @@ def test_forward_log_softmax(): ...@@ -2111,6 +2326,7 @@ def test_forward_log_softmax():
tf.math.log_softmax(in_data, name="LogSoftmax") tf.math.log_softmax(in_data, name="LogSoftmax")
compare_tf_with_tvm([np_data], ['in_data:0'], 'LogSoftmax:0') compare_tf_with_tvm([np_data], ['in_data:0'], 'LogSoftmax:0')
def test_forward_softplus(): def test_forward_softplus():
"""test operator Softplus""" """test operator Softplus"""
np_data = np.random.uniform(1, 10, size=(2, 3, 5)).astype(np.float32) np_data = np.random.uniform(1, 10, size=(2, 3, 5)).astype(np.float32)
...@@ -2119,6 +2335,7 @@ def test_forward_softplus(): ...@@ -2119,6 +2335,7 @@ def test_forward_softplus():
tf.nn.softplus(in_data, name="softplus") tf.nn.softplus(in_data, name="softplus")
compare_tf_with_tvm([np_data], ['in_data:0'], 'softplus:0') compare_tf_with_tvm([np_data], ['in_data:0'], 'softplus:0')
def test_forward_rsqrt(): def test_forward_rsqrt():
"""test Rsqrt """ """test Rsqrt """
np_data = np.random.uniform(1, 100, size=(5, 7, 11)).astype(np.float32) np_data = np.random.uniform(1, 100, size=(5, 7, 11)).astype(np.float32)
...@@ -2127,6 +2344,7 @@ def test_forward_rsqrt(): ...@@ -2127,6 +2344,7 @@ def test_forward_rsqrt():
tf.rsqrt(in_data, name="rsqrt") tf.rsqrt(in_data, name="rsqrt")
compare_tf_with_tvm([np_data], ['in_data:0'], 'rsqrt:0') compare_tf_with_tvm([np_data], ['in_data:0'], 'rsqrt:0')
def test_forward_sqrt(): def test_forward_sqrt():
"""test Sqrt """ """test Sqrt """
np_data = np.random.uniform(1, 100, size=(5, 7, 11)).astype(np.float32) np_data = np.random.uniform(1, 100, size=(5, 7, 11)).astype(np.float32)
...@@ -2135,6 +2353,7 @@ def test_forward_sqrt(): ...@@ -2135,6 +2353,7 @@ def test_forward_sqrt():
tf.sqrt(in_data, name="sqrt") tf.sqrt(in_data, name="sqrt")
compare_tf_with_tvm([np_data], ['in_data:0'], 'sqrt:0') compare_tf_with_tvm([np_data], ['in_data:0'], 'sqrt:0')
def _test_forward_right_shift(in_shape, dtype): def _test_forward_right_shift(in_shape, dtype):
"""test operator RightShift""" """test operator RightShift"""
lh_data = np.random.randint(1, 3, size=in_shape).astype(dtype) lh_data = np.random.randint(1, 3, size=in_shape).astype(dtype)
...@@ -2143,12 +2362,15 @@ def _test_forward_right_shift(in_shape, dtype): ...@@ -2143,12 +2362,15 @@ def _test_forward_right_shift(in_shape, dtype):
lft_data = tf.placeholder(dtype, in_shape, name="lft_data") lft_data = tf.placeholder(dtype, in_shape, name="lft_data")
rgt_data = tf.placeholder(dtype, in_shape, name="rgt_data") rgt_data = tf.placeholder(dtype, in_shape, name="rgt_data")
tf.bitwise.right_shift(lft_data, rgt_data, name="RightShift") tf.bitwise.right_shift(lft_data, rgt_data, name="RightShift")
compare_tf_with_tvm([lh_data, rh_data], ['lft_data:0', 'rgt_data:0'], 'RightShift:0') compare_tf_with_tvm([lh_data, rh_data], [
'lft_data:0', 'rgt_data:0'], 'RightShift:0')
def test_forward_right_shift(): def test_forward_right_shift():
_test_forward_right_shift((7,), 'int32') _test_forward_right_shift((7,), 'int32')
_test_forward_right_shift((3, 11), 'int16') _test_forward_right_shift((3, 11), 'int16')
def _test_forward_left_shift(in_shape, dtype): def _test_forward_left_shift(in_shape, dtype):
"""test operator LeftShift""" """test operator LeftShift"""
lh_data = np.random.randint(100, 1000000, size=in_shape).astype(dtype) lh_data = np.random.randint(100, 1000000, size=in_shape).astype(dtype)
...@@ -2157,7 +2379,9 @@ def _test_forward_left_shift(in_shape, dtype): ...@@ -2157,7 +2379,9 @@ def _test_forward_left_shift(in_shape, dtype):
lft_data = tf.placeholder(dtype, in_shape, name="lft_data") lft_data = tf.placeholder(dtype, in_shape, name="lft_data")
rgt_data = tf.placeholder(dtype, in_shape, name="rgt_data") rgt_data = tf.placeholder(dtype, in_shape, name="rgt_data")
tf.bitwise.left_shift(lft_data, rgt_data, name="LeftShift") tf.bitwise.left_shift(lft_data, rgt_data, name="LeftShift")
compare_tf_with_tvm([lh_data, rh_data], ['lft_data:0', 'rgt_data:0'], 'LeftShift:0') compare_tf_with_tvm([lh_data, rh_data], [
'lft_data:0', 'rgt_data:0'], 'LeftShift:0')
def test_forward_left_shift(): def test_forward_left_shift():
_test_forward_left_shift((10,), 'int32') _test_forward_left_shift((10,), 'int32')
...@@ -2166,13 +2390,16 @@ def test_forward_left_shift(): ...@@ -2166,13 +2390,16 @@ def test_forward_left_shift():
####################################################################### #######################################################################
# Mean # Mean
# ---- # ----
def test_forward_mean(): def test_forward_mean():
def check_mean(ishape, **kwargs): def check_mean(ishape, **kwargs):
inp_array = np.random.uniform(size=ishape).astype(np.float32) inp_array = np.random.uniform(size=ishape).astype(np.float32)
with tf.Graph().as_default(): with tf.Graph().as_default():
in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype) in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype)
tf.keras.backend.mean(in1, **kwargs) tf.keras.backend.mean(in1, **kwargs)
compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Mean:0', no_gpu=True) compare_tf_with_tvm(inp_array, 'Placeholder:0',
'Mean:0', no_gpu=True)
check_mean((10, 8, 16, 32)) check_mean((10, 8, 16, 32))
check_mean((10, 8, 16, 32), axis=(2, 3)) check_mean((10, 8, 16, 32), axis=(2, 3))
...@@ -2181,6 +2408,8 @@ def test_forward_mean(): ...@@ -2181,6 +2408,8 @@ def test_forward_mean():
####################################################################### #######################################################################
# Size # Size
# ---- # ----
def test_forward_size(): def test_forward_size():
def check_size(ishape): def check_size(ishape):
np_input = np.random.uniform(size=ishape).astype(np.float32) np_input = np.random.uniform(size=ishape).astype(np.float32)
...@@ -2190,7 +2419,8 @@ def test_forward_size(): ...@@ -2190,7 +2419,8 @@ def test_forward_size():
tf_input_shape[0] = None tf_input_shape[0] = None
with tf.Graph().as_default(): with tf.Graph().as_default():
input = tf.placeholder(shape=tf_input_shape, dtype=np_input.dtype, name='input') input = tf.placeholder(shape=tf_input_shape,
dtype=np_input.dtype, name='input')
tf.size(input, name='size') tf.size(input, name='size')
compare_tf_with_tvm([np_input], ['input:0'], 'size:0') compare_tf_with_tvm([np_input], ['input:0'], 'size:0')
...@@ -2200,6 +2430,8 @@ def test_forward_size(): ...@@ -2200,6 +2430,8 @@ def test_forward_size():
####################################################################### #######################################################################
# All, Any, Max, Min # All, Any, Max, Min
# ------------- # -------------
def test_forward_reduce_all(): def test_forward_reduce_all():
"""Test the All operator.""" """Test the All operator."""
np_data = np.random.choice([True, False], size=(5, 7, 11)) np_data = np.random.choice([True, False], size=(5, 7, 11))
...@@ -2208,32 +2440,28 @@ def test_forward_reduce_all(): ...@@ -2208,32 +2440,28 @@ def test_forward_reduce_all():
tf.reduce_all(in_data, name="all") tf.reduce_all(in_data, name="all")
compare_tf_with_tvm([np_data], ['in_data:0'], 'all:0') compare_tf_with_tvm([np_data], ['in_data:0'], 'all:0')
def test_forward_reduce_any():
"""Test the Any operator."""
np_data = np.random.choice([True, False], size=(5, 7, 11))
tf.reset_default_graph()
in_data = tf.placeholder(tf.bool, (5, 7, 11), name="in_data")
tf.reduce_any(in_data, name="any")
compare_tf_with_tvm([np_data], ['in_data:0'], 'any:0')
def test_forward_reduce_max(): def test_forward_reduce_max():
def check_max(ishape, axis, keepdims, dtype): def check_max(ishape, axis, keepdims, dtype):
tf.reset_default_graph() tf.reset_default_graph()
np_data = np.random.uniform(size=ishape).astype(dtype) np_data = np.random.uniform(size=ishape).astype(dtype)
in_data = tf.placeholder(dtype, name="in_data") in_data = tf.placeholder(dtype, name="in_data")
tf.math.reduce_max(in_data, axis=axis, keepdims=keepdims, name="reduce_max") tf.math.reduce_max(in_data, axis=axis,
keepdims=keepdims, name="reduce_max")
compare_tf_with_tvm([np_data], ['in_data:0'], 'reduce_max:0') compare_tf_with_tvm([np_data], ['in_data:0'], 'reduce_max:0')
check_max((10, 8, 16, 32), axis=(-1), keepdims=True, dtype="int32") check_max((10, 8, 16, 32), axis=(-1), keepdims=True, dtype="int32")
check_max((10, 8, 16, 32), axis=(2, 3), keepdims=True, dtype="float32") check_max((10, 8, 16, 32), axis=(2, 3), keepdims=True, dtype="float32")
check_max((10, 8, 16, 32), axis=(1, 2), keepdims=True, dtype='float32') check_max((10, 8, 16, 32), axis=(1, 2), keepdims=True, dtype='float32')
def test_forward_reduce_min(): def test_forward_reduce_min():
def check_min(ishape, axis, keepdims, dtype): def check_min(ishape, axis, keepdims, dtype):
tf.reset_default_graph() tf.reset_default_graph()
np_data = np.random.uniform(size=ishape).astype(dtype) np_data = np.random.uniform(size=ishape).astype(dtype)
in_data = tf.placeholder(dtype, name="in_data") in_data = tf.placeholder(dtype, name="in_data")
tf.math.reduce_min(in_data, axis=axis, keepdims=keepdims, name="reduce_max") tf.math.reduce_min(in_data, axis=axis,
keepdims=keepdims, name="reduce_max")
compare_tf_with_tvm([np_data], ['in_data:0'], 'reduce_max:0') compare_tf_with_tvm([np_data], ['in_data:0'], 'reduce_max:0')
check_min((10, 8, 16, 32), axis=(-1), keepdims=True, dtype="int32") check_min((10, 8, 16, 32), axis=(-1), keepdims=True, dtype="int32")
...@@ -2243,14 +2471,19 @@ def test_forward_reduce_min(): ...@@ -2243,14 +2471,19 @@ def test_forward_reduce_min():
####################################################################### #######################################################################
# Relational operators # Relational operators
# -------------------- # --------------------
def _test_forward_rel_op(data, func): def _test_forward_rel_op(data, func):
with tf.Graph().as_default(): with tf.Graph().as_default():
in1 = tf.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in1') in1 = tf.placeholder(
in2 = tf.placeholder(shape=data[1].shape, dtype=data[1].dtype, name='in2') shape=data[0].shape, dtype=data[0].dtype, name='in1')
in2 = tf.placeholder(
shape=data[1].shape, dtype=data[1].dtype, name='in2')
op = func(in1, in2, name='op') op = func(in1, in2, name='op')
out = tf.cast(op, tf.int32, name='out1') out = tf.cast(op, tf.int32, name='out1')
compare_tf_with_tvm([data[0], data[1]], ['in1:0', 'in2:0'], 'out1:0') compare_tf_with_tvm([data[0], data[1]], ['in1:0', 'in2:0'], 'out1:0')
def test_forward_rel_ops(): def test_forward_rel_ops():
t1 = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) t1 = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
t2 = np.array([[9, 8, 7], [6, 5, 4], [3, 2, 1]]) t2 = np.array([[9, 8, 7], [6, 5, 4], [3, 2, 1]])
...@@ -2264,11 +2497,14 @@ def test_forward_rel_ops(): ...@@ -2264,11 +2497,14 @@ def test_forward_rel_ops():
####################################################################### #######################################################################
# ExpandDims # ExpandDims
# ---------- # ----------
def _test_forward_expand_dims(data, axis): def _test_forward_expand_dims(data, axis):
in1 = tf.placeholder(shape=data.shape, dtype=data.dtype, name='in1') in1 = tf.placeholder(shape=data.shape, dtype=data.dtype, name='in1')
out = tf.expand_dims(in1, axis) out = tf.expand_dims(in1, axis)
compare_tf_with_tvm([data], [in1.name], out.name) compare_tf_with_tvm([data], [in1.name], out.name)
def test_forward_expand_dims(): def test_forward_expand_dims():
_test_forward_expand_dims(np.int32(1), 0) _test_forward_expand_dims(np.int32(1), 0)
_test_forward_expand_dims(np.array([1]), 0) _test_forward_expand_dims(np.array([1]), 0)
...@@ -2288,6 +2524,7 @@ def _test_forward_reduce_prod(shape, axis, keepdims): ...@@ -2288,6 +2524,7 @@ def _test_forward_reduce_prod(shape, axis, keepdims):
out = tf.math.reduce_prod(in1, axis, keepdims) out = tf.math.reduce_prod(in1, axis, keepdims)
compare_tf_with_tvm(inp_array1, in1.name, out.name) compare_tf_with_tvm(inp_array1, in1.name, out.name)
def test_forward_reduce_prod(): def test_forward_reduce_prod():
_test_forward_reduce_prod((5,), 0, False) _test_forward_reduce_prod((5,), 0, False)
_test_forward_reduce_prod((5, 5), 0, False) _test_forward_reduce_prod((5, 5), 0, False)
...@@ -2309,11 +2546,13 @@ def test_forward_maximum(): ...@@ -2309,11 +2546,13 @@ def test_forward_maximum():
lft_data = tf.placeholder(dtype, name="lft_data") lft_data = tf.placeholder(dtype, name="lft_data")
rgt_data = tf.placeholder(dtype, name="rgt_data") rgt_data = tf.placeholder(dtype, name="rgt_data")
tf.math.maximum(lft_data, rgt_data, name="maximum") tf.math.maximum(lft_data, rgt_data, name="maximum")
compare_tf_with_tvm([lh_data, rh_data], ['lft_data:0', 'rgt_data:0'], 'maximum:0') compare_tf_with_tvm([lh_data, rh_data], [
'lft_data:0', 'rgt_data:0'], 'maximum:0')
check_maximum((10, 8, 16, 32), (1,), dtype="int32") check_maximum((10, 8, 16, 32), (1,), dtype="int32")
check_maximum((10, 8, 16, 32), (10, 8, 16, 32), dtype="float32") check_maximum((10, 8, 16, 32), (10, 8, 16, 32), dtype="float32")
def test_forward_minimum(): def test_forward_minimum():
"""test Op Minimum""" """test Op Minimum"""
def check_minimum(lh_shape, rh_shape, dtype): def check_minimum(lh_shape, rh_shape, dtype):
...@@ -2323,7 +2562,8 @@ def test_forward_minimum(): ...@@ -2323,7 +2562,8 @@ def test_forward_minimum():
lft_data = tf.placeholder(dtype, name="lft_data") lft_data = tf.placeholder(dtype, name="lft_data")
rgt_data = tf.placeholder(dtype, name="rgt_data") rgt_data = tf.placeholder(dtype, name="rgt_data")
tf.math.minimum(lft_data, rgt_data, name="minimum") tf.math.minimum(lft_data, rgt_data, name="minimum")
compare_tf_with_tvm([lh_data, rh_data], ['lft_data:0', 'rgt_data:0'], 'minimum:0') compare_tf_with_tvm([lh_data, rh_data], [
'lft_data:0', 'rgt_data:0'], 'minimum:0')
check_minimum((10, 8, 16, 32), (1,), dtype="int32") check_minimum((10, 8, 16, 32), (1,), dtype="int32")
check_minimum((10, 8, 16, 32), (10, 8, 16, 32), dtype="float32") check_minimum((10, 8, 16, 32), (10, 8, 16, 32), dtype="float32")
...@@ -2339,7 +2579,8 @@ def test_placeholder(): ...@@ -2339,7 +2579,8 @@ def test_placeholder():
var2 = array_ops.placeholder_with_default(var1, None, name='place1') var2 = array_ops.placeholder_with_default(var1, None, name='place1')
in_data2 = np.random.uniform(-5, 5, size=(3, 4, 5)).astype(np.float32) in_data2 = np.random.uniform(-5, 5, size=(3, 4, 5)).astype(np.float32)
place1 = array_ops.placeholder(shape=in_data1.shape, dtype=in_data1.dtype, name='in2') place1 = array_ops.placeholder(
shape=in_data1.shape, dtype=in_data1.dtype, name='in2')
out1 = tf.math.add(var1, var2, name='out1') out1 = tf.math.add(var1, var2, name='out1')
out2 = tf.math.add(out1, place1, name='out2') out2 = tf.math.add(out1, place1, name='out2')
...@@ -2350,13 +2591,17 @@ def test_placeholder(): ...@@ -2350,13 +2591,17 @@ def test_placeholder():
####################################################################### #######################################################################
# OneHot # OneHot
# ---------------------- # ----------------------
def _test_forward_one_hot(indices_shape, depth, on_value, off_value, axis, out_dtype): def _test_forward_one_hot(indices_shape, depth, on_value, off_value, axis, out_dtype):
inp_array1 = np.random.randint(0, 5, size=indices_shape) inp_array1 = np.random.randint(0, 5, size=indices_shape)
with tf.Graph().as_default(): with tf.Graph().as_default():
in1 = tf.placeholder(shape=inp_array1.shape, dtype=inp_array1.dtype) in1 = tf.placeholder(shape=inp_array1.shape, dtype=inp_array1.dtype)
out = tf.one_hot(in1, depth, on_value, off_value, axis, dtype=out_dtype) out = tf.one_hot(in1, depth, on_value, off_value,
axis, dtype=out_dtype)
compare_tf_with_tvm(inp_array1, in1.name, out.name) compare_tf_with_tvm(inp_array1, in1.name, out.name)
def test_forward_one_hot(): def test_forward_one_hot():
_test_forward_one_hot((3,), 3, 1, 0, -1, "int32") _test_forward_one_hot((3,), 3, 1, 0, -1, "int32")
_test_forward_one_hot((3,), 3, 1.0, 0.0, -1, "float32") _test_forward_one_hot((3,), 3, 1.0, 0.0, -1, "float32")
...@@ -2365,6 +2610,40 @@ def test_forward_one_hot(): ...@@ -2365,6 +2610,40 @@ def test_forward_one_hot():
_test_forward_one_hot((3, 2, 4, 5), 6, 1, 0, 1, "int32") _test_forward_one_hot((3, 2, 4, 5), 6, 1, 0, 1, "int32")
_test_forward_one_hot((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32") _test_forward_one_hot((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32")
#######################################################################
# AddN
# ----------------------
def _test_forward_add_n(inputs):
tf.reset_default_graph()
with tf.Graph().as_default():
temp = []
for each in inputs:
temp.append(tf.placeholder(shape=each.shape, dtype=each.dtype))
output = tf.add_n(temp)
compare_tf_with_tvm([each for each in inputs], [
each.name for each in temp], output.name)
def test_forward_add_n():
x = np.random.randint(1, 100, size=(3, 3, 3), dtype=np.int32)
y = np.random.randint(1, 100, size=(3, 3, 3), dtype=np.int32)
z = np.random.randint(1, 100, size=(3, 3, 3), dtype=np.int32)
m, n, o = x.astype(np.float32), y.astype(np.float32), z.astype(np.float32)
in0 = x
in1 = [x, y]
in2 = (x, y, z)
in3 = m
in4 = [m, n]
in5 = (m, n, o)
_test_forward_add_n(in0)
_test_forward_add_n(in1)
_test_forward_add_n(in2)
_test_forward_add_n(in3)
_test_forward_add_n(in4)
_test_forward_add_n(in5)
####################################################################### #######################################################################
# Main # Main
...@@ -2433,6 +2712,7 @@ if __name__ == '__main__': ...@@ -2433,6 +2712,7 @@ if __name__ == '__main__':
test_forward_zeros_like() test_forward_zeros_like()
test_forward_erf() test_forward_erf()
test_forward_squared_difference() test_forward_squared_difference()
test_forward_add_n()
# Reductions # Reductions
test_forward_argminmax() test_forward_argminmax()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment