Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
084e338e
Commit
084e338e
authored
Jun 09, 2019
by
Alexander Pivovarov
Committed by
Yao Wang
Jun 09, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add MUL operator to relay tflite frontend (#3304)
parent
98a91af9
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
46 additions
and
12 deletions
+46
-12
python/tvm/relay/frontend/tflite.py
+14
-6
tests/python/frontend/tflite/test_forward.py
+32
-6
No files found.
python/tvm/relay/frontend/tflite.py
View file @
084e338e
...
@@ -64,6 +64,7 @@ class OperatorConverter(object):
...
@@ -64,6 +64,7 @@ class OperatorConverter(object):
'MAX_POOL_2D'
:
self
.
convert_max_pool2d
,
'MAX_POOL_2D'
:
self
.
convert_max_pool2d
,
'CONCATENATION'
:
self
.
convert_concatenation
,
'CONCATENATION'
:
self
.
convert_concatenation
,
'ADD'
:
self
.
convert_add
,
'ADD'
:
self
.
convert_add
,
'MUL'
:
self
.
convert_mul
,
'FULLY_CONNECTED'
:
self
.
convert_fully_connected
,
'FULLY_CONNECTED'
:
self
.
convert_fully_connected
,
}
}
...
@@ -267,8 +268,8 @@ class OperatorConverter(object):
...
@@ -267,8 +268,8 @@ class OperatorConverter(object):
out
=
self
.
convert_fused_activation_function
(
out
,
fused_activation_fn
)
out
=
self
.
convert_fused_activation_function
(
out
,
fused_activation_fn
)
return
out
return
out
def
convert_add
(
self
,
op
):
def
_convert_elemwise
(
self
,
relay_op
,
op
):
"""
Convert TFLite add
"""
"""
Generic method to Convert TFLite elemwise
"""
try
:
try
:
from
tflite.Operator
import
Operator
from
tflite.Operator
import
Operator
except
ImportError
:
except
ImportError
:
...
@@ -283,19 +284,26 @@ class OperatorConverter(object):
...
@@ -283,19 +284,26 @@ class OperatorConverter(object):
rhs_tensor
=
input_tensors
[
1
]
rhs_tensor
=
input_tensors
[
1
]
if
self
.
has_expr
(
rhs_tensor
.
tensor_idx
):
if
self
.
has_expr
(
rhs_tensor
.
tensor_idx
):
# In most cases, we can assume that TOCO fuses
ADD
operators
# In most cases, we can assume that TOCO fuses
elemwise
operators
# with constants - it means both will be tensors.
# with constants - it means both will be tensors.
rhs_expr
=
self
.
get_expr
(
rhs_tensor
.
tensor_idx
)
rhs_expr
=
self
.
get_expr
(
rhs_tensor
.
tensor_idx
)
else
:
else
:
# However, in some corner cases, the
ADD
operator is not fused,
# However, in some corner cases, the
elemwise
operator is not fused,
# we can receive as constant.
# we can receive as constant.
rhs_type_str
=
self
.
get_tensor_type_str
(
rhs_tensor
.
tensor
.
Type
())
rhs_type_str
=
self
.
get_tensor_type_str
(
rhs_tensor
.
tensor
.
Type
())
rhs_expr
=
self
.
exp_tab
.
new_const
(
self
.
get_tensor_value
(
rhs_tensor
),
rhs_expr
=
self
.
exp_tab
.
new_const
(
self
.
get_tensor_value
(
rhs_tensor
),
dtype
=
rhs_type_str
)
dtype
=
rhs_type_str
)
out
=
relay_op
(
lhs_expr
,
rhs_expr
)
out
=
_op
.
add
(
lhs_expr
,
rhs_expr
)
return
out
return
out
def
convert_add
(
self
,
op
):
"""Convert TFLite ADD"""
return
self
.
_convert_elemwise
(
_op
.
add
,
op
)
def
convert_mul
(
self
,
op
):
"""Convert TFLite MUL"""
return
self
.
_convert_elemwise
(
_op
.
multiply
,
op
)
def
convert_fully_connected
(
self
,
op
):
def
convert_fully_connected
(
self
,
op
):
"""Convert TFLite fully connected"""
"""Convert TFLite fully connected"""
try
:
try
:
...
...
tests/python/frontend/tflite/test_forward.py
View file @
084e338e
...
@@ -24,7 +24,6 @@ from __future__ import print_function
...
@@ -24,7 +24,6 @@ from __future__ import print_function
import
numpy
as
np
import
numpy
as
np
import
tvm
import
tvm
from
tvm
import
relay
from
tvm
import
relay
from
tvm.contrib
import
util
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.python.framework
import
constant_op
from
tensorflow.python.framework
import
constant_op
from
tensorflow.python.framework
import
ops
from
tensorflow.python.framework
import
ops
...
@@ -144,8 +143,6 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors,
...
@@ -144,8 +143,6 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors,
for
i
in
range
(
len
(
tflite_output
)):
for
i
in
range
(
len
(
tflite_output
)):
tvm
.
testing
.
assert_allclose
(
tflite_output
[
i
],
tvm_output
[
i
],
atol
=
1e-5
,
rtol
=
1e-5
)
tvm
.
testing
.
assert_allclose
(
tflite_output
[
i
],
tvm_output
[
i
],
atol
=
1e-5
,
rtol
=
1e-5
)
sess
.
close
()
#######################################################################
#######################################################################
# Pooling
# Pooling
...
@@ -311,10 +308,10 @@ def test_forward_concatenation():
...
@@ -311,10 +308,10 @@ def test_forward_concatenation():
#######################################################################
#######################################################################
#
Add
#
Element-wise
# ---
# ---
def
_test_
add
(
data
):
def
_test_
elemwise
(
math_op
,
data
):
""" One iteration of add """
""" One iteration of add """
assert
len
(
data
)
==
2
assert
len
(
data
)
==
2
...
@@ -329,10 +326,19 @@ def _test_add(data):
...
@@ -329,10 +326,19 @@ def _test_add(data):
# Test with tensor and constant
# Test with tensor and constant
with
tf
.
Graph
()
.
as_default
():
with
tf
.
Graph
()
.
as_default
():
in_data
=
[
array_ops
.
placeholder
(
shape
=
data
[
0
]
.
shape
,
dtype
=
data
[
0
]
.
dtype
,
name
=
'in'
)]
in_data
=
[
array_ops
.
placeholder
(
shape
=
data
[
0
]
.
shape
,
dtype
=
data
[
0
]
.
dtype
,
name
=
'in'
)]
out
=
math_op
s
.
add
(
in_data
[
0
],
ops
.
convert_to_tensor
(
data
[
1
],
dtype
=
data
[
1
]
.
dtype
))
out
=
math_op
(
in_data
[
0
],
ops
.
convert_to_tensor
(
data
[
1
],
dtype
=
data
[
1
]
.
dtype
))
compare_tflite_with_tvm
([
data
[
0
]],
[
'in:0'
],
in_data
,
[
out
])
compare_tflite_with_tvm
([
data
[
0
]],
[
'in:0'
],
in_data
,
[
out
])
#######################################################################
# Add
# ---
def
_test_add
(
data
):
""" One iteration of add """
return
_test_elemwise
(
math_ops
.
add
,
data
)
def
test_forward_add
():
def
test_forward_add
():
""" Add """
""" Add """
_test_add
([
np
.
arange
(
6.0
,
dtype
=
np
.
float32
)
.
reshape
((
2
,
1
,
1
,
3
)),
_test_add
([
np
.
arange
(
6.0
,
dtype
=
np
.
float32
)
.
reshape
((
2
,
1
,
1
,
3
)),
...
@@ -344,6 +350,25 @@ def test_forward_add():
...
@@ -344,6 +350,25 @@ def test_forward_add():
#######################################################################
#######################################################################
# Mul
# ---
def
_test_mul
(
data
):
""" One iteration of mul """
return
_test_elemwise
(
math_ops
.
multiply
,
data
)
def
test_forward_mul
():
""" Mul """
_test_mul
([
np
.
arange
(
6.0
,
dtype
=
np
.
float32
)
.
reshape
((
2
,
1
,
1
,
3
)),
np
.
arange
(
6.0
,
dtype
=
np
.
float32
)
.
reshape
((
2
,
1
,
1
,
3
))])
_test_mul
([
np
.
arange
(
6.0
,
dtype
=
np
.
float32
)
.
reshape
((
2
,
1
,
3
)),
np
.
arange
(
6.0
,
dtype
=
np
.
float32
)
.
reshape
((
2
,
1
,
3
))])
_test_mul
([
np
.
arange
(
3.0
,
dtype
=
np
.
float32
)
.
reshape
((
1
,
3
)),
np
.
arange
(
3.0
,
dtype
=
np
.
float32
)
.
reshape
((
1
,
3
))])
#######################################################################
# Squeeze
# Squeeze
# -------
# -------
...
@@ -514,6 +539,7 @@ if __name__ == '__main__':
...
@@ -514,6 +539,7 @@ if __name__ == '__main__':
# Math
# Math
test_forward_add
()
test_forward_add
()
test_forward_mul
()
# End to End
# End to End
test_forward_mobilenet_v1
()
test_forward_mobilenet_v1
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment