Commit b63267b9 by Zhao Wu Committed by Siva

[TFLite] Convert TFLite NCHW to NHWC (#3141)

* Convert TFLite NCHW to NHWC

* Minor comment fix
parent 4b1d3d87
...@@ -209,44 +209,10 @@ class OperatorConverter(object): ...@@ -209,44 +209,10 @@ class OperatorConverter(object):
reshape_options = ReshapeOptions() reshape_options = ReshapeOptions()
reshape_options.Init(op_options.Bytes, op_options.Pos) reshape_options.Init(op_options.Bytes, op_options.Pos)
target_shape = reshape_options.NewShapeAsNumpy() target_shape = reshape_options.NewShapeAsNumpy()
input_shape_length = len(input_tensor.tensor.ShapeAsNumpy())
in_expr = self.get_expr(input_tensor_idx) in_expr = self.get_expr(input_tensor_idx)
if input_shape_length in (1, 2):
# The rule is channel first (after N but before H, W).
# length of 1 means N*H*W*C, do nothing.
# length of 2 means N*H*W, C, do nothing.
pass
elif input_shape_length == 3:
# convert N C H*W to N H*W C
in_expr = _op.transpose(in_expr, axes=(0, 2, 1))
elif input_shape_length == 4:
# convert input to N H W C, then reshape to target shape,
# finally convert back if necessary
in_expr = _op.transpose(in_expr, axes=(0, 2, 3, 1))
else:
msg = 'Input shape length {} for operator Reshape is not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(input_shape_length))
out = _op.reshape(in_expr, newshape=tuple(target_shape)) out = _op.reshape(in_expr, newshape=tuple(target_shape))
# The rule is channel first.
# 1: N*H*W*C
# 2: N*H*W, C
# 3: N H W C, reshape to N H*W C, transpose to N C H*W
# 4: N H W C, transpose to N C H W
# add more if we need target shapes in future
if len(target_shape) == 1 or len(target_shape) == 2:
pass
elif len(target_shape) == 3:
out = _op.transpose(out, axes=(0, 2, 1))
elif len(target_shape) == 4:
out = _op.transpose(out, axes=(0, 3, 1, 2))
else:
raise tvm.error.OpAttributeInvalid(
'Length of target shape must be between 1 and 5 for operator Reshape.')
return out return out
def convert_softmax(self, op): def convert_softmax(self, op):
...@@ -269,7 +235,7 @@ class OperatorConverter(object): ...@@ -269,7 +235,7 @@ class OperatorConverter(object):
return out return out
def convert_concatenation(self, op): def convert_concatenation(self, op):
""" convert TFLite concatenation""" """Convert TFLite concatenation"""
try: try:
from tflite.Operator import Operator from tflite.Operator import Operator
from tflite.ConcatenationOptions import ConcatenationOptions from tflite.ConcatenationOptions import ConcatenationOptions
...@@ -292,15 +258,6 @@ class OperatorConverter(object): ...@@ -292,15 +258,6 @@ class OperatorConverter(object):
concatenation_options.Init(op_options.Bytes, op_options.Pos) concatenation_options.Init(op_options.Bytes, op_options.Pos)
concatenation_axis = concatenation_options.Axis() concatenation_axis = concatenation_options.Axis()
fused_activation_fn = concatenation_options.FusedActivationFunction() fused_activation_fn = concatenation_options.FusedActivationFunction()
input_shape_length = len(input_tensors[0].tensor.ShapeAsNumpy())
# TFLite is N H W C, our layout is N C H W
if input_shape_length <= 4:
axis_convert_map = [0] + list(range(2, input_shape_length)) + [1]
concatenation_axis = axis_convert_map[concatenation_axis]
else:
raise NotImplementedError("Not support input shape length {} of concatenatio : "
.format(str(input_shape_length)))
# with axis in N H W C # with axis in N H W C
out = _op.concatenate(in_exprs, axis=concatenation_axis) out = _op.concatenate(in_exprs, axis=concatenation_axis)
...@@ -336,20 +293,6 @@ class OperatorConverter(object): ...@@ -336,20 +293,6 @@ class OperatorConverter(object):
rhs_expr = self.exp_tab.new_const(self.get_tensor_value(rhs_tensor), rhs_expr = self.exp_tab.new_const(self.get_tensor_value(rhs_tensor),
dtype=rhs_type_str) dtype=rhs_type_str)
# In this case, we have to be careful about formatting.
input_shape_length = len(rhs_tensor.tensor.ShapeAsNumpy())
if input_shape_length in (1, 2):
pass
elif input_shape_length == 3:
# N H*W C to N C H*W
rhs_expr = _op.transpose(rhs_expr, axes=(0, 2, 1))
elif input_shape_length == 4:
# N H W C to N C H W
rhs_expr = _op.transpose(rhs_expr, axes=(0, 3, 1, 2))
else:
msg = 'Input shape length {} for operator ADD is not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(input_shape_length))
out = _op.add(lhs_expr, rhs_expr) out = _op.add(lhs_expr, rhs_expr)
return out return out
...@@ -440,46 +383,10 @@ class OperatorConverter(object): ...@@ -440,46 +383,10 @@ class OperatorConverter(object):
squeeze_options = SqueezeOptions() squeeze_options = SqueezeOptions()
squeeze_options.Init(op_options.Bytes, op_options.Pos) squeeze_options.Init(op_options.Bytes, op_options.Pos)
squeeze_axis = squeeze_options.SqueezeDimsAsNumpy() squeeze_axis = squeeze_options.SqueezeDimsAsNumpy()
input_shape_length = len(input_tensor.tensor.ShapeAsNumpy())
output_shape_length = len(output_tensors[0].tensor.ShapeAsNumpy())
in_expr = self.get_expr(input_tensor_idx) in_expr = self.get_expr(input_tensor_idx)
# TFLite is N H W C, our layout is N C H W
if input_shape_length in (1, 2):
# The rule is channel first (after N but before H, W).
# length of 1 means N*H*W*C, do nothing.
# length of 2 means N*H*W, C, do nothing.
pass
elif input_shape_length == 3:
# convert N C H*W to N H*W C
in_expr = _op.transpose(in_expr, axes=(0, 2, 1))
elif input_shape_length == 4:
# convert input to N H W C, then reshape to target shape,
# finally convert back if necessary
in_expr = _op.transpose(in_expr, axes=(0, 2, 3, 1))
else:
msg = 'Input shape length {} for operator Squeeze is not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(input_shape_length))
out = _op.squeeze(in_expr, axis=tuple(squeeze_axis)) out = _op.squeeze(in_expr, axis=tuple(squeeze_axis))
# The rule is channel first.
# 1: N*H*W*C
# 2: N*H*W, C
# 3: N H W C, reshape to N H*W C, transpose to N C H*W
# 4: N H W C, transpose to N C H W
# add more if we need target shapes in future
if output_shape_length in (1, 2):
pass
elif output_shape_length == 3:
out = _op.transpose(out, axes=(0, 2, 1))
elif output_shape_length == 4:
out = _op.transpose(out, axes=(0, 3, 1, 2))
else:
msg = 'Output shape length {} for operator Squeeze is not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(output_shape_length))
return out return out
def convert_fused_activation_function(self, in_expr, fused_activation_fn): def convert_fused_activation_function(self, in_expr, fused_activation_fn):
...@@ -562,13 +469,16 @@ class OperatorConverter(object): ...@@ -562,13 +469,16 @@ class OperatorConverter(object):
params = {'kernel_size': [kernel_h, kernel_w], params = {'kernel_size': [kernel_h, kernel_w],
'strides': [stride_h, stride_w], 'strides': [stride_h, stride_w],
'dilation': [dilation_h, dilation_w], 'dilation': [dilation_h, dilation_w],
'padding': [0, 0]} 'padding': [0, 0],
'data_layout': 'NHWC'}
if is_depthwise_conv: if is_depthwise_conv:
params['channels'] = int(in_channels * multiplier) params['channels'] = int(in_channels * multiplier)
params['groups'] = int(in_channels) params['groups'] = int(in_channels)
params['kernel_layout'] = 'HWOI'
else: else:
params['channels'] = int(output_channels) params['channels'] = int(output_channels)
params['kernel_layout'] = 'HWIO'
# weight tensor type should be UINT8 (quantization) or FLOAT32 # weight tensor type should be UINT8 (quantization) or FLOAT32
weight_tensor_type = weight_tensor.tensor.Type() weight_tensor_type = weight_tensor.tensor.Type()
...@@ -578,12 +488,9 @@ class OperatorConverter(object): ...@@ -578,12 +488,9 @@ class OperatorConverter(object):
in_expr = self.get_expr(input_tensor_idx) in_expr = self.get_expr(input_tensor_idx)
weight_value = self.get_tensor_value(weight_tensor) weight_value = self.get_tensor_value(weight_tensor)
if is_depthwise_conv: # TFLite is OC/M KH KW IC, we require KH KW IC OC/M
# TFLite is M KH KW IC, we require IC M KH KW # M means multiplier in depthwise convolution
weight_value = weight_value.transpose((3, 0, 1, 2)) weight_value = weight_value.transpose((1, 2, 3, 0))
else:
# TFLite is OC KH KW IC, we require OC IC KH kW
weight_value = weight_value.transpose((0, 3, 1, 2))
weight_expr = self.exp_tab.new_const(weight_value, dtype=weight_tensor_type_str) weight_expr = self.exp_tab.new_const(weight_value, dtype=weight_tensor_type_str)
...@@ -592,9 +499,10 @@ class OperatorConverter(object): ...@@ -592,9 +499,10 @@ class OperatorConverter(object):
elif padding == Padding.SAME: elif padding == Padding.SAME:
pad_top, pad_bottom = get_pad_value(input_h, dilated_kernel_h, stride_h) pad_top, pad_bottom = get_pad_value(input_h, dilated_kernel_h, stride_h)
pad_left, pad_right = get_pad_value(input_w, dilated_kernel_w, stride_w) pad_left, pad_right = get_pad_value(input_w, dilated_kernel_w, stride_w)
in_expr = _op.nn.pad(data=in_expr, pad_width=((0, 0), (0, 0), in_expr = _op.nn.pad(data=in_expr, pad_width=((0, 0),
(pad_top, pad_bottom), (pad_top, pad_bottom),
(pad_left, pad_right))) (pad_left, pad_right),
(0, 0)))
else: else:
raise tvm.error.OpAttributeUnimplemented( raise tvm.error.OpAttributeUnimplemented(
'Padding format {} is not supported for operator Conv.'.format(padding)) 'Padding format {} is not supported for operator Conv.'.format(padding))
...@@ -610,7 +518,8 @@ class OperatorConverter(object): ...@@ -610,7 +518,8 @@ class OperatorConverter(object):
bias_tensor_type_str = self.get_tensor_type_str(bias_tensor_type) bias_tensor_type_str = self.get_tensor_type_str(bias_tensor_type)
bias_expr = self.exp_tab.new_const(self.get_tensor_value(bias_tensor), bias_expr = self.exp_tab.new_const(self.get_tensor_value(bias_tensor),
dtype=bias_tensor_type_str) dtype=bias_tensor_type_str)
out = _op.nn.bias_add(out, bias_expr) channel_axis = 3
out = _op.nn.bias_add(out, bias_expr, axis=channel_axis)
# If we have fused activations # If we have fused activations
if fused_activation_fn != ActivationFunctionType.NONE: if fused_activation_fn != ActivationFunctionType.NONE:
...@@ -648,7 +557,8 @@ class OperatorConverter(object): ...@@ -648,7 +557,8 @@ class OperatorConverter(object):
params = {'pool_size': (filter_h, filter_w), params = {'pool_size': (filter_h, filter_w),
'strides': (stride_h, stride_w), 'strides': (stride_h, stride_w),
'padding': [0, 0]} 'padding': [0, 0],
'layout': 'NHWC'}
in_expr = self.get_expr(input_tensor_idx) in_expr = self.get_expr(input_tensor_idx)
......
...@@ -117,32 +117,23 @@ plt.imshow(resized_image) ...@@ -117,32 +117,23 @@ plt.imshow(resized_image)
plt.show() plt.show()
image_data = np.asarray(resized_image).astype("float32") image_data = np.asarray(resized_image).astype("float32")
# convert HWC to CHW # after expand_dims, we have format NHWC
image_data = image_data.transpose((2, 0, 1))
# after expand_dims, we have format NCHW
image_data = np.expand_dims(image_data, axis=0) image_data = np.expand_dims(image_data, axis=0)
# preprocess image as described here: # preprocess image as described here:
# https://github.com/tensorflow/models/blob/edb6ed22a801665946c63d650ab9a0b23d98e1b1/research/slim/preprocessing/inception_preprocessing.py#L243 # https://github.com/tensorflow/models/blob/edb6ed22a801665946c63d650ab9a0b23d98e1b1/research/slim/preprocessing/inception_preprocessing.py#L243
image_data[:, 0, :, :] = 2.0 / 255.0 * image_data[:, 0, :, :] - 1 image_data[:, :, :, 0] = 2.0 / 255.0 * image_data[:, :, :, 0] - 1
image_data[:, 1, :, :] = 2.0 / 255.0 * image_data[:, 1, :, :] - 1 image_data[:, :, :, 1] = 2.0 / 255.0 * image_data[:, :, :, 1] - 1
image_data[:, 2, :, :] = 2.0 / 255.0 * image_data[:, 2, :, :] - 1 image_data[:, :, :, 2] = 2.0 / 255.0 * image_data[:, :, :, 2] - 1
print('input', image_data.shape) print('input', image_data.shape)
####################################################################
#
# .. note:: Input layout:
#
# Currently, TVM TFLite frontend accepts ``NCHW`` as input layout.
###################################################################### ######################################################################
# Compile the model with relay # Compile the model with relay
# --------------------------------------------- # ---------------------------------------------
# TFLite input tensor name, shape and type # TFLite input tensor name, shape and type
input_tensor = "input" input_tensor = "input"
input_shape = (1, 3, 224, 224) input_shape = (1, 224, 224, 3)
input_dtype = "float32" input_dtype = "float32"
# parse TFLite model and convert into Relay computation graph # parse TFLite model and convert into Relay computation graph
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment