Commit b63267b9 by Zhao Wu Committed by Siva

[TFLite] Convert TFLite NCHW to NHWC (#3141)

* Convert TFLite NCHW to NHWC

* Minor comment fix
parent 4b1d3d87
...@@ -209,44 +209,10 @@ class OperatorConverter(object): ...@@ -209,44 +209,10 @@ class OperatorConverter(object):
reshape_options = ReshapeOptions() reshape_options = ReshapeOptions()
reshape_options.Init(op_options.Bytes, op_options.Pos) reshape_options.Init(op_options.Bytes, op_options.Pos)
target_shape = reshape_options.NewShapeAsNumpy() target_shape = reshape_options.NewShapeAsNumpy()
input_shape_length = len(input_tensor.tensor.ShapeAsNumpy())
in_expr = self.get_expr(input_tensor_idx) in_expr = self.get_expr(input_tensor_idx)
if input_shape_length in (1, 2):
# The rule is channel first (after N but before H, W).
# length of 1 means N*H*W*C, do nothing.
# length of 2 means N*H*W, C, do nothing.
pass
elif input_shape_length == 3:
# convert N C H*W to N H*W C
in_expr = _op.transpose(in_expr, axes=(0, 2, 1))
elif input_shape_length == 4:
# convert input to N H W C, then reshape to target shape,
# finally convert back if necessary
in_expr = _op.transpose(in_expr, axes=(0, 2, 3, 1))
else:
msg = 'Input shape length {} for operator Reshape is not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(input_shape_length))
out = _op.reshape(in_expr, newshape=tuple(target_shape)) out = _op.reshape(in_expr, newshape=tuple(target_shape))
# The rule is channel first.
# 1: N*H*W*C
# 2: N*H*W, C
# 3: N H W C, reshape to N H*W C, transpose to N C H*W
# 4: N H W C, transpose to N C H W
# add more if we need target shapes in future
if len(target_shape) == 1 or len(target_shape) == 2:
pass
elif len(target_shape) == 3:
out = _op.transpose(out, axes=(0, 2, 1))
elif len(target_shape) == 4:
out = _op.transpose(out, axes=(0, 3, 1, 2))
else:
raise tvm.error.OpAttributeInvalid(
'Length of target shape must be between 1 and 5 for operator Reshape.')
return out return out
def convert_softmax(self, op): def convert_softmax(self, op):
...@@ -269,7 +235,7 @@ class OperatorConverter(object): ...@@ -269,7 +235,7 @@ class OperatorConverter(object):
return out return out
def convert_concatenation(self, op): def convert_concatenation(self, op):
""" convert TFLite concatenation""" """Convert TFLite concatenation"""
try: try:
from tflite.Operator import Operator from tflite.Operator import Operator
from tflite.ConcatenationOptions import ConcatenationOptions from tflite.ConcatenationOptions import ConcatenationOptions
...@@ -292,15 +258,6 @@ class OperatorConverter(object): ...@@ -292,15 +258,6 @@ class OperatorConverter(object):
concatenation_options.Init(op_options.Bytes, op_options.Pos) concatenation_options.Init(op_options.Bytes, op_options.Pos)
concatenation_axis = concatenation_options.Axis() concatenation_axis = concatenation_options.Axis()
fused_activation_fn = concatenation_options.FusedActivationFunction() fused_activation_fn = concatenation_options.FusedActivationFunction()
input_shape_length = len(input_tensors[0].tensor.ShapeAsNumpy())
# TFLite is N H W C, our layout is N C H W
if input_shape_length <= 4:
axis_convert_map = [0] + list(range(2, input_shape_length)) + [1]
concatenation_axis = axis_convert_map[concatenation_axis]
else:
raise NotImplementedError("Not support input shape length {} of concatenatio : "
.format(str(input_shape_length)))
# with axis in N H W C # with axis in N H W C
out = _op.concatenate(in_exprs, axis=concatenation_axis) out = _op.concatenate(in_exprs, axis=concatenation_axis)
...@@ -336,20 +293,6 @@ class OperatorConverter(object): ...@@ -336,20 +293,6 @@ class OperatorConverter(object):
rhs_expr = self.exp_tab.new_const(self.get_tensor_value(rhs_tensor), rhs_expr = self.exp_tab.new_const(self.get_tensor_value(rhs_tensor),
dtype=rhs_type_str) dtype=rhs_type_str)
# In this case, we have to be careful about formatting.
input_shape_length = len(rhs_tensor.tensor.ShapeAsNumpy())
if input_shape_length in (1, 2):
pass
elif input_shape_length == 3:
# N H*W C to N C H*W
rhs_expr = _op.transpose(rhs_expr, axes=(0, 2, 1))
elif input_shape_length == 4:
# N H W C to N C H W
rhs_expr = _op.transpose(rhs_expr, axes=(0, 3, 1, 2))
else:
msg = 'Input shape length {} for operator ADD is not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(input_shape_length))
out = _op.add(lhs_expr, rhs_expr) out = _op.add(lhs_expr, rhs_expr)
return out return out
...@@ -440,46 +383,10 @@ class OperatorConverter(object): ...@@ -440,46 +383,10 @@ class OperatorConverter(object):
squeeze_options = SqueezeOptions() squeeze_options = SqueezeOptions()
squeeze_options.Init(op_options.Bytes, op_options.Pos) squeeze_options.Init(op_options.Bytes, op_options.Pos)
squeeze_axis = squeeze_options.SqueezeDimsAsNumpy() squeeze_axis = squeeze_options.SqueezeDimsAsNumpy()
input_shape_length = len(input_tensor.tensor.ShapeAsNumpy())
output_shape_length = len(output_tensors[0].tensor.ShapeAsNumpy())
in_expr = self.get_expr(input_tensor_idx) in_expr = self.get_expr(input_tensor_idx)
# TFLite is N H W C, our layout is N C H W
if input_shape_length in (1, 2):
# The rule is channel first (after N but before H, W).
# length of 1 means N*H*W*C, do nothing.
# length of 2 means N*H*W, C, do nothing.
pass
elif input_shape_length == 3:
# convert N C H*W to N H*W C
in_expr = _op.transpose(in_expr, axes=(0, 2, 1))
elif input_shape_length == 4:
# convert input to N H W C, then reshape to target shape,
# finally convert back if necessary
in_expr = _op.transpose(in_expr, axes=(0, 2, 3, 1))
else:
msg = 'Input shape length {} for operator Squeeze is not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(input_shape_length))
out = _op.squeeze(in_expr, axis=tuple(squeeze_axis)) out = _op.squeeze(in_expr, axis=tuple(squeeze_axis))
# The rule is channel first.
# 1: N*H*W*C
# 2: N*H*W, C
# 3: N H W C, reshape to N H*W C, transpose to N C H*W
# 4: N H W C, transpose to N C H W
# add more if we need target shapes in future
if output_shape_length in (1, 2):
pass
elif output_shape_length == 3:
out = _op.transpose(out, axes=(0, 2, 1))
elif output_shape_length == 4:
out = _op.transpose(out, axes=(0, 3, 1, 2))
else:
msg = 'Output shape length {} for operator Squeeze is not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(output_shape_length))
return out return out
def convert_fused_activation_function(self, in_expr, fused_activation_fn): def convert_fused_activation_function(self, in_expr, fused_activation_fn):
...@@ -562,13 +469,16 @@ class OperatorConverter(object): ...@@ -562,13 +469,16 @@ class OperatorConverter(object):
params = {'kernel_size': [kernel_h, kernel_w], params = {'kernel_size': [kernel_h, kernel_w],
'strides': [stride_h, stride_w], 'strides': [stride_h, stride_w],
'dilation': [dilation_h, dilation_w], 'dilation': [dilation_h, dilation_w],
'padding': [0, 0]} 'padding': [0, 0],
'data_layout': 'NHWC'}
if is_depthwise_conv: if is_depthwise_conv:
params['channels'] = int(in_channels * multiplier) params['channels'] = int(in_channels * multiplier)
params['groups'] = int(in_channels) params['groups'] = int(in_channels)
params['kernel_layout'] = 'HWOI'
else: else:
params['channels'] = int(output_channels) params['channels'] = int(output_channels)
params['kernel_layout'] = 'HWIO'
# weight tensor type should be UINT8 (quantization) or FLOAT32 # weight tensor type should be UINT8 (quantization) or FLOAT32
weight_tensor_type = weight_tensor.tensor.Type() weight_tensor_type = weight_tensor.tensor.Type()
...@@ -578,12 +488,9 @@ class OperatorConverter(object): ...@@ -578,12 +488,9 @@ class OperatorConverter(object):
in_expr = self.get_expr(input_tensor_idx) in_expr = self.get_expr(input_tensor_idx)
weight_value = self.get_tensor_value(weight_tensor) weight_value = self.get_tensor_value(weight_tensor)
if is_depthwise_conv: # TFLite is OC/M KH KW IC, we require KH KW IC OC/M
# TFLite is M KH KW IC, we require IC M KH KW # M means multiplier in depthwise convolution
weight_value = weight_value.transpose((3, 0, 1, 2)) weight_value = weight_value.transpose((1, 2, 3, 0))
else:
# TFLite is OC KH KW IC, we require OC IC KH kW
weight_value = weight_value.transpose((0, 3, 1, 2))
weight_expr = self.exp_tab.new_const(weight_value, dtype=weight_tensor_type_str) weight_expr = self.exp_tab.new_const(weight_value, dtype=weight_tensor_type_str)
...@@ -592,9 +499,10 @@ class OperatorConverter(object): ...@@ -592,9 +499,10 @@ class OperatorConverter(object):
elif padding == Padding.SAME: elif padding == Padding.SAME:
pad_top, pad_bottom = get_pad_value(input_h, dilated_kernel_h, stride_h) pad_top, pad_bottom = get_pad_value(input_h, dilated_kernel_h, stride_h)
pad_left, pad_right = get_pad_value(input_w, dilated_kernel_w, stride_w) pad_left, pad_right = get_pad_value(input_w, dilated_kernel_w, stride_w)
in_expr = _op.nn.pad(data=in_expr, pad_width=((0, 0), (0, 0), in_expr = _op.nn.pad(data=in_expr, pad_width=((0, 0),
(pad_top, pad_bottom), (pad_top, pad_bottom),
(pad_left, pad_right))) (pad_left, pad_right),
(0, 0)))
else: else:
raise tvm.error.OpAttributeUnimplemented( raise tvm.error.OpAttributeUnimplemented(
'Padding format {} is not supported for operator Conv.'.format(padding)) 'Padding format {} is not supported for operator Conv.'.format(padding))
...@@ -610,7 +518,8 @@ class OperatorConverter(object): ...@@ -610,7 +518,8 @@ class OperatorConverter(object):
bias_tensor_type_str = self.get_tensor_type_str(bias_tensor_type) bias_tensor_type_str = self.get_tensor_type_str(bias_tensor_type)
bias_expr = self.exp_tab.new_const(self.get_tensor_value(bias_tensor), bias_expr = self.exp_tab.new_const(self.get_tensor_value(bias_tensor),
dtype=bias_tensor_type_str) dtype=bias_tensor_type_str)
out = _op.nn.bias_add(out, bias_expr) channel_axis = 3
out = _op.nn.bias_add(out, bias_expr, axis=channel_axis)
# If we have fused activations # If we have fused activations
if fused_activation_fn != ActivationFunctionType.NONE: if fused_activation_fn != ActivationFunctionType.NONE:
...@@ -648,7 +557,8 @@ class OperatorConverter(object): ...@@ -648,7 +557,8 @@ class OperatorConverter(object):
params = {'pool_size': (filter_h, filter_w), params = {'pool_size': (filter_h, filter_w),
'strides': (stride_h, stride_w), 'strides': (stride_h, stride_w),
'padding': [0, 0]} 'padding': [0, 0],
'layout': 'NHWC'}
in_expr = self.get_expr(input_tensor_idx) in_expr = self.get_expr(input_tensor_idx)
......
...@@ -116,12 +116,10 @@ def run_tflite_graph(tflite_model_buf, input_data): ...@@ -116,12 +116,10 @@ def run_tflite_graph(tflite_model_buf, input_data):
return tflite_output return tflite_output
def compare_tflite_with_tvm(tflite_in_data, tvm_in_data, in_name, input_tensors, def compare_tflite_with_tvm(in_data, in_name, input_tensors,
output_tensors, output_need_transpose=False, output_tensors, init_global_variables=False):
init_global_variables=False):
"""Generic function to generate and compare TFLite and TVM output""" """Generic function to generate and compare TFLite and TVM output"""
tflite_in_data = convert_to_list(tflite_in_data) in_data = convert_to_list(in_data)
tvm_in_data = convert_to_list(tvm_in_data)
in_name = convert_to_list(in_name) in_name = convert_to_list(in_name)
in_node = [0] * len(in_name) in_node = [0] * len(in_name)
for i in range(len(in_name)): for i in range(len(in_name)):
...@@ -134,7 +132,7 @@ def compare_tflite_with_tvm(tflite_in_data, tvm_in_data, in_name, input_tensors, ...@@ -134,7 +132,7 @@ def compare_tflite_with_tvm(tflite_in_data, tvm_in_data, in_name, input_tensors,
converter = tf.contrib.lite.TFLiteConverter.from_session( converter = tf.contrib.lite.TFLiteConverter.from_session(
sess, input_tensors, output_tensors) sess, input_tensors, output_tensors)
tflite_model_buffer = converter.convert() tflite_model_buffer = converter.convert()
tflite_output = run_tflite_graph(tflite_model_buffer, tflite_in_data) tflite_output = run_tflite_graph(tflite_model_buffer, in_data)
for device in ["llvm"]: for device in ["llvm"]:
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
...@@ -142,25 +140,9 @@ def compare_tflite_with_tvm(tflite_in_data, tvm_in_data, in_name, input_tensors, ...@@ -142,25 +140,9 @@ def compare_tflite_with_tvm(tflite_in_data, tvm_in_data, in_name, input_tensors,
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
continue continue
tvm_output = run_tvm_graph(tflite_model_buffer, tvm_in_data, in_node, target=device) tvm_output = run_tvm_graph(tflite_model_buffer, in_data, in_node, target=device)
for i in range(len(tflite_output)): for i in range(len(tflite_output)):
if output_need_transpose: tvm.testing.assert_allclose(tflite_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)
dim = len(tvm_output[i].shape)
if dim == 3:
# N C H*W to N H*W C
axes = (0, 2, 1)
elif dim == 4:
# N C H W to N H W C
axes = (0, 2, 3, 1)
else:
raise NotImplementedError("Not support input shape {} of transpose : ".
format(str(dim)))
tvm.testing.assert_allclose(tflite_output[i],
np.transpose(tvm_output[i], axes=axes),
atol=1e-5, rtol=1e-5)
else:
tvm.testing.assert_allclose(tflite_output[i], tvm_output[i],
atol=1e-5, rtol=1e-5)
sess.close() sess.close()
...@@ -173,14 +155,12 @@ def _test_pooling_iteration(input_shape, **kwargs): ...@@ -173,14 +155,12 @@ def _test_pooling_iteration(input_shape, **kwargs):
x = -np.arange( x = -np.arange(
np.prod(input_shape), dtype=np.float32).reshape(input_shape) - 1 np.prod(input_shape), dtype=np.float32).reshape(input_shape) - 1
tvm_data = np.transpose(x, axes=(0, 3, 1, 2))
with tf.Graph().as_default(): with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=input_shape, dtype='float32') in_data = array_ops.placeholder(shape=input_shape, dtype='float32')
out = nn_ops.pool(in_data, **kwargs) out = nn_ops.pool(in_data, **kwargs)
compare_tflite_with_tvm(x, tvm_data, 'Placeholder:0', [in_data], [out], compare_tflite_with_tvm(x,'Placeholder:0', [in_data], [out])
output_need_transpose=True)
def _test_pooling(input_shape, **kwargs): def _test_pooling(input_shape, **kwargs):
...@@ -258,13 +238,8 @@ def _test_convolution(tensor_in_sizes, filter_in_sizes, ...@@ -258,13 +238,8 @@ def _test_convolution(tensor_in_sizes, filter_in_sizes,
strides=strides, strides=strides,
padding=padding, padding=padding,
data_format=data_format) data_format=data_format)
# TFLite is NHWC, TVM is NCHW data_array = np.reshape(data_array, tensor_in_sizes).astype('float32')
tflite_data_array = np.reshape(data_array, tensor_in_sizes).astype('float32') compare_tflite_with_tvm(data_array, 'Placeholder:0', [in_data], [out])
tvm_data_array = np.transpose(tflite_data_array, axes=(0, 3, 1, 2))
# TFLite output is NHWC, TVM is NCHW, we need transpose
compare_tflite_with_tvm(tflite_data_array, tvm_data_array,
'Placeholder:0', [in_data], [out],
output_need_transpose=True)
def test_forward_convolution(): def test_forward_convolution():
...@@ -286,22 +261,11 @@ def test_forward_convolution(): ...@@ -286,22 +261,11 @@ def test_forward_convolution():
def _test_reshape(data, out_shape): def _test_reshape(data, out_shape):
""" One iteration of reshape operation with given data and out shape """ """ One iteration of reshape operation with given data and out shape """
# see relay/frontend/tflite.py convert_reshape more detail of channel first rule
if len(data.shape) == 1 or len(data.shape) == 2:
tvm_data = data
elif len(data.shape) == 3:
tvm_data = np.transpose(data, axes=(0, 2, 1))
elif len(data.shape) == 4:
tvm_data = np.transpose(data, axes=(0, 3, 1, 2))
else:
raise NotImplementedError("Not support input shape {} of reshape : ".
format(str(len(data))))
with tf.Graph().as_default(): with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
out = array_ops.reshape(in_data, out_shape) out = array_ops.reshape(in_data, out_shape)
compare_tflite_with_tvm(data, tvm_data, 'Placeholder:0', [in_data], [out]) compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
def test_forward_reshape(): def test_forward_reshape():
...@@ -319,18 +283,6 @@ def _test_concatenation(data, axis): ...@@ -319,18 +283,6 @@ def _test_concatenation(data, axis):
""" One iteration of concatenation """ """ One iteration of concatenation """
assert len(data) >= 1 assert len(data) >= 1
need_transpose = False
if len(data[0].shape) == 1 or len(data[0].shape) == 2:
tvm_data = data
elif len(data[0].shape) == 3:
#need_transpose = True
tvm_data = [np.transpose(d, axes=(0, 2, 1)) for d in data]
elif len(data[0].shape) == 4:
need_transpose = True
tvm_data = [np.transpose(d, axes=(0, 3, 1, 2)) for d in data]
else:
raise NotImplementedError("Not support input shape {} of reshape : ".
format(str(len(data))))
with tf.Graph().as_default(): with tf.Graph().as_default():
in_data = [ in_data = [
...@@ -339,7 +291,7 @@ def _test_concatenation(data, axis): ...@@ -339,7 +291,7 @@ def _test_concatenation(data, axis):
out = array_ops.concat(in_data, axis=axis) out = array_ops.concat(in_data, axis=axis)
name = ["in_{}:0".format(idx) for idx in range(len(data))] name = ["in_{}:0".format(idx) for idx in range(len(data))]
compare_tflite_with_tvm(data, tvm_data, name, in_data, [out], need_transpose) compare_tflite_with_tvm(data, name, in_data, [out])
def test_forward_concatenation(): def test_forward_concatenation():
...@@ -366,33 +318,19 @@ def _test_add(data): ...@@ -366,33 +318,19 @@ def _test_add(data):
""" One iteration of add """ """ One iteration of add """
assert len(data) == 2 assert len(data) == 2
need_transpose = False
if len(data[0].shape) == 1 or len(data[0].shape) == 2:
tvm_data = data
elif len(data[0].shape) == 3:
need_transpose = True
tvm_data = [np.transpose(d, axes=(0, 2, 1)) for d in data]
elif len(data[0].shape) == 4:
need_transpose = True
tvm_data = [np.transpose(d, axes=(0, 3, 1, 2)) for d in data]
else:
raise NotImplementedError("Not support input shape {} of add : ".
format(str(len(data.shape))))
# Test with two tensors # Test with two tensors
with tf.Graph().as_default(): with tf.Graph().as_default():
in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in_0'), in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in_0'),
array_ops.placeholder(shape=data[1].shape, dtype=data[1].dtype, name='in_1')] array_ops.placeholder(shape=data[1].shape, dtype=data[1].dtype, name='in_1')]
out = math_ops.add(in_data[0], in_data[1]) out = math_ops.add(in_data[0], in_data[1])
compare_tflite_with_tvm(data, tvm_data, ['in_0:0','in_1:0'], compare_tflite_with_tvm(data, ['in_0:0', 'in_1:0'], in_data, [out])
in_data, [out], need_transpose)
# Test with tensor and constant # Test with tensor and constant
with tf.Graph().as_default(): with tf.Graph().as_default():
in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in')] in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in')]
out = math_ops.add(in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype)) out = math_ops.add(in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype))
compare_tflite_with_tvm([data[0]], [tvm_data[0]], ['in:0'], compare_tflite_with_tvm([data[0]], ['in:0'], in_data, [out])
in_data, [out], need_transpose)
def test_forward_add(): def test_forward_add():
...@@ -415,19 +353,6 @@ def _test_squeeze(data, squeeze_dims=None): ...@@ -415,19 +353,6 @@ def _test_squeeze(data, squeeze_dims=None):
if squeeze_dims is None: if squeeze_dims is None:
squeeze_dims = [] squeeze_dims = []
# see relay/frontend/tflite.py convert_squeeze more detail of channel first rule
if len(data.shape) == 1 or len(data.shape) == 2:
tvm_data = data
elif len(data.shape) == 3:
tvm_data = np.transpose(data, axes=(0, 2, 1))
elif len(data.shape) == 4:
tvm_data = np.transpose(data, axes=(0, 3, 1, 2))
else:
raise NotImplementedError("Not support input shape {} of reshape : ".
format(str(len(data.shape))))
tvm_data = np.transpose(data, axes=(0, 3, 1, 2))
with tf.Graph().as_default(): with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
...@@ -436,7 +361,7 @@ def _test_squeeze(data, squeeze_dims=None): ...@@ -436,7 +361,7 @@ def _test_squeeze(data, squeeze_dims=None):
else: else:
out = array_ops.squeeze(in_data) out = array_ops.squeeze(in_data)
compare_tflite_with_tvm(data, tvm_data, 'Placeholder:0', [in_data], [out]) compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
def test_forward_squeeze(): def test_forward_squeeze():
...@@ -453,7 +378,7 @@ def _test_softmax(data): ...@@ -453,7 +378,7 @@ def _test_softmax(data):
with tf.Graph().as_default(): with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
out = nn_ops.softmax(in_data) out = nn_ops.softmax(in_data)
compare_tflite_with_tvm(data, data, 'Placeholder:0', [in_data], [out]) compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
def test_forward_softmax(): def test_forward_softmax():
""" Softmax """ """ Softmax """
...@@ -496,10 +421,8 @@ def _test_fully_connected(tensor_in_sizes, filter_in_sizes, bias_in_size=None): ...@@ -496,10 +421,8 @@ def _test_fully_connected(tensor_in_sizes, filter_in_sizes, bias_in_size=None):
in_bias = constant_op.constant(bias_array, shape=bias_in_size, dtype='float32') in_bias = constant_op.constant(bias_array, shape=bias_in_size, dtype='float32')
out = nn_ops.bias_add(out, in_bias) out = nn_ops.bias_add(out, in_bias)
tflite_data_array = np.reshape(data_array, tensor_in_sizes).astype('float32') data_array = np.reshape(data_array, tensor_in_sizes).astype('float32')
tvm_data_array = np.transpose(tflite_data_array, axes=(0, 3, 1, 2)) compare_tflite_with_tvm(data_array, 'Placeholder:0', [in_data], [out])
compare_tflite_with_tvm(tflite_data_array, tvm_data_array,
'Placeholder:0', [in_data], [out])
def test_forward_fully_connected(): def test_forward_fully_connected():
...@@ -523,9 +446,8 @@ def test_forward_mobilenet_v1(): ...@@ -523,9 +446,8 @@ def test_forward_mobilenet_v1():
with open(tflite_model_file, "rb") as f: with open(tflite_model_file, "rb") as f:
tflite_model_buf = f.read() tflite_model_buf = f.read()
data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32') data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32')
tvm_data = np.transpose(data, axes=(0, 3, 1, 2))
tflite_output = run_tflite_graph(tflite_model_buf, data) tflite_output = run_tflite_graph(tflite_model_buf, data)
tvm_output = run_tvm_graph(tflite_model_buf, tvm_data, 'input') tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
rtol=1e-5, atol=1e-5) rtol=1e-5, atol=1e-5)
...@@ -538,9 +460,8 @@ def test_forward_mobilenet_v2(): ...@@ -538,9 +460,8 @@ def test_forward_mobilenet_v2():
with open(tflite_model_file, "rb") as f: with open(tflite_model_file, "rb") as f:
tflite_model_buf = f.read() tflite_model_buf = f.read()
data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32') data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32')
tvm_data = np.transpose(data, axes=(0, 3, 1, 2))
tflite_output = run_tflite_graph(tflite_model_buf, data) tflite_output = run_tflite_graph(tflite_model_buf, data)
tvm_output = run_tvm_graph(tflite_model_buf, tvm_data, 'input') tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
rtol=1e-5, atol=1e-5) rtol=1e-5, atol=1e-5)
...@@ -557,9 +478,8 @@ def test_forward_inception_v3_net(): ...@@ -557,9 +478,8 @@ def test_forward_inception_v3_net():
with open(tflite_model_file, "rb") as f: with open(tflite_model_file, "rb") as f:
tflite_model_buf = f.read() tflite_model_buf = f.read()
data = np.random.uniform(size=(1, 299, 299, 3)).astype('float32') data = np.random.uniform(size=(1, 299, 299, 3)).astype('float32')
tvm_data = np.transpose(data, axes=(0, 3, 1, 2))
tflite_output = run_tflite_graph(tflite_model_buf, data) tflite_output = run_tflite_graph(tflite_model_buf, data)
tvm_output = run_tvm_graph(tflite_model_buf, tvm_data, 'input') tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
rtol=1e-5, atol=1e-5) rtol=1e-5, atol=1e-5)
...@@ -572,9 +492,8 @@ def test_forward_inception_v4_net(): ...@@ -572,9 +492,8 @@ def test_forward_inception_v4_net():
with open(tflite_model_file, "rb") as f: with open(tflite_model_file, "rb") as f:
tflite_model_buf = f.read() tflite_model_buf = f.read()
data = np.random.uniform(size=(1, 299, 299, 3)).astype('float32') data = np.random.uniform(size=(1, 299, 299, 3)).astype('float32')
tvm_data = np.transpose(data, axes=(0, 3, 1, 2))
tflite_output = run_tflite_graph(tflite_model_buf, data) tflite_output = run_tflite_graph(tflite_model_buf, data)
tvm_output = run_tvm_graph(tflite_model_buf, tvm_data, 'input') tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
rtol=1e-5, atol=1e-5) rtol=1e-5, atol=1e-5)
......
...@@ -117,32 +117,23 @@ plt.imshow(resized_image) ...@@ -117,32 +117,23 @@ plt.imshow(resized_image)
plt.show() plt.show()
image_data = np.asarray(resized_image).astype("float32") image_data = np.asarray(resized_image).astype("float32")
# convert HWC to CHW # after expand_dims, we have format NHWC
image_data = image_data.transpose((2, 0, 1))
# after expand_dims, we have format NCHW
image_data = np.expand_dims(image_data, axis=0) image_data = np.expand_dims(image_data, axis=0)
# preprocess image as described here: # preprocess image as described here:
# https://github.com/tensorflow/models/blob/edb6ed22a801665946c63d650ab9a0b23d98e1b1/research/slim/preprocessing/inception_preprocessing.py#L243 # https://github.com/tensorflow/models/blob/edb6ed22a801665946c63d650ab9a0b23d98e1b1/research/slim/preprocessing/inception_preprocessing.py#L243
image_data[:, 0, :, :] = 2.0 / 255.0 * image_data[:, 0, :, :] - 1 image_data[:, :, :, 0] = 2.0 / 255.0 * image_data[:, :, :, 0] - 1
image_data[:, 1, :, :] = 2.0 / 255.0 * image_data[:, 1, :, :] - 1 image_data[:, :, :, 1] = 2.0 / 255.0 * image_data[:, :, :, 1] - 1
image_data[:, 2, :, :] = 2.0 / 255.0 * image_data[:, 2, :, :] - 1 image_data[:, :, :, 2] = 2.0 / 255.0 * image_data[:, :, :, 2] - 1
print('input', image_data.shape) print('input', image_data.shape)
####################################################################
#
# .. note:: Input layout:
#
# Currently, TVM TFLite frontend accepts ``NCHW`` as input layout.
###################################################################### ######################################################################
# Compile the model with relay # Compile the model with relay
# --------------------------------------------- # ---------------------------------------------
# TFLite input tensor name, shape and type # TFLite input tensor name, shape and type
input_tensor = "input" input_tensor = "input"
input_shape = (1, 3, 224, 224) input_shape = (1, 224, 224, 3)
input_dtype = "float32" input_dtype = "float32"
# parse TFLite model and convert into Relay computation graph # parse TFLite model and convert into Relay computation graph
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment