Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
e68874d6
Commit
e68874d6
authored
Apr 04, 2019
by
Sunwoong Joo
Committed by
Tianqi Chen
Apr 03, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay][Frontend] Adding ADD operator to tflite frontend for compiling the MobileNetV2 (#2919)
parent
eb82e7b7
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
136 additions
and
8 deletions
+136
-8
python/tvm/relay/frontend/common.py
+3
-0
python/tvm/relay/frontend/tflite.py
+48
-1
tests/python/frontend/tflite/test_forward.py
+85
-7
No files found.
python/tvm/relay/frontend/common.py
View file @
e68874d6
...
@@ -258,6 +258,9 @@ class ExprTable(object):
...
@@ -258,6 +258,9 @@ class ExprTable(object):
if
name
not
in
self
.
exprs
:
if
name
not
in
self
.
exprs
:
self
.
exprs
[
name
]
=
expr
self
.
exprs
[
name
]
=
expr
def
has_expr
(
self
,
name
):
return
True
if
name
in
self
.
exprs
else
False
def
set_padding
(
self
,
paddings
):
def
set_padding
(
self
,
paddings
):
self
.
paddings
=
paddings
self
.
paddings
=
paddings
self
.
in_padding
=
True
self
.
in_padding
=
True
...
...
python/tvm/relay/frontend/tflite.py
View file @
e68874d6
...
@@ -46,7 +46,8 @@ class OperatorConverter(object):
...
@@ -46,7 +46,8 @@ class OperatorConverter(object):
'SOFTMAX'
:
self
.
convert_softmax
,
'SOFTMAX'
:
self
.
convert_softmax
,
'SQUEEZE'
:
self
.
convert_squeeze
,
'SQUEEZE'
:
self
.
convert_squeeze
,
'MAX_POOL_2D'
:
self
.
convert_max_pool2d
,
'MAX_POOL_2D'
:
self
.
convert_max_pool2d
,
"CONCATENATION"
:
self
.
convert_concatenation
'CONCATENATION'
:
self
.
convert_concatenation
,
'ADD'
:
self
.
convert_add
}
}
def
check_unsupported_ops
(
self
):
def
check_unsupported_ops
(
self
):
...
@@ -292,6 +293,49 @@ class OperatorConverter(object):
...
@@ -292,6 +293,49 @@ class OperatorConverter(object):
out
=
self
.
convert_fused_activation_function
(
out
,
fused_activation_fn
)
out
=
self
.
convert_fused_activation_function
(
out
,
fused_activation_fn
)
return
out
return
out
def
convert_add
(
self
,
op
):
"""Convert TFLite add"""
try
:
from
tflite.Operator
import
Operator
except
ImportError
:
raise
ImportError
(
"The tflite package must be installed"
)
assert
isinstance
(
op
,
Operator
)
input_tensors
=
self
.
get_input_tensors
(
op
)
assert
len
(
input_tensors
)
==
2
,
"input tensors length should be 2"
lhs_tensor
=
input_tensors
[
0
]
lhs_expr
=
self
.
get_expr
(
lhs_tensor
.
tensor_idx
)
rhs_tensor
=
input_tensors
[
1
]
if
self
.
has_expr
(
rhs_tensor
.
tensor_idx
):
# In most cases, we can assume that TOCO fuses ADD operators
# with constants - it means both will be tensors.
rhs_expr
=
self
.
get_expr
(
rhs_tensor
.
tensor_idx
)
else
:
# However, in some corner cases, the ADD operator is not fused,
# we can receive as constant.
rhs_type_str
=
self
.
get_tensor_type_str
(
rhs_tensor
.
tensor
.
Type
())
rhs_expr
=
self
.
exp_tab
.
new_const
(
self
.
get_tensor_value
(
rhs_tensor
),
dtype
=
rhs_type_str
)
# In this case, we have to be careful about formatting.
input_shape_length
=
len
(
rhs_tensor
.
tensor
.
ShapeAsNumpy
())
if
input_shape_length
in
(
1
,
2
):
pass
elif
input_shape_length
==
3
:
# N H*W C to N C H*W
rhs_expr
=
_op
.
transpose
(
rhs_expr
,
axes
=
(
0
,
2
,
1
))
elif
input_shape_length
==
4
:
# N H W C to N C H W
rhs_expr
=
_op
.
transpose
(
rhs_expr
,
axes
=
(
0
,
3
,
1
,
2
))
else
:
msg
=
'Input shape length {} for operator ADD is not valid.'
raise
tvm
.
error
.
OpAttributeInvalid
(
msg
.
format
(
input_shape_length
))
out
=
_op
.
add
(
lhs_expr
,
rhs_expr
)
return
out
def
convert_squeeze
(
self
,
op
):
def
convert_squeeze
(
self
,
op
):
"""Convert TFLite squeeze"""
"""Convert TFLite squeeze"""
try
:
try
:
...
@@ -554,6 +598,9 @@ class OperatorConverter(object):
...
@@ -554,6 +598,9 @@ class OperatorConverter(object):
def
get_expr
(
self
,
input_tensor_idx
):
def
get_expr
(
self
,
input_tensor_idx
):
return
self
.
exp_tab
.
get_expr
(
get_tensor_name
(
self
.
subgraph
,
input_tensor_idx
))
return
self
.
exp_tab
.
get_expr
(
get_tensor_name
(
self
.
subgraph
,
input_tensor_idx
))
def
has_expr
(
self
,
input_tensor_idx
):
return
self
.
exp_tab
.
has_expr
(
get_tensor_name
(
self
.
subgraph
,
input_tensor_idx
))
def
build_str_map
(
obj
):
def
build_str_map
(
obj
):
"""Build string map of TFLite enum int value
"""Build string map of TFLite enum int value
...
...
tests/python/frontend/tflite/test_forward.py
View file @
e68874d6
...
@@ -11,6 +11,8 @@ from tvm import relay
...
@@ -11,6 +11,8 @@ from tvm import relay
from
tvm.contrib
import
util
from
tvm.contrib
import
util
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.python.framework
import
constant_op
from
tensorflow.python.framework
import
constant_op
from
tensorflow.python.framework
import
ops
from
tensorflow.python.ops
import
math_ops
from
tensorflow.python.ops
import
nn_ops
from
tensorflow.python.ops
import
nn_ops
from
tensorflow.python.ops
import
array_ops
from
tensorflow.python.ops
import
array_ops
from
tensorflow.python.ops
import
variables
from
tensorflow.python.ops
import
variables
...
@@ -99,7 +101,7 @@ def run_tflite_graph(tflite_model_buf, input_data):
...
@@ -99,7 +101,7 @@ def run_tflite_graph(tflite_model_buf, input_data):
def
compare_tflite_with_tvm
(
tflite_in_data
,
tvm_in_data
,
in_name
,
input_tensors
,
def
compare_tflite_with_tvm
(
tflite_in_data
,
tvm_in_data
,
in_name
,
input_tensors
,
output_tensors
,
output_need_transpose
_nchw
=
False
,
output_tensors
,
output_need_transpose
=
False
,
init_global_variables
=
False
):
init_global_variables
=
False
):
"""Generic function to generate and compare TFLite and TVM output"""
"""Generic function to generate and compare TFLite and TVM output"""
tflite_in_data
=
convert_to_list
(
tflite_in_data
)
tflite_in_data
=
convert_to_list
(
tflite_in_data
)
...
@@ -126,9 +128,19 @@ def compare_tflite_with_tvm(tflite_in_data, tvm_in_data, in_name, input_tensors,
...
@@ -126,9 +128,19 @@ def compare_tflite_with_tvm(tflite_in_data, tvm_in_data, in_name, input_tensors,
tvm_output
=
run_tvm_graph
(
tflite_model_buffer
,
tvm_in_data
,
in_node
,
target
=
device
)
tvm_output
=
run_tvm_graph
(
tflite_model_buffer
,
tvm_in_data
,
in_node
,
target
=
device
)
for
i
in
range
(
len
(
tflite_output
)):
for
i
in
range
(
len
(
tflite_output
)):
if
output_need_transpose_nchw
:
if
output_need_transpose
:
dim
=
len
(
tvm_output
[
i
]
.
shape
)
if
dim
==
3
:
# N C H*W to N H*W C
axes
=
(
0
,
2
,
1
)
elif
dim
==
4
:
# N C H W to N H W C
axes
=
(
0
,
2
,
3
,
1
)
else
:
raise
NotImplementedError
(
"Not support input shape {} of transpose : "
.
format
(
str
(
dim
)))
tvm
.
testing
.
assert_allclose
(
tflite_output
[
i
],
tvm
.
testing
.
assert_allclose
(
tflite_output
[
i
],
np
.
transpose
(
tvm_output
[
i
],
axes
=
(
0
,
2
,
3
,
1
)
),
np
.
transpose
(
tvm_output
[
i
],
axes
=
axes
),
atol
=
1e-5
,
rtol
=
1e-5
)
atol
=
1e-5
,
rtol
=
1e-5
)
else
:
else
:
tvm
.
testing
.
assert_allclose
(
tflite_output
[
i
],
tvm_output
[
i
],
tvm
.
testing
.
assert_allclose
(
tflite_output
[
i
],
tvm_output
[
i
],
...
@@ -152,7 +164,7 @@ def _test_pooling_iteration(input_shape, **kwargs):
...
@@ -152,7 +164,7 @@ def _test_pooling_iteration(input_shape, **kwargs):
out
=
nn_ops
.
pool
(
in_data
,
**
kwargs
)
out
=
nn_ops
.
pool
(
in_data
,
**
kwargs
)
compare_tflite_with_tvm
(
x
,
tvm_data
,
'Placeholder:0'
,
[
in_data
],
[
out
],
compare_tflite_with_tvm
(
x
,
tvm_data
,
'Placeholder:0'
,
[
in_data
],
[
out
],
output_need_transpose
_nchw
=
True
)
output_need_transpose
=
True
)
def
_test_pooling
(
input_shape
,
**
kwargs
):
def
_test_pooling
(
input_shape
,
**
kwargs
):
...
@@ -236,7 +248,7 @@ def _test_convolution(tensor_in_sizes, filter_in_sizes,
...
@@ -236,7 +248,7 @@ def _test_convolution(tensor_in_sizes, filter_in_sizes,
# TFLite output is NHWC, TVM is NCHW, we need transpose
# TFLite output is NHWC, TVM is NCHW, we need transpose
compare_tflite_with_tvm
(
tflite_data_array
,
tvm_data_array
,
compare_tflite_with_tvm
(
tflite_data_array
,
tvm_data_array
,
'Placeholder:0'
,
[
in_data
],
[
out
],
'Placeholder:0'
,
[
in_data
],
[
out
],
output_need_transpose
_nchw
=
True
)
output_need_transpose
=
True
)
def
test_forward_convolution
():
def
test_forward_convolution
():
...
@@ -331,6 +343,53 @@ def test_forward_concatenation():
...
@@ -331,6 +343,53 @@ def test_forward_concatenation():
#######################################################################
#######################################################################
# Add
# ---
def
_test_add
(
data
):
""" One iteration of add """
assert
len
(
data
)
==
2
need_transpose
=
False
if
len
(
data
[
0
]
.
shape
)
==
1
or
len
(
data
[
0
]
.
shape
)
==
2
:
tvm_data
=
data
elif
len
(
data
[
0
]
.
shape
)
==
3
:
need_transpose
=
True
tvm_data
=
[
np
.
transpose
(
d
,
axes
=
(
0
,
2
,
1
))
for
d
in
data
]
elif
len
(
data
[
0
]
.
shape
)
==
4
:
need_transpose
=
True
tvm_data
=
[
np
.
transpose
(
d
,
axes
=
(
0
,
3
,
1
,
2
))
for
d
in
data
]
else
:
raise
NotImplementedError
(
"Not support input shape {} of add : "
.
format
(
str
(
len
(
data
.
shape
))))
# Test with two tensors
with
tf
.
Graph
()
.
as_default
():
in_data
=
[
array_ops
.
placeholder
(
shape
=
data
[
0
]
.
shape
,
dtype
=
data
[
0
]
.
dtype
,
name
=
'in_0'
),
array_ops
.
placeholder
(
shape
=
data
[
1
]
.
shape
,
dtype
=
data
[
1
]
.
dtype
,
name
=
'in_1'
)]
out
=
math_ops
.
add
(
in_data
[
0
],
in_data
[
1
])
compare_tflite_with_tvm
(
data
,
tvm_data
,
[
'in_0:0'
,
'in_1:0'
],
in_data
,
[
out
],
need_transpose
)
# Test with tensor and constant
with
tf
.
Graph
()
.
as_default
():
in_data
=
[
array_ops
.
placeholder
(
shape
=
data
[
0
]
.
shape
,
dtype
=
data
[
0
]
.
dtype
,
name
=
'in'
)]
out
=
math_ops
.
add
(
in_data
[
0
],
ops
.
convert_to_tensor
(
data
[
1
],
dtype
=
data
[
1
]
.
dtype
))
compare_tflite_with_tvm
([
data
[
0
]],
[
tvm_data
[
0
]],
[
'in:0'
],
in_data
,
[
out
],
need_transpose
)
def
test_forward_add
():
""" Add """
_test_add
([
np
.
arange
(
6.0
,
dtype
=
np
.
float32
)
.
reshape
((
2
,
1
,
1
,
3
)),
np
.
arange
(
6.0
,
dtype
=
np
.
float32
)
.
reshape
((
2
,
1
,
1
,
3
))])
_test_add
([
np
.
arange
(
6.0
,
dtype
=
np
.
float32
)
.
reshape
((
2
,
1
,
3
)),
np
.
arange
(
6.0
,
dtype
=
np
.
float32
)
.
reshape
((
2
,
1
,
3
))])
_test_add
([
np
.
arange
(
3.0
,
dtype
=
np
.
float32
)
.
reshape
((
1
,
3
)),
np
.
arange
(
3.0
,
dtype
=
np
.
float32
)
.
reshape
((
1
,
3
))])
#######################################################################
# Squeeze
# Squeeze
# -------
# -------
...
@@ -388,7 +447,7 @@ def test_forward_softmax():
...
@@ -388,7 +447,7 @@ def test_forward_softmax():
# Mobilenet
# Mobilenet
# ---------
# ---------
def
test_forward_mobilenet
():
def
test_forward_mobilenet
_v1
():
'''test mobilenet v1 tflite model'''
'''test mobilenet v1 tflite model'''
# MobilenetV1
# MobilenetV1
tflite_model_file
=
tf_testing
.
get_workload_official
(
tflite_model_file
=
tf_testing
.
get_workload_official
(
...
@@ -403,6 +462,21 @@ def test_forward_mobilenet():
...
@@ -403,6 +462,21 @@ def test_forward_mobilenet():
tvm
.
testing
.
assert_allclose
(
np
.
squeeze
(
tvm_output
[
0
]),
np
.
squeeze
(
tflite_output
[
0
]),
tvm
.
testing
.
assert_allclose
(
np
.
squeeze
(
tvm_output
[
0
]),
np
.
squeeze
(
tflite_output
[
0
]),
rtol
=
1e-5
,
atol
=
1e-5
)
rtol
=
1e-5
,
atol
=
1e-5
)
def
test_forward_mobilenet_v2
():
'''test mobilenet v2 tflite model'''
# MobilenetV2
tflite_model_file
=
tf_testing
.
get_workload_official
(
"http://download.tensorflow.org/models/tflite_11_05_08/mobilenet_v2_1.0_224.tgz"
,
"mobilenet_v2_1.0_224.tflite"
)
with
open
(
tflite_model_file
,
"rb"
)
as
f
:
tflite_model_buf
=
f
.
read
()
data
=
np
.
random
.
uniform
(
size
=
(
1
,
224
,
224
,
3
))
.
astype
(
'float32'
)
tvm_data
=
np
.
transpose
(
data
,
axes
=
(
0
,
3
,
1
,
2
))
tflite_output
=
run_tflite_graph
(
tflite_model_buf
,
data
)
tvm_output
=
run_tvm_graph
(
tflite_model_buf
,
tvm_data
,
'input'
)
tvm
.
testing
.
assert_allclose
(
np
.
squeeze
(
tvm_output
[
0
]),
np
.
squeeze
(
tflite_output
[
0
]),
rtol
=
1e-5
,
atol
=
1e-5
)
#######################################################################
#######################################################################
# Inception V3
# Inception V3
# ------------
# ------------
...
@@ -436,6 +510,10 @@ if __name__ == '__main__':
...
@@ -436,6 +510,10 @@ if __name__ == '__main__':
test_forward_pooling
()
test_forward_pooling
()
test_forward_softmax
()
test_forward_softmax
()
# Math
test_forward_add
()
# End to End
# End to End
test_forward_mobilenet
()
test_forward_mobilenet_v1
()
test_forward_mobilenet_v2
()
test_forward_inception_v3_net
()
test_forward_inception_v3_net
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment