Commit 968ffef6 by Zhao Wu Committed by Thierry Moreau

[TFLite] Support depthwise convolution multiplier greater than 1 (#3922)

parent 54dbcc28
...@@ -623,8 +623,6 @@ class OperatorConverter(object): ...@@ -623,8 +623,6 @@ class OperatorConverter(object):
conv_options = DepthwiseConv2DOptions() conv_options = DepthwiseConv2DOptions()
conv_options.Init(op_options.Bytes, op_options.Pos) conv_options.Init(op_options.Bytes, op_options.Pos)
depth_multiplier = conv_options.DepthMultiplier() depth_multiplier = conv_options.DepthMultiplier()
assert depth_multiplier == 1, "TF frontend transforms it to be 1 regardless of what " \
"original value is set to 0.25, 0.5 or anything else"
else: else:
raise tvm.error.OpNotImplemented( raise tvm.error.OpNotImplemented(
'Operator {} is not supported for frontend TFLite.'.format(conv_type)) 'Operator {} is not supported for frontend TFLite.'.format(conv_type))
...@@ -636,11 +634,13 @@ class OperatorConverter(object): ...@@ -636,11 +634,13 @@ class OperatorConverter(object):
padding = conv_options.Padding() padding = conv_options.Padding()
fused_activation_fn = conv_options.FusedActivationFunction() fused_activation_fn = conv_options.FusedActivationFunction()
_, input_h, input_w, _ = input_tensor.tensor.ShapeAsNumpy() _, input_h, input_w, input_c = input_tensor.tensor.ShapeAsNumpy()
if is_depthwise_conv: if is_depthwise_conv:
multiplier, kernel_h, kernel_w, in_channels = weight_tensor.tensor.ShapeAsNumpy() # TFLite depthwise convolution kernel layout is:
assert multiplier == depth_multiplier # 1 KH KW C(input_c * depth_multiplier)
_, kernel_h, kernel_w, in_channels = weight_tensor.tensor.ShapeAsNumpy()
assert in_channels == input_c * depth_multiplier
else: else:
output_channels, kernel_h, kernel_w, _ = weight_tensor.tensor.ShapeAsNumpy() output_channels, kernel_h, kernel_w, _ = weight_tensor.tensor.ShapeAsNumpy()
...@@ -654,7 +654,7 @@ class OperatorConverter(object): ...@@ -654,7 +654,7 @@ class OperatorConverter(object):
'data_layout': 'NHWC'} 'data_layout': 'NHWC'}
if is_depthwise_conv: if is_depthwise_conv:
params['channels'] = int(in_channels * multiplier) params['channels'] = int(in_channels)
params['groups'] = int(in_channels) params['groups'] = int(in_channels)
params['kernel_layout'] = 'HWOI' params['kernel_layout'] = 'HWOI'
else: else:
...@@ -669,9 +669,16 @@ class OperatorConverter(object): ...@@ -669,9 +669,16 @@ class OperatorConverter(object):
in_expr = self.get_expr(input_tensor_idx) in_expr = self.get_expr(input_tensor_idx)
weight_value = self.get_tensor_value(weight_tensor) weight_value = self.get_tensor_value(weight_tensor)
# TFLite is OC/M KH KW IC, we require KH KW IC OC/M # TFLite kernel layout:
# M means multiplier in depthwise convolution # convolution:
weight_value = weight_value.transpose((1, 2, 3, 0)) # OC KH KW IC, we require KH KW IC OC (HWIO)
# depthwise convolution:
# 1 KH KW C(input_c * depth_multiplier), we require
# KH KW IC M (depth_multiplier) (HWOI)
if is_depthwise_conv:
weight_value = weight_value.reshape(kernel_h, kernel_w, input_c, depth_multiplier)
else:
weight_value = weight_value.transpose((1, 2, 3, 0))
weight_expr = self.exp_tab.new_const(weight_value, dtype=weight_tensor_type_str) weight_expr = self.exp_tab.new_const(weight_value, dtype=weight_tensor_type_str)
......
...@@ -356,6 +356,7 @@ def test_forward_convolution(): ...@@ -356,6 +356,7 @@ def test_forward_convolution():
_test_convolution([4, 17, 17, 19], [3, 3, 19, 1], [1, 1], [2, 2], 'VALID', 'NHWC', True) _test_convolution([4, 17, 17, 19], [3, 3, 19, 1], [1, 1], [2, 2], 'VALID', 'NHWC', True)
_test_convolution([4, 17, 17, 124], [1, 1, 124, 1], [1, 1], [1, 1], 'SAME', 'NHWC', True) _test_convolution([4, 17, 17, 124], [1, 1, 124, 1], [1, 1], [1, 1], 'SAME', 'NHWC', True)
_test_convolution([4, 17, 17, 12], [3, 3, 12, 1], [1, 1], [2, 2], 'VALID', 'NHWC', True) _test_convolution([4, 17, 17, 12], [3, 3, 12, 1], [1, 1], [2, 2], 'VALID', 'NHWC', True)
_test_convolution([4, 17, 17, 12], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NHWC', True)
####################################################################### #######################################################################
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment