Unverified Commit bb3c8151 by Siva Committed by GitHub

[FRONTEND][TENSORFLOW] Enhance with left over patches from NNVM. (#2757)

* [FRONTEND][TENSORFLOW] Enhance with left over patches from NNVM.

commit 76188a43
Author: Siva sivar.b@huawei.com
[NNVM][TENSORFLOW] bugfix. (#2444)

commit 6737739c
Author: Ashutosh Parkhi ashutosh.parkhi@imgtec.com
[Tensorflow] Support for Crop (#2285)

commit f6c3f997
Author: Alexey Romanov alexey.v.romanov@gmail.com
[FRONTEND][TENSORFLOW] Use input shapes directly instead of 1-element lists (#2242)

commit e5d92e1b
Author: Dominic Symes 36929632+dominicsymes@users.noreply.github.com
[FRONTEND][TENSORFLOW] Bugfix (#2326)

commit 00d509d4
Author: Alexey Romanov alexey.v.romanov@gmail.com
[FRONTEND][TENSORFLOW] Support Unstack and Split (#2105)

commit df9d3ad2
Author: Siva sivar.b@huawei.com
[FRONTEND][TENSORFLOW] Bugfix (#2267)

commit d1a0c901
Author: Zhebin Jin zhebin.jzb@alibaba-inc.com
[FRONTEND][TENSORFLOW]Add Split and realdiv op support (#2123)
* Add Split and realdiv op support
* Fix the pad calculation in the case of dilated convolution

* 	* review comments

* 	* resnet fix.

* 	* review comments
parent f63631fc
...@@ -137,7 +137,7 @@ def is_gpu_available(): ...@@ -137,7 +137,7 @@ def is_gpu_available():
from tensorflow.python.client import device_lib from tensorflow.python.client import device_lib
local_device_protos = device_lib.list_local_devices() local_device_protos = device_lib.list_local_devices()
gpu_list = [x.name for x in local_device_protos if x.device_type == 'GPU'] gpu_list = [x.name for x in local_device_protos if x.device_type == 'GPU']
if len(gpu_list) < 0: if len(gpu_list) > 0:
print("Tensorflow GPU:", gpu_list) print("Tensorflow GPU:", gpu_list)
return True return True
else: else:
...@@ -168,7 +168,7 @@ def _test_pooling(input_shape, **kwargs): ...@@ -168,7 +168,7 @@ def _test_pooling(input_shape, **kwargs):
if is_gpu_available(): if is_gpu_available():
input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)] input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)]
kwargs['data_layout'] = 'NCHW' kwargs['data_format'] = 'NCHW'
_test_pooling_iteration(input_shape, **kwargs) _test_pooling_iteration(input_shape, **kwargs)
def test_forward_pooling(): def test_forward_pooling():
...@@ -225,8 +225,12 @@ def _test_convolution(tensor_in_sizes, filter_in_sizes, ...@@ -225,8 +225,12 @@ def _test_convolution(tensor_in_sizes, filter_in_sizes,
with tf.Graph().as_default(): with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype='float32') in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype='float32')
in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype='float32') in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype='float32')
if data_format == 'NHWC':
strides = [1] + strides + [1] strides = [1] + strides + [1]
dilations = [1] + dilations + [1] dilations = [1] + dilations + [1]
else:
strides = [1, 1] + strides
dilations = [1, 1] + dilations
nn_ops.conv2d(in_data, nn_ops.conv2d(in_data,
in_filter, in_filter,
...@@ -898,7 +902,7 @@ def test_forward_mobilenet(): ...@@ -898,7 +902,7 @@ def test_forward_mobilenet():
####################################################################### #######################################################################
# ResnetV2 # ResnetV2
# --------- # --------
def test_forward_resnetv2(): def test_forward_resnetv2():
'''test resnet model''' '''test resnet model'''
if is_gpu_available(): if is_gpu_available():
...@@ -912,7 +916,12 @@ def test_forward_resnetv2(): ...@@ -912,7 +916,12 @@ def test_forward_resnetv2():
with tf.Session() as sess: with tf.Session() as sess:
tf_output = run_tf_graph(sess, data, 'input_tensor:0', out_node + ':0') tf_output = run_tf_graph(sess, data, 'input_tensor:0', out_node + ':0')
tvm_output = run_tvm_graph(graph_def, data, 'input_tensor', tf_output.shape, 'float32') for device in ["llvm", "cuda"]:
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
continue
tvm_output = run_tvm_graph(graph_def, data, 'input_tensor', len(tf_output), target=device)
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5) tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5)
####################################################################### #######################################################################
......
...@@ -81,6 +81,7 @@ class AttrCvt(object): ...@@ -81,6 +81,7 @@ class AttrCvt(object):
self._ignores.append('_node_name') self._ignores.append('_node_name')
self._ignores.append('is_training') self._ignores.append('is_training')
self._ignores.append('_target_layout') self._ignores.append('_target_layout')
self._ignores.append('_input_0d_mismatch')
# apply custom check # apply custom check
if self._custom_check: if self._custom_check:
...@@ -227,7 +228,7 @@ def _pooling(name): ...@@ -227,7 +228,7 @@ def _pooling(name):
attr['data_format'] = attr['data_format'].decode("utf-8") attr['data_format'] = attr['data_format'].decode("utf-8")
flip_layout = False flip_layout = False
input_shape = attr['_input_shapes'][inputs[0]][0] input_shape = attr['_input_shapes'][inputs[0]]
if attr['data_format'] == 'NHWC': if attr['data_format'] == 'NHWC':
attr['kernel_shape'] = (attr['ksize'][1], attr['ksize'][2]) attr['kernel_shape'] = (attr['ksize'][1], attr['ksize'][2])
...@@ -239,7 +240,7 @@ def _pooling(name): ...@@ -239,7 +240,7 @@ def _pooling(name):
raise TypeError("Unsupported data_format type : {}".format(attr['data_format'])) raise TypeError("Unsupported data_format type : {}".format(attr['data_format']))
if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC": if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC":
tmp_shape = attr['_input_shapes'][inputs[0]][0] tmp_shape = attr['_input_shapes'][inputs[0]]
input_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)] input_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)]
inputs[0] = _op.transpose(inputs[0], axes=(0, 3, 1, 2)) inputs[0] = _op.transpose(inputs[0], axes=(0, 3, 1, 2))
attr['data_format'] = "NCHW" attr['data_format'] = "NCHW"
...@@ -292,13 +293,13 @@ def _conv(opname): ...@@ -292,13 +293,13 @@ def _conv(opname):
# NCHW Layout require weights transpose # NCHW Layout require weights transpose
if attr['data_format'] == 'NCHW': if attr['data_format'] == 'NCHW':
tmp_shape = attr['_input_shapes'][inputs[1]][0] tmp_shape = attr['_input_shapes'][inputs[1]]
tmp_shape = [tmp_shape[ii] for ii in (3, 2, 0, 1)] tmp_shape = [tmp_shape[ii] for ii in (3, 2, 0, 1)]
inputs[1] = _op.transpose(inputs[1], axes=(3, 2, 0, 1)) inputs[1] = _op.transpose(inputs[1], axes=(3, 2, 0, 1))
attr['_input_shapes'][inputs[1]] = [tmp_shape] attr['_input_shapes'][inputs[1]] = tmp_shape
input_shape = attr['_input_shapes'][inputs[0]][0] input_shape = attr['_input_shapes'][inputs[0]]
weights_shape = attr['_input_shapes'][inputs[1]][0] weights_shape = attr['_input_shapes'][inputs[1]]
if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC": if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC":
input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)] input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)]
...@@ -323,7 +324,7 @@ def _conv(opname): ...@@ -323,7 +324,7 @@ def _conv(opname):
attr['channels'] = input_shape[3] * depth_mult attr['channels'] = input_shape[3] * depth_mult
if 'dilations' in attr: if 'dilations' in attr:
attr['dilations'] = (attr['dilations'][0], attr['dilations'][1]) attr['dilations'] = (attr['dilations'][1], attr['dilations'][2])
attr['strides'] = (attr['strides'][1], attr['strides'][2]) attr['strides'] = (attr['strides'][1], attr['strides'][2])
elif attr['data_format'] == 'NCHW': elif attr['data_format'] == 'NCHW':
depth_mult, _, kernel_h, kernel_w = weights_shape depth_mult, _, kernel_h, kernel_w = weights_shape
...@@ -360,8 +361,13 @@ def _conv(opname): ...@@ -360,8 +361,13 @@ def _conv(opname):
in_h = input_shape[2] in_h = input_shape[2]
in_w = input_shape[3] in_w = input_shape[3]
pad_v = _get_pad_pair(in_h, kernel_h, stride_h) dilation_h = attr['dilations'][0]
pad_h = _get_pad_pair(in_w, kernel_w, stride_w) dilation_w = attr['dilations'][1]
dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h)
pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w)
if attr['data_format'] == 'NHWC': if attr['data_format'] == 'NHWC':
inputs[0] = _op.nn.pad(data=inputs[0], inputs[0] = _op.nn.pad(data=inputs[0],
...@@ -425,8 +431,7 @@ def _expand_dims(): ...@@ -425,8 +431,7 @@ def _expand_dims():
dim_input = inputs.pop(1) dim_input = inputs.pop(1)
axis = params[dim_input.name_hint] axis = params[dim_input.name_hint]
params.pop(dim_input.name_hint) params.pop(dim_input.name_hint)
return AttrCvt(op_name="expand_dims", ignores=['Tdim'], return _expand_dims_0d_aware(inputs[0], attr, axis=axis.asnumpy()[0])
extras={'axis': int(axis.asnumpy()[0])})(inputs, attr)
return _impl return _impl
def _resize_bilinear(): def _resize_bilinear():
...@@ -461,6 +466,11 @@ def _matmul(): ...@@ -461,6 +466,11 @@ def _matmul():
return _impl return _impl
def _undef():
def _impl(inputs, attr, params):
return _sym.__undef__()
return _impl
def _identity(): def _identity():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
return inputs[0] return inputs[0]
...@@ -489,10 +499,26 @@ def _concat(): ...@@ -489,10 +499,26 @@ def _concat():
def _pack(): def _pack():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
axis = int(attr["axis"]) axis = int(attr["axis"])
inputs_reshaped = [_op.expand_dims(i, axis=axis, num_newaxis=1) for i in inputs] inputs_reshaped = [_expand_dims_0d_aware(i, attr, axis=axis, num_newaxis=1) for i in inputs]
return _op.concatenate(inputs_reshaped, axis) return _op.concatenate(inputs_reshaped, axis)
return _impl return _impl
def _slice():
def _impl(inputs, attr, params):
begin = params.pop(_get_name_hint(inputs[1])).asnumpy().tolist()
size = params.pop(_get_name_hint(inputs[2])).asnumpy().tolist()
data_shape = attr['_input_shapes'][inputs[0]]
data_dim = len(data_shape)
end = size
for i in range(data_dim):
if size[i] == -1:
end[i] = data_shape[i] - begin[i]
else:
end[i] += begin[i]
return _op.strided_slice(inputs[0], begin=begin, end=size)
return _impl
def _reshape(): def _reshape():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
try: try:
...@@ -596,7 +622,7 @@ def _relu6(): ...@@ -596,7 +622,7 @@ def _relu6():
def _shape(): def _shape():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
return np.array(attr['_input_shapes'][inputs[0]][0], dtype='int32') return np.array(attr['_input_shapes'][inputs[0]], dtype='int32')
return _impl return _impl
def _fill(): def _fill():
...@@ -671,7 +697,7 @@ def _stridedSlice(): ...@@ -671,7 +697,7 @@ def _stridedSlice():
new_axis_mask = int(attr.get('new_axis_mask', 0)) new_axis_mask = int(attr.get('new_axis_mask', 0))
shrink_axis_mask = int(attr.get('shrink_axis_mask', 0)) shrink_axis_mask = int(attr.get('shrink_axis_mask', 0))
data_shape = attr['_input_shapes'][inputs[0]] data_shape = attr['_input_shapes'][inputs[0]]
data_dim = len(data_shape[0]) data_dim = len(data_shape)
stride_dim = len(stride) stride_dim = len(stride)
def _transform_mask(stride_dim, ellipsis_mask): def _transform_mask(stride_dim, ellipsis_mask):
...@@ -702,7 +728,7 @@ def _stridedSlice(): ...@@ -702,7 +728,7 @@ def _stridedSlice():
+ new_axes_after_ellipsis), data_dim) + new_axes_after_ellipsis), data_dim)
for i in range(final_index, to_index): for i in range(final_index, to_index):
m_begin[final_index] = 0 m_begin[final_index] = 0
m_end[final_index] = data_shape[0][final_index] m_end[final_index] = data_shape[final_index]
m_stride[final_index] = 1 m_stride[final_index] = 1
fshape_indices.append(final_index) fshape_indices.append(final_index)
final_index += 1 final_index += 1
...@@ -712,19 +738,19 @@ def _stridedSlice(): ...@@ -712,19 +738,19 @@ def _stridedSlice():
if final_index == len(m_begin): if final_index == len(m_begin):
break break
if mask & begin_mask: if mask & begin_mask:
m_begin[final_index] = data_shape[0][final_index] \ m_begin[final_index] = data_shape[final_index] \
if stride[index] < 0 else 0 if stride[index] < 0 else 0
elif begin[index]: elif begin[index]:
m_begin[final_index] = begin[index] m_begin[final_index] = begin[index]
if mask & end_mask: if mask & end_mask:
m_end[final_index] = 0 if stride[index] < 0 \ m_end[final_index] = 0 if stride[index] < 0 \
else data_shape[0][final_index] else data_shape[final_index]
elif end[index]: elif end[index]:
m_end[final_index] = end[index] m_end[final_index] = end[index]
m_stride[final_index] = stride[index] m_stride[final_index] = stride[index]
if mask & shrink_axis_mask: if mask & shrink_axis_mask:
#Tensorflow make axis with shrink_axis_mask as dimension 1 #Tensorflow make axis with shrink_axis_mask as dimension 1
m_begin[final_index] = data_shape[0][final_index] + begin[index] \ m_begin[final_index] = data_shape[final_index] + begin[index] \
if begin[index] < 0 else begin[index] if begin[index] < 0 else begin[index]
m_end[final_index] = begin[index] + 1 m_end[final_index] = begin[index] + 1
m_stride[final_index] = 1 m_stride[final_index] = 1
...@@ -752,6 +778,9 @@ def _stridedSlice(): ...@@ -752,6 +778,9 @@ def _stridedSlice():
pass pass
else: else:
final_output.append(out_shape[gather_index]) final_output.append(out_shape[gather_index])
# Prevent 0-dim tensors which are not accepted by Relay
if not final_output:
final_output.append(1)
return _op.reshape(out, newshape=tuple(final_output)) return _op.reshape(out, newshape=tuple(final_output))
return _impl return _impl
...@@ -789,11 +818,10 @@ def _transpose(): ...@@ -789,11 +818,10 @@ def _transpose():
def _rank(): def _rank():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
input_shapes = attr['_input_shapes'][inputs[0]] input_shape = attr['_input_shapes'][inputs[0]]
assert len(inputs) == 1
name = attr["_node_name"] name = attr["_node_name"]
params[name] = tvm.nd.array([len(input_shapes[0])]) params[name] = tvm.nd.array([len(input_shape)])
return [_expr.var(name, return [_expr.var(name,
shape=params[name].shape, shape=params[name].shape,
dtype='int32')] dtype='int32')]
...@@ -844,6 +872,72 @@ def _broadcast(name): ...@@ -844,6 +872,72 @@ def _broadcast(name):
)(inputs, attr) )(inputs, attr)
return _impl return _impl
def _split(has_size_vector):
# TF documentation https://www.tensorflow.org/api_docs/python/tf/split
def _impl(inputs, attr, params):
try:
# order and number of inputs are different:
# if has_size_vector:
# https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/split-v
# else:
# https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/split
# in addition, `axis` and `num_or_size_splits` can be tensors in TensorFlow,
# we can only support constants
if has_size_vector:
input_node_index = 0
input_axis_index = 2
size_splits_input_name = _get_name_hint(inputs[1])
size_splits = params[size_splits_input_name].asnumpy()
section_beginnings = np.cumsum(size_splits)[:-1]
indices_or_sections = tuple(section_beginnings)
else:
input_node_index = 1
input_axis_index = 0
indices_or_sections = attr['num_split']
input_node = inputs[input_node_index]
axis_input_name = _get_name_hint(inputs[input_axis_index])
axis_input_value = params[axis_input_name].asnumpy()[0]
except (IndexError, KeyError):
raise TypeError( \
"Unsupported argument for split: `axis` and `num_or_size_splits` " \
"should be constants")
return _op.split(input_node,
indices_or_sections=indices_or_sections,
axis=int(axis_input_value))
return _impl
def _unpack():
def _impl(inputs, attr, params):
input_node = inputs[0]
axis = attr['axis']
input_shape = attr['_input_shapes'][input_node]
axis_length = input_shape[axis]
if axis_length < 0:
raise TypeError("Unstack with unknown axis length")
splitted = _op.split(input_node,
indices_or_sections=axis_length,
axis=axis)
#name=attr.get('_node_name', 'unstack'))
if axis == 0:
axis = None
else:
axis = [axis]
return _expr.TupleWrapper(
_expr.Tuple([_op.squeeze(split_item, axis=axis) \
for split_item in splitted]), len(splitted))
return _impl
def _expand_dims_0d_aware(data, attr, axis, num_newaxis=1):
if data in attr['_input_0d_mismatch']:
return data if num_newaxis == 1 else \
AttrCvt(op_name="expand_dims", ignores=['Tdim', 'N'],
extras={'axis': int(axis), 'num_newaxis': int(num_newaxis-1)})([data], attr)
return AttrCvt(op_name="expand_dims", ignores=['Tdim', 'N'],
extras={'axis': int(axis), 'num_newaxis': int(num_newaxis)})([data], attr)
def _softmax(): def _softmax():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
return AttrCvt(op_name='softmax', return AttrCvt(op_name='softmax',
...@@ -885,11 +979,13 @@ _convert_map = { ...@@ -885,11 +979,13 @@ _convert_map = {
'Add' : _elemwise('add'), 'Add' : _elemwise('add'),
'Sub' : _elemwise('subtract'), 'Sub' : _elemwise('subtract'),
'Mul' : _elemwise('multiply'), 'Mul' : _elemwise('multiply'),
'RealDiv' : _elemwise('div'),
'Maximum' : _elemwise('maximum'), 'Maximum' : _elemwise('maximum'),
'Minimum' : _elemwise('minimum'), 'Minimum' : _elemwise('minimum'),
'Sum' : _sum(), 'Sum' : _sum(),
'Square' : _square(), 'Square' : _square(),
'Pack' : _pack(), 'Pack' : _pack(),
'Slice' : _slice(),
'LeakyRelu' : AttrCvt('leaky_relu'), 'LeakyRelu' : AttrCvt('leaky_relu'),
'Relu' : AttrCvt('relu'), 'Relu' : AttrCvt('relu'),
'Reshape' : _reshape(), 'Reshape' : _reshape(),
...@@ -924,6 +1020,9 @@ _convert_map = { ...@@ -924,6 +1020,9 @@ _convert_map = {
'GreaterEqual' : _broadcast('greater_equal'), 'GreaterEqual' : _broadcast('greater_equal'),
'Equal' : _broadcast('equal'), 'Equal' : _broadcast('equal'),
'NotEqual' : _broadcast('not_equal'), 'NotEqual' : _broadcast('not_equal'),
'Split' : _split(False),
'SplitV' : _split(True),
'Unpack' : _unpack(),
} }
def _LSTMBlockCell(): def _LSTMBlockCell():
...@@ -958,8 +1057,8 @@ def _LSTMBlockCell(): ...@@ -958,8 +1057,8 @@ def _LSTMBlockCell():
forget_bias = attr.pop('forget_bias') forget_bias = attr.pop('forget_bias')
input_shape = attr['_input_shapes'][inputs[0]] input_shape = attr['_input_shapes'][inputs[0]]
weight_shape = attr['_input_shapes'][inputs[3]] weight_shape = attr['_input_shapes'][inputs[3]]
batch_size, input_size = input_shape[0][0], input_shape[0][1] batch_size, input_size = input_shape[0], input_shape[1]
num_hidden_layers = weight_shape[0][1] num_hidden_layers = weight_shape[1]
num_hidden = num_hidden_layers // 4 num_hidden = num_hidden_layers // 4
in_data = _op.reshape(in_data, in_data = _op.reshape(in_data,
...@@ -1087,8 +1186,8 @@ class RecurrentNetworks(object): ...@@ -1087,8 +1186,8 @@ class RecurrentNetworks(object):
input_shape = attr['_input_shapes'][inputs[0]] input_shape = attr['_input_shapes'][inputs[0]]
weight_shape = attr['_input_shapes'][inputs[3]] weight_shape = attr['_input_shapes'][inputs[3]]
batch_size = input_shape[0][0] batch_size = input_shape[0]
num_hidden = weight_shape[0][1] // 4 num_hidden = weight_shape[1] // 4
if layer == 0: if layer == 0:
#Create initial states placeholder in case of first layer #Create initial states placeholder in case of first layer
...@@ -1183,6 +1282,8 @@ class GraphProto(object): ...@@ -1183,6 +1282,8 @@ class GraphProto(object):
self._output_shapes = {} self._output_shapes = {}
self._num_param = 0 self._num_param = 0
self._num_rnn_layer = False self._num_rnn_layer = False
self._outputs_are_0d = {}
self._input_shapes = {}
def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
"""Construct relay nodes from tensorflow graph definition - GraphDef. """Construct relay nodes from tensorflow graph definition - GraphDef.
...@@ -1259,6 +1360,7 @@ class GraphProto(object): ...@@ -1259,6 +1360,7 @@ class GraphProto(object):
# Operator name 'Const' is treated as a parameter to build params dict. # Operator name 'Const' is treated as a parameter to build params dict.
input_shapes = {} input_shapes = {}
input_0d_mismatch = set()
attr = self._parse_attr(node.attr) attr = self._parse_attr(node.attr)
# Variable converted to Const will not have only value attr # Variable converted to Const will not have only value attr
...@@ -1267,6 +1369,8 @@ class GraphProto(object): ...@@ -1267,6 +1369,8 @@ class GraphProto(object):
elif shape and node.name in shape: elif shape and node.name in shape:
# Give priority to user argument. # Give priority to user argument.
self._output_shapes[node.name] = [shape[node.name]] self._output_shapes[node.name] = [shape[node.name]]
elif node.op == 'Placeholder':
self._output_shapes[node.name] = [self._input_shapes[node.name]]
elif '_output_shapes' in attr: elif '_output_shapes' in attr:
self._output_shapes[node.name] = \ self._output_shapes[node.name] = \
[tensor_util.TensorShapeProtoToList(tshape) \ [tensor_util.TensorShapeProtoToList(tshape) \
...@@ -1274,8 +1378,13 @@ class GraphProto(object): ...@@ -1274,8 +1378,13 @@ class GraphProto(object):
else: else:
# Keep the list indexable to avoid key error. # Keep the list indexable to avoid key error.
# Actual value will be filled after node creation. # Actual value will be filled after node creation.
# Will infer shapes if the graph is not frozen with add_shapes=True
self._output_shapes[node.name] = [None] self._output_shapes[node.name] = [None]
self._outputs_are_0d[node.name] = [ \
not shape if isinstance(tshape, list) else False \
for tshape in self._output_shapes[node.name]]
if node.op == "Placeholder": if node.op == "Placeholder":
self._output_shapes[node.name] = [self._input_shapes[node.name]] self._output_shapes[node.name] = [self._input_shapes[node.name]]
self._nodes[node.name] = [_expr.var(node.name, self._nodes[node.name] = [_expr.var(node.name,
...@@ -1315,10 +1424,33 @@ class GraphProto(object): ...@@ -1315,10 +1424,33 @@ class GraphProto(object):
# Fill shapes for all inputs in a list # Fill shapes for all inputs in a list
inputs = [] inputs = []
for i in node.input: for i in node.input:
if i in self._nodes: # Some TensorFlow operators internally maintain execution layers
inputs.append(self._nodes[i][0]) # and their output name includes the layer number along with
input_shapes[self._nodes[i][0]] = self._output_shapes[i] # graph node name. E.g. the node name is 'Model/RNN/cell_0/RnnCell', but the
# output tensor name is 'Model/RNN/cell_0/RnnCell:0'. In this case,
# the number has to be ignored for single-output nodes.
# On the other hand, for multi-output nodes the number is the output index,
# and the lack of the number implies 0.
tensor_name = i.split(':')
node_name = tensor_name[0]
if node_name in self._nodes:
in_sym = self._nodes[node_name]
if isinstance(in_sym, _expr.TupleWrapper):
tensor_slot = int(tensor_name[1]) if len(tensor_name) > 1 else 0
in_sym = [in_sym[tensor_slot]]
input_shape = self._output_shapes[node_name][tensor_slot]
else:
tensor_slot = 0
input_shape = self._output_shapes[node_name][0]
inputs.append(in_sym[0])
input_shapes[in_sym[0]] = input_shape
# This means the node is 1d in Relay and 0d in TF.
# See `_expand_dims_0d_aware`.
if self._outputs_are_0d[node_name][tensor_slot] and input_shape:
input_0d_mismatch.add(in_sym)
attr['_input_shapes'] = input_shapes attr['_input_shapes'] = input_shapes
attr['_input_0d_mismatch'] = input_0d_mismatch
op = self._convert_operator(node.op, inputs, attr, graph) op = self._convert_operator(node.op, inputs, attr, graph)
...@@ -1340,23 +1472,36 @@ class GraphProto(object): ...@@ -1340,23 +1472,36 @@ class GraphProto(object):
# Infer shapes even without specifying "add_shapes=True" # Infer shapes even without specifying "add_shapes=True"
if output_shapes == [None]: if output_shapes == [None]:
out_type = ir_pass.infer_type(self._nodes[node.name][0]) out_shapes = []
self._output_shapes[node.name] = [get_const_tuple(out_type.checked_type.shape)] for node_item in self._nodes[node.name]:
out_type = ir_pass.infer_type(node_item)
out_shapes.append(get_const_tuple(out_type.checked_type.shape))
self._output_shapes[node.name] = out_shapes
if self._output_shapes[node.name] and shape and node.name in shape: if self._output_shapes[node.name] and shape and node.name in shape:
assert self._output_shapes[node.name] == list(shape[node.name]) assert self._output_shapes[node.name] == list(shape[node.name])
# Infer shapes if passed explicitely # Infer shapes if passed explicitely
node_output = self._nodes[node.name] node_output = self._nodes[node.name]
out_type = ir_pass.infer_type(node_output[0]) if shape and (not self._output_shapes[node.name][0]
self._output_shapes[node.name] = [get_const_tuple(out_type.checked_type.shape)] or -1 in self._output_shapes[node.name][0]):
out_shapes = []
for node_item in node_output:
out_type = ir_pass.infer_type(node_item)
out_shapes.append(get_const_tuple(out_type.checked_type.shape))
self._output_shapes[node.name] = out_shapes
out = [] out = []
if outputs is None: if outputs is None:
out = op out = op
else: else:
out = [self._nodes[out_name][0] for out_name in outputs] for out_name in outputs:
if ":" in out_name:
out_name, out_num = out_name.split(":")
out_num = int(out_num)
out.append(self._nodes[out_name][out_num])
else:
out.append(self._nodes[out_name][0])
#Add the RNN outputs also with 'head' nodes of the relay graph #Add the RNN outputs also with 'head' nodes of the relay graph
if self._num_rnn_layer: if self._num_rnn_layer:
......
...@@ -127,7 +127,8 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False, ...@@ -127,7 +127,8 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False,
if no_gpu and device == 'cuda': if no_gpu and device == 'cuda':
continue continue
tvm_output = run_tvm_graph(final_graph_def, in_data, in_node, target=device) tvm_output = run_tvm_graph(final_graph_def, in_data, in_node, target=device,
out_names=out_name, num_output=len(out_name))
# since the names from tensorflow and relay runs are not exactly same, # since the names from tensorflow and relay runs are not exactly same,
# first len(tf_output) will be compared # first len(tf_output) will be compared
for i in range(len(tf_output)): for i in range(len(tf_output)):
...@@ -170,7 +171,7 @@ def _test_pooling(input_shape, **kwargs): ...@@ -170,7 +171,7 @@ def _test_pooling(input_shape, **kwargs):
if is_gpu_available(): if is_gpu_available():
input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)] input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)]
kwargs['data_layout'] = 'NCHW' kwargs['data_format'] = 'NCHW'
_test_pooling_iteration(input_shape, **kwargs) _test_pooling_iteration(input_shape, **kwargs)
def test_forward_pooling(): def test_forward_pooling():
...@@ -227,8 +228,12 @@ def _test_convolution(tensor_in_sizes, filter_in_sizes, ...@@ -227,8 +228,12 @@ def _test_convolution(tensor_in_sizes, filter_in_sizes,
with tf.Graph().as_default(): with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype='float32') in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype='float32')
in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype='float32') in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype='float32')
if data_format == 'NHWC':
strides = [1] + strides + [1] strides = [1] + strides + [1]
dilations = [1] + dilations + [1] dilations = [1] + dilations + [1]
else:
strides = [1, 1] + strides
dilations = [1, 1] + dilations
nn_ops.conv2d(in_data, nn_ops.conv2d(in_data,
in_filter, in_filter,
...@@ -504,6 +509,84 @@ def test_forward_gather(): ...@@ -504,6 +509,84 @@ def test_forward_gather():
_test_gather((3,3,3), (1,1,2), [[[1,0]]], 2, 'int32') _test_gather((3,3,3), (1,1,2), [[[1,0]]], 2, 'int32')
_test_gather((4,3,5,6), (1,4), [[2,1,0,0]], 0, 'float32') _test_gather((4,3,5,6), (1,4), [[2,1,0,0]], 0, 'float32')
#######################################################################
# Split
# -----
def _test_split(in_shape, axis, num_or_size_splits, dtype):
np_data = np.random.uniform(-5, 5, size=in_shape).astype(dtype)
""" One iteration of a Split """
tf.reset_default_graph()
in_data = tf.placeholder(dtype, in_shape, name="in_data")
num_split = len(num_or_size_splits) if isinstance(num_or_size_splits, list) else num_or_size_splits
tf.split(in_data, num_or_size_splits, axis=axis)
compare_tf_with_tvm([np_data], ['in_data:0'], [f'split:{n}' for n in range(num_split)])
# and now test together with concat
tf.reset_default_graph()
in_data = tf.placeholder(dtype, in_shape, name="in_data")
splitted = tf.split(in_data, num_or_size_splits, axis=axis)
tf.concat(splitted, axis)
compare_tf_with_tvm([np_data], 'in_data:0', 'concat:0')
def test_forward_split():
'''test split layer'''
# rank 1
_test_split((3,), 0, 1, 'float32')
_test_split((3,), 0, 3, 'float32')
_test_split((6,), 0, 3, 'float32')
# rank 2
_test_split((6, 2), 0, 3, 'float32')
_test_split((2, 6), 1, 6, 'float32')
# rank 3
_test_split((6, 2, 4), 0, 2, 'int32')
_test_split((2, 6, 4), 1, 3, 'float32')
_test_split((2, 4, 6), 2, 1, 'float32')
# rank 4
_test_split((6, 1, 3, 5), 0, 3, 'float32')
_test_split((1, 6, 3, 5), 1, 3, 'float32')
_test_split((1, 3, 6, 5), 2, 3, 'float32')
_test_split((1, 3, 5, 6), 3, 3, 'float32')
# split along negative axis
_test_split((6, 1, 3, 5), -4, 3, 'float32')
_test_split((1, 6, 3, 5), -3, 3, 'float32')
_test_split((1, 3, 6, 5), -2, 3, 'float32')
_test_split((1, 3, 5, 6), -1, 3, 'float32')
# size_splits list
_test_split((6,), 0, [1, 2, 3], 'int32')
_test_split((3, 6, 4), -2, [1, 4, 1], 'float32')
#######################################################################
# Unstack
# -------
def _test_unstack(ip_shape, axis, dtype):
np_data = np.random.uniform(-5, 5, size=ip_shape).astype(dtype)
tf.reset_default_graph()
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
tf.unstack(in_data, axis=axis)
compare_tf_with_tvm([np_data], ['in_data:0'], [f'unstack:{n}' for n in range(ip_shape[axis])])
tf.reset_default_graph()
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
tf.stack(tf.unstack(in_data, axis=axis), axis=axis)
compare_tf_with_tvm([np_data], ['in_data:0'], 'stack:0')
def test_forward_unstack():
'''test unstack layer'''
_test_unstack((6,), 0, 'int32')
_test_unstack((2,6), 1, 'float64')
# negative axis
_test_unstack((1,4), -1, 'int32')
_test_unstack((3,6,4), -2, 'float32')
####################################################################### #######################################################################
# Multi Input to graph # Multi Input to graph
...@@ -576,6 +659,22 @@ def test_forward_resize_bilinear(): ...@@ -576,6 +659,22 @@ def test_forward_resize_bilinear():
_test_resize_bilinear((4, 16, 32, 32), [50, 50], False) _test_resize_bilinear((4, 16, 32, 32), [50, 50], False)
_test_resize_bilinear((6, 32, 64, 64), [20, 20], True) _test_resize_bilinear((6, 32, 64, 64), [20, 20], True)
#######################################################################
# Crop to bounding box
# --------------------
def _test_crop(in_shape, off_h, off_w, tar_h, tar_w):
""" Crop to bounding box """
data = np.random.uniform(size=in_shape).astype('float32')
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
tf.image.crop_to_bounding_box(in_data, off_h, off_w, tar_h, tar_w)
compare_tf_with_tvm(data, 'Placeholder:0', 'crop_to_bounding_box/Slice:0')
def test_forward_crop():
""" Crop to bounding box """
_test_crop((1, 224, 224, 3), 20, 20, 120, 120)
####################################################################### #######################################################################
# LSTM # LSTM
...@@ -804,7 +903,7 @@ def test_forward_mobilenet(): ...@@ -804,7 +903,7 @@ def test_forward_mobilenet():
####################################################################### #######################################################################
# ResnetV2 # ResnetV2
# --------- # --------
def test_forward_resnetv2(): def test_forward_resnetv2():
'''test resnet model''' '''test resnet model'''
if is_gpu_available(): if is_gpu_available():
...@@ -818,7 +917,12 @@ def test_forward_resnetv2(): ...@@ -818,7 +917,12 @@ def test_forward_resnetv2():
with tf.Session() as sess: with tf.Session() as sess:
tf_output = run_tf_graph(sess, data, 'input_tensor:0', out_node + ':0') tf_output = run_tf_graph(sess, data, 'input_tensor:0', out_node + ':0')
tvm_output = run_tvm_graph(graph_def, data, 'input_tensor', tf_output.shape, 'float32') for device in ["llvm", "cuda"]:
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
continue
tvm_output = run_tvm_graph(graph_def, data, 'input_tensor', len(tf_output), target=device)
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5) tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5)
####################################################################### #######################################################################
...@@ -1106,9 +1210,12 @@ if __name__ == '__main__': ...@@ -1106,9 +1210,12 @@ if __name__ == '__main__':
test_forward_squeeze() test_forward_squeeze()
test_forward_pack() test_forward_pack()
test_forward_resize_bilinear() test_forward_resize_bilinear()
test_forward_crop()
test_forward_pad() test_forward_pad()
test_forward_gather() test_forward_gather()
test_forward_stridedslice() test_forward_stridedslice()
test_forward_split()
test_forward_unstack()
# Activations # Activations
test_forward_sigmoid() test_forward_sigmoid()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment