Unverified Commit bb3c8151 by Siva Committed by GitHub

[FRONTEND][TENSORFLOW] Enhance with left over patches from NNVM. (#2757)

* [FRONTEND][TENSORFLOW] Enhance with left over patches from NNVM.

commit 76188a43
Author: Siva sivar.b@huawei.com
[NNVM][TENSORFLOW] bugfix. (#2444)

commit 6737739c
Author: Ashutosh Parkhi ashutosh.parkhi@imgtec.com
[Tensorflow] Support for Crop (#2285)

commit f6c3f997
Author: Alexey Romanov alexey.v.romanov@gmail.com
[FRONTEND][TENSORFLOW] Use input shapes directly instead of 1-element lists (#2242)

commit e5d92e1b
Author: Dominic Symes 36929632+dominicsymes@users.noreply.github.com
[FRONTEND][TENSORFLOW] Bugfix (#2326)

commit 00d509d4
Author: Alexey Romanov alexey.v.romanov@gmail.com
[FRONTEND][TENSORFLOW] Support Unstack and Split (#2105)

commit df9d3ad2
Author: Siva sivar.b@huawei.com
[FRONTEND][TENSORFLOW] Bugfix (#2267)

commit d1a0c901
Author: Zhebin Jin zhebin.jzb@alibaba-inc.com
[FRONTEND][TENSORFLOW]Add Split and realdiv op support (#2123)
* Add Split and realdiv op support
* Fix the pad calculation in the case of dilated convolution

* 	* review comments

* 	* resnet fix.

* 	* review comments
parent f63631fc
......@@ -137,7 +137,7 @@ def is_gpu_available():
from tensorflow.python.client import device_lib
local_device_protos = device_lib.list_local_devices()
gpu_list = [x.name for x in local_device_protos if x.device_type == 'GPU']
if len(gpu_list) < 0:
if len(gpu_list) > 0:
print("Tensorflow GPU:", gpu_list)
return True
else:
......@@ -168,7 +168,7 @@ def _test_pooling(input_shape, **kwargs):
if is_gpu_available():
input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)]
kwargs['data_layout'] = 'NCHW'
kwargs['data_format'] = 'NCHW'
_test_pooling_iteration(input_shape, **kwargs)
def test_forward_pooling():
......@@ -225,8 +225,12 @@ def _test_convolution(tensor_in_sizes, filter_in_sizes,
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype='float32')
in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype='float32')
strides = [1] + strides + [1]
dilations = [1] + dilations + [1]
if data_format == 'NHWC':
strides = [1] + strides + [1]
dilations = [1] + dilations + [1]
else:
strides = [1, 1] + strides
dilations = [1, 1] + dilations
nn_ops.conv2d(in_data,
in_filter,
......@@ -898,7 +902,7 @@ def test_forward_mobilenet():
#######################################################################
# ResnetV2
# ---------
# --------
def test_forward_resnetv2():
'''test resnet model'''
if is_gpu_available():
......@@ -912,8 +916,13 @@ def test_forward_resnetv2():
with tf.Session() as sess:
tf_output = run_tf_graph(sess, data, 'input_tensor:0', out_node + ':0')
tvm_output = run_tvm_graph(graph_def, data, 'input_tensor', tf_output.shape, 'float32')
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5)
for device in ["llvm", "cuda"]:
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
continue
tvm_output = run_tvm_graph(graph_def, data, 'input_tensor', len(tf_output), target=device)
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5)
#######################################################################
# PTB
......
......@@ -127,7 +127,8 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False,
if no_gpu and device == 'cuda':
continue
tvm_output = run_tvm_graph(final_graph_def, in_data, in_node, target=device)
tvm_output = run_tvm_graph(final_graph_def, in_data, in_node, target=device,
out_names=out_name, num_output=len(out_name))
# since the names from tensorflow and relay runs are not exactly same,
# first len(tf_output) will be compared
for i in range(len(tf_output)):
......@@ -170,7 +171,7 @@ def _test_pooling(input_shape, **kwargs):
if is_gpu_available():
input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)]
kwargs['data_layout'] = 'NCHW'
kwargs['data_format'] = 'NCHW'
_test_pooling_iteration(input_shape, **kwargs)
def test_forward_pooling():
......@@ -227,8 +228,12 @@ def _test_convolution(tensor_in_sizes, filter_in_sizes,
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype='float32')
in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype='float32')
strides = [1] + strides + [1]
dilations = [1] + dilations + [1]
if data_format == 'NHWC':
strides = [1] + strides + [1]
dilations = [1] + dilations + [1]
else:
strides = [1, 1] + strides
dilations = [1, 1] + dilations
nn_ops.conv2d(in_data,
in_filter,
......@@ -504,6 +509,84 @@ def test_forward_gather():
_test_gather((3,3,3), (1,1,2), [[[1,0]]], 2, 'int32')
_test_gather((4,3,5,6), (1,4), [[2,1,0,0]], 0, 'float32')
#######################################################################
# Split
# -----
def _test_split(in_shape, axis, num_or_size_splits, dtype):
np_data = np.random.uniform(-5, 5, size=in_shape).astype(dtype)
""" One iteration of a Split """
tf.reset_default_graph()
in_data = tf.placeholder(dtype, in_shape, name="in_data")
num_split = len(num_or_size_splits) if isinstance(num_or_size_splits, list) else num_or_size_splits
tf.split(in_data, num_or_size_splits, axis=axis)
compare_tf_with_tvm([np_data], ['in_data:0'], [f'split:{n}' for n in range(num_split)])
# and now test together with concat
tf.reset_default_graph()
in_data = tf.placeholder(dtype, in_shape, name="in_data")
splitted = tf.split(in_data, num_or_size_splits, axis=axis)
tf.concat(splitted, axis)
compare_tf_with_tvm([np_data], 'in_data:0', 'concat:0')
def test_forward_split():
'''test split layer'''
# rank 1
_test_split((3,), 0, 1, 'float32')
_test_split((3,), 0, 3, 'float32')
_test_split((6,), 0, 3, 'float32')
# rank 2
_test_split((6, 2), 0, 3, 'float32')
_test_split((2, 6), 1, 6, 'float32')
# rank 3
_test_split((6, 2, 4), 0, 2, 'int32')
_test_split((2, 6, 4), 1, 3, 'float32')
_test_split((2, 4, 6), 2, 1, 'float32')
# rank 4
_test_split((6, 1, 3, 5), 0, 3, 'float32')
_test_split((1, 6, 3, 5), 1, 3, 'float32')
_test_split((1, 3, 6, 5), 2, 3, 'float32')
_test_split((1, 3, 5, 6), 3, 3, 'float32')
# split along negative axis
_test_split((6, 1, 3, 5), -4, 3, 'float32')
_test_split((1, 6, 3, 5), -3, 3, 'float32')
_test_split((1, 3, 6, 5), -2, 3, 'float32')
_test_split((1, 3, 5, 6), -1, 3, 'float32')
# size_splits list
_test_split((6,), 0, [1, 2, 3], 'int32')
_test_split((3, 6, 4), -2, [1, 4, 1], 'float32')
#######################################################################
# Unstack
# -------
def _test_unstack(ip_shape, axis, dtype):
np_data = np.random.uniform(-5, 5, size=ip_shape).astype(dtype)
tf.reset_default_graph()
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
tf.unstack(in_data, axis=axis)
compare_tf_with_tvm([np_data], ['in_data:0'], [f'unstack:{n}' for n in range(ip_shape[axis])])
tf.reset_default_graph()
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
tf.stack(tf.unstack(in_data, axis=axis), axis=axis)
compare_tf_with_tvm([np_data], ['in_data:0'], 'stack:0')
def test_forward_unstack():
'''test unstack layer'''
_test_unstack((6,), 0, 'int32')
_test_unstack((2,6), 1, 'float64')
# negative axis
_test_unstack((1,4), -1, 'int32')
_test_unstack((3,6,4), -2, 'float32')
#######################################################################
# Multi Input to graph
......@@ -576,6 +659,22 @@ def test_forward_resize_bilinear():
_test_resize_bilinear((4, 16, 32, 32), [50, 50], False)
_test_resize_bilinear((6, 32, 64, 64), [20, 20], True)
#######################################################################
# Crop to bounding box
# --------------------
def _test_crop(in_shape, off_h, off_w, tar_h, tar_w):
""" Crop to bounding box """
data = np.random.uniform(size=in_shape).astype('float32')
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
tf.image.crop_to_bounding_box(in_data, off_h, off_w, tar_h, tar_w)
compare_tf_with_tvm(data, 'Placeholder:0', 'crop_to_bounding_box/Slice:0')
def test_forward_crop():
""" Crop to bounding box """
_test_crop((1, 224, 224, 3), 20, 20, 120, 120)
#######################################################################
# LSTM
......@@ -804,7 +903,7 @@ def test_forward_mobilenet():
#######################################################################
# ResnetV2
# ---------
# --------
def test_forward_resnetv2():
'''test resnet model'''
if is_gpu_available():
......@@ -818,8 +917,13 @@ def test_forward_resnetv2():
with tf.Session() as sess:
tf_output = run_tf_graph(sess, data, 'input_tensor:0', out_node + ':0')
tvm_output = run_tvm_graph(graph_def, data, 'input_tensor', tf_output.shape, 'float32')
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5)
for device in ["llvm", "cuda"]:
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
continue
tvm_output = run_tvm_graph(graph_def, data, 'input_tensor', len(tf_output), target=device)
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5)
#######################################################################
# PTB
......@@ -1106,9 +1210,12 @@ if __name__ == '__main__':
test_forward_squeeze()
test_forward_pack()
test_forward_resize_bilinear()
test_forward_crop()
test_forward_pad()
test_forward_gather()
test_forward_stridedslice()
test_forward_split()
test_forward_unstack()
# Activations
test_forward_sigmoid()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment