Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
fa351045
Commit
fa351045
authored
Jun 17, 2019
by
Zhi
Committed by
Tianqi Chen
Jun 17, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[relay][frontend] Return module from frontend parsers (#3353)
parent
07fbe5c8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
34 changed files
with
207 additions
and
152 deletions
+207
-152
python/tvm/relay/backend/interpreter.py
+42
-37
python/tvm/relay/backend/vm.py
+5
-3
python/tvm/relay/build_module.py
+7
-4
python/tvm/relay/frontend/caffe2.py
+9
-5
python/tvm/relay/frontend/coreml.py
+4
-3
python/tvm/relay/frontend/darknet.py
+5
-3
python/tvm/relay/frontend/keras.py
+4
-3
python/tvm/relay/frontend/mxnet.py
+14
-6
python/tvm/relay/frontend/onnx.py
+9
-7
python/tvm/relay/frontend/tensorflow.py
+11
-8
python/tvm/relay/frontend/tflite.py
+4
-3
tests/python/frontend/caffe2/test_forward.py
+3
-2
tests/python/frontend/caffe2/test_graph.py
+3
-2
tests/python/frontend/coreml/test_forward.py
+4
-4
tests/python/frontend/darknet/test_forward.py
+4
-2
tests/python/frontend/keras/test_forward.py
+4
-2
tests/python/frontend/mxnet/test_forward.py
+0
-0
tests/python/frontend/mxnet/test_graph.py
+16
-16
tests/python/frontend/onnx/test_forward.py
+4
-2
tests/python/frontend/tensorflow/test_control_flow.py
+3
-3
tests/python/frontend/tensorflow/test_forward.py
+6
-5
tests/python/frontend/tflite/test_forward.py
+6
-4
tests/python/relay/test_vm.py
+2
-2
tutorials/frontend/deploy_model_on_android.py
+2
-2
tutorials/frontend/deploy_model_on_rasp.py
+2
-1
tutorials/frontend/deploy_ssd_gluoncv.py
+2
-2
tutorials/frontend/from_caffe2.py
+2
-2
tutorials/frontend/from_coreml.py
+4
-2
tutorials/frontend/from_darknet.py
+5
-2
tutorials/frontend/from_keras.py
+3
-3
tutorials/frontend/from_mxnet.py
+4
-3
tutorials/frontend/from_onnx.py
+3
-3
tutorials/frontend/from_tensorflow.py
+7
-2
tutorials/frontend/from_tflite.py
+4
-4
No files found.
python/tvm/relay/backend/interpreter.py
View file @
fa351045
...
...
@@ -21,7 +21,8 @@ from __future__ import absolute_import
import
numpy
as
np
from
.
import
_backend
from
..
import
_make
,
ir_pass
from
..
import
_make
,
ir_pass
,
transform
from
..
import
module
from
...
import
register_func
,
nd
from
..base
import
NodeBase
,
register_relay_node
from
..expr
import
Tuple
,
RefCreate
,
Call
,
Constant
,
GlobalVar
,
Function
,
const
...
...
@@ -191,14 +192,14 @@ class Executor(object):
return
tuple
(
cargs
)
def
_make_executor
(
self
,
_
):
def
_make_executor
(
self
,
expr
=
None
):
"""
Construct a Python function that implements the evaluation
of expression.
Parameters
----------
expr:
relay.Expr
expr:
Optional[relay.Expr]
The Relay expression to execute.
Returns
...
...
@@ -208,16 +209,16 @@ class Executor(object):
"""
raise
NotImplementedError
()
def
evaluate
(
self
,
expr
,
binds
=
None
):
def
evaluate
(
self
,
expr
=
None
,
binds
=
None
):
"""
Evaluate a Relay expression on the executor.
Parameters
----------
expr:
tvm.relay.Expr
expr:
Optional[tvm.relay.Expr]
The expression to evaluate.
binds:
Map[tvm.relay.Var, tvm.relay.Expr
]
binds:
Optional[Map[tvm.relay.Var, tvm.relay.Expr]
]
Additional binding of free variable.
Returns
...
...
@@ -232,6 +233,9 @@ class Executor(object):
scope_builder
.
ret
(
expr
)
expr
=
scope_builder
.
get
()
if
not
expr
:
return
self
.
_make_executor
()
if
isinstance
(
expr
,
Function
):
assert
not
ir_pass
.
free_vars
(
expr
)
...
...
@@ -264,46 +268,47 @@ class Interpreter(Executor):
self
.
target
=
target
self
.
_intrp
=
_backend
.
CreateInterpreter
(
mod
,
ctx
,
target
)
def
optimize
(
self
,
expr
):
"""Optimize an expr.
Parameters
----------
expr : Expr
The expression to be optimized.
def
optimize
(
self
):
"""Optimize functions in a module.
Returns
-------
opt_
expr : Expr
The optimized
expression
.
opt_
mod : tvm.relay.Module
The optimized
module
.
"""
# TODO: We need to move this optimization code into the optimizer/pass manager
wrapped_expr
=
expr
if
isinstance
(
expr
,
Function
)
else
Function
([],
expr
)
if
self
.
mod
:
self
.
mod
[
self
.
mod
.
entry_func
]
=
wrapped_expr
ck_expr
=
ir_pass
.
infer_type
(
wrapped_expr
,
mod
=
self
.
mod
)
simp_expr
=
ir_pass
.
simplify_inference
(
ck_expr
)
ck_simp
=
ir_pass
.
infer_type
(
simp_expr
,
mod
=
self
.
mod
)
fused_expr
=
ir_pass
.
fuse_ops
(
ck_simp
,
0
,
mod
=
self
.
mod
)
ck_fused
=
ir_pass
.
infer_type
(
fused_expr
,
mod
=
self
.
mod
)
return
ck_fused
if
isinstance
(
expr
,
Function
)
else
Call
(
ck_fused
,
[])
def
_make_executor
(
self
,
expr
):
seq
=
transform
.
Sequential
([
transform
.
SimplifyInference
(),
transform
.
FuseOps
(
0
),
transform
.
InferType
()])
return
seq
(
self
.
mod
)
def
_make_executor
(
self
,
expr
=
None
):
if
expr
is
None
or
isinstance
(
expr
,
GlobalVar
):
assert
self
.
mod
is
not
None
def
_interp_wrapper
(
*
args
,
**
kwargs
):
args
=
self
.
_convert_args
(
expr
,
args
,
kwargs
)
if
expr
is
None
:
args
=
self
.
_convert_args
(
self
.
mod
[
self
.
mod
.
entry_func
],
args
,
kwargs
)
else
:
args
=
self
.
_convert_args
(
expr
,
args
,
kwargs
)
relay_args
=
[]
for
arg
in
args
:
relay_args
.
append
(
_arg_to_ast
(
arg
))
if
isinstance
(
expr
,
GlobalVar
):
func
=
self
.
mod
[
expr
]
func
=
self
.
optimize
(
func
)
self
.
mod
.
_add
(
expr
,
func
,
True
)
opt_expr
=
Call
(
expr
,
relay_args
)
return
self
.
_intrp
(
opt_expr
)
# Set the entry function for the module.
if
expr
is
None
:
pass
elif
isinstance
(
expr
,
GlobalVar
):
self
.
mod
[
self
.
mod
.
entry_func
]
=
self
.
mod
[
expr
]
else
:
call
=
Call
(
expr
,
relay_args
)
opt_expr
=
self
.
optimize
(
call
)
return
self
.
_intrp
(
opt_expr
)
assert
isinstance
(
expr
,
Function
)
func
=
Function
([],
Call
(
expr
,
relay_args
))
relay_args
=
[]
if
self
.
mod
:
self
.
mod
[
self
.
mod
.
entry_func
]
=
func
else
:
self
.
mod
=
module
.
Module
.
from_expr
(
func
)
mod
=
self
.
optimize
()
opt_expr
=
Call
(
mod
[
self
.
mod
.
entry_func
.
name_hint
],
relay_args
)
return
self
.
_intrp
(
opt_expr
)
return
_interp_wrapper
python/tvm/relay/backend/vm.py
View file @
fa351045
...
...
@@ -130,9 +130,11 @@ class VMExecutor(Executor):
self
.
ctx
=
ctx
self
.
target
=
target
def
_make_executor
(
self
,
expr
):
assert
isinstance
(
expr
,
Expr
)
self
.
mod
[
self
.
mod
.
entry_func
]
=
expr
def
_make_executor
(
self
,
expr
=
None
):
expr
=
expr
if
expr
else
self
.
mod
assert
expr
,
"either expr or self.mod should be not null."
if
isinstance
(
expr
,
Expr
):
self
.
mod
[
self
.
mod
.
entry_func
]
=
expr
main
=
self
.
mod
[
self
.
mod
.
entry_func
]
def
_vm_wrapper
(
*
args
,
**
kwargs
):
...
...
python/tvm/relay/build_module.py
View file @
fa351045
...
...
@@ -219,16 +219,19 @@ class GraphExecutor(_interpreter.Executor):
self
.
ctx
=
ctx
self
.
target
=
target
def
_make_executor
(
self
,
func
):
ret_type
=
ir_pass
.
infer_type
(
func
)
.
ret_type
def
_make_executor
(
self
,
expr
=
None
):
if
not
expr
:
assert
self
.
mod
,
"either expr or self.mod should be not null."
expr
=
self
.
mod
[
self
.
mod
.
entry_func
]
ret_type
=
ir_pass
.
infer_type
(
expr
)
.
ret_type
num_outputs
=
len
(
ret_type
.
fields
)
if
isinstance
(
ret_type
,
_ty
.
TupleType
)
else
1
graph_json
,
mod
,
params
=
build
(
func
,
target
=
self
.
target
)
graph_json
,
mod
,
params
=
build
(
expr
,
target
=
self
.
target
)
gmodule
=
_graph_rt
.
create
(
graph_json
,
mod
,
self
.
ctx
)
if
params
:
gmodule
.
set_input
(
**
params
)
def
_graph_wrapper
(
*
args
,
**
kwargs
):
args
=
self
.
_convert_args
(
func
,
args
,
kwargs
)
args
=
self
.
_convert_args
(
expr
,
args
,
kwargs
)
# Create map of inputs.
for
i
,
arg
in
enumerate
(
args
):
gmodule
.
set_input
(
i
,
arg
)
...
...
python/tvm/relay/frontend/caffe2.py
View file @
fa351045
...
...
@@ -20,6 +20,7 @@ from __future__ import absolute_import as _abs
import
tvm
from
..
import
ir_pass
from
..
import
expr
as
_expr
from
..
import
module
as
_module
from
..
import
op
as
_op
from
...
import
nd
as
_nd
from
.common
import
AttrCvt
,
Renamer
...
...
@@ -382,6 +383,7 @@ class Caffe2NetDef(object):
self
.
_ops
=
{}
self
.
_shape
=
shape
self
.
_dtype
=
dtype
self
.
_mod
=
_module
.
Module
({})
def
from_caffe2
(
self
,
init_net
,
predict_net
):
"""Construct Relay expression from caffe2 graph.
...
...
@@ -393,8 +395,9 @@ class Caffe2NetDef(object):
Returns
-------
func : tvm.relay.expr.Function
Compatible relay function
mod : tvm.relay.Module
The module that optimizations will be performed on.
params : dict
A dict of name: tvm.nd.array pairs, used as pretrained weights
"""
...
...
@@ -448,8 +451,9 @@ class Caffe2NetDef(object):
outputs
=
out
[
0
]
func
=
_expr
.
Function
(
ir_pass
.
free_vars
(
outputs
),
outputs
)
self
.
_mod
[
self
.
_mod
.
entry_func
]
=
func
return
func
,
self
.
_params
return
self
.
_mod
,
self
.
_params
def
_get_node
(
self
,
blob
):
"""Get the Symbol of blob and detect cyclic dependency in the graph."""
...
...
@@ -560,8 +564,8 @@ def from_caffe2(init_net, predict_net, shape=None, dtype="float32"):
Returns
-------
sym : tvm.relay.expr.Function
Compatible relay function
mod : tvm.relay.Module
The module that optimizations will be performed on.
params : dict of str to tvm.ndarray
Dict of converted parameters stored in tvm.ndarray format
...
...
python/tvm/relay/frontend/coreml.py
View file @
fa351045
...
...
@@ -21,6 +21,7 @@ import numpy as np
import
tvm
from
..
import
ir_pass
from
..
import
expr
as
_expr
from
..
import
module
as
_module
from
..
import
op
as
_op
from
...
import
nd
as
_nd
from
..._ffi
import
base
as
_base
...
...
@@ -416,8 +417,8 @@ def from_coreml(model, shape=None):
Returns
-------
func : tvm.relay.Function
Compatible relay Func
tion.
mod : tvm.relay.Module
The relay module for compila
tion.
params : dict of str to tvm.NDArray
The parameter dict to be used by Relay.
...
...
@@ -463,4 +464,4 @@ def from_coreml(model, shape=None):
outexpr
=
outexpr
[
0
]
func
=
_expr
.
Function
(
ir_pass
.
free_vars
(
outexpr
),
outexpr
)
params
=
{
k
:
_nd
.
array
(
np
.
array
(
v
,
dtype
=
np
.
float32
))
for
k
,
v
in
etab
.
params
.
items
()}
return
func
,
params
return
_module
.
Module
.
from_expr
(
func
)
,
params
python/tvm/relay/frontend/darknet.py
View file @
fa351045
...
...
@@ -25,6 +25,7 @@ import numpy as np
import
tvm
from
..
import
ir_pass
from
..
import
expr
as
_expr
from
..
import
module
as
_module
from
.common
import
get_relay_op
,
new_var
__all__
=
[
'from_darknet'
]
...
...
@@ -820,7 +821,7 @@ class GraphProto(object):
outputs
=
_as_list
(
sym
)
+
self
.
_outs
outputs
=
outputs
[
0
]
if
len
(
outputs
)
==
1
else
_expr
.
Tuple
(
outputs
)
sym
=
_expr
.
Function
(
ir_pass
.
free_vars
(
outputs
),
outputs
)
return
sym
,
self
.
_tvmparams
return
_module
.
Module
.
from_expr
(
sym
)
,
self
.
_tvmparams
def
from_darknet
(
net
,
shape
=
None
,
...
...
@@ -838,8 +839,9 @@ def from_darknet(net,
Returns
-------
sym : tvm.relay.Function
Compatible relay Function
mod : tvm.relay.Module
The relay module for compilation.
params : dict of str to tvm.NDArray
The parameter dict to be used by relay
"""
...
...
python/tvm/relay/frontend/keras.py
View file @
fa351045
...
...
@@ -22,6 +22,7 @@ import numpy as np
import
tvm
from
..
import
ir_pass
from
..
import
expr
as
_expr
from
..
import
module
as
_module
from
..
import
op
as
_op
from
...
import
nd
as
_nd
from
.common
import
ExprTable
,
new_var
...
...
@@ -679,8 +680,8 @@ def from_keras(model, shape=None):
Returns
-------
func : tvm.relay.Function
Compatible relay Func
tion.
mod : tvm.relay.Module
The relay module for compila
tion.
params : dict of str to tvm.NDArray
The parameter dict to be used by Relay.
...
...
@@ -744,4 +745,4 @@ def from_keras(model, shape=None):
outexpr
=
outexpr
[
0
]
if
len
(
outexpr
)
==
1
else
_expr
.
Tuple
(
outexpr
)
func
=
_expr
.
Function
(
ir_pass
.
free_vars
(
outexpr
),
outexpr
)
params
=
{
k
:
_nd
.
array
(
np
.
array
(
v
,
dtype
=
np
.
float32
))
for
k
,
v
in
etab
.
params
.
items
()}
return
func
,
params
return
_module
.
Module
.
from_expr
(
func
)
,
params
python/tvm/relay/frontend/mxnet.py
View file @
fa351045
...
...
@@ -23,6 +23,7 @@ import tvm
from
..
import
ir_pass
from
..
import
expr
as
_expr
from
..
import
op
as
_op
from
..
import
module
as
_module
from
...
import
nd
as
_nd
from
.common
import
StrAttrsDict
...
...
@@ -992,7 +993,8 @@ _convert_map = {
_convert_map
.
update
({
k
:
_rename
(
k
)
for
k
in
_identity_list
})
def
_from_mxnet_impl
(
symbol
,
shape_dict
,
dtype_info
):
def
_from_mxnet_impl
(
symbol
,
shape_dict
,
dtype_info
,
mod
=
None
):
#pylint: disable=unused-argument
"""Convert mxnet symbol to compatible relay Function.
Reconstruct a relay Function by traversing the mxnet symbol.
...
...
@@ -1009,6 +1011,10 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info):
dtype_info : dict or str.
Known parameter dtypes
mod : tvm.relay.Module
The module that contains global information. It will be used for
converting ops that need global information, e.g. control-flow ops.
Returns:
-------
func : tvm.relay.Function
...
...
@@ -1097,8 +1103,8 @@ def from_mxnet(symbol,
Returns
-------
sym : tvm.relay.Function
Compatible relay Func
tion
mod : tvm.relay.Module
The relay module for compila
tion
params : dict of str to tvm.NDArray
The parameter dict to be used by nnvm
...
...
@@ -1108,6 +1114,7 @@ def from_mxnet(symbol,
except
ImportError
as
e
:
raise
ImportError
(
"{}. MXNet is required to parse symbols."
.
format
(
e
))
mod
=
_module
.
Module
()
if
isinstance
(
symbol
,
mx
.
sym
.
Symbol
):
params
=
{}
arg_params
=
arg_params
if
arg_params
else
{}
...
...
@@ -1117,7 +1124,7 @@ def from_mxnet(symbol,
for
k
,
v
in
aux_params
.
items
():
params
[
k
]
=
_nd
.
array
(
v
.
asnumpy
())
shape
,
dtype
=
_update_shape_dtype
(
shape
,
dtype
,
params
)
sym
=
_from_mxnet_impl
(
symbol
,
shape
,
dtype
)
func
=
_from_mxnet_impl
(
symbol
,
shape
,
dtype
,
mod
)
elif
isinstance
(
symbol
,
mx
.
gluon
.
HybridBlock
):
if
arg_params
is
not
None
or
aux_params
is
not
None
:
raise
ValueError
(
"arg_params and aux_params ae not used when importing HybridBlock"
)
...
...
@@ -1129,10 +1136,11 @@ def from_mxnet(symbol,
if
isinstance
(
sym
,
(
list
,
tuple
)):
sym
=
mx
.
sym
.
Group
(
sym
)
shape
,
dtype
=
_update_shape_dtype
(
shape
,
dtype
,
params
)
sym
=
_from_mxnet_impl
(
sym
,
shape
,
dtype
)
func
=
_from_mxnet_impl
(
sym
,
shape
,
dtype
,
mod
)
elif
isinstance
(
symbol
,
mx
.
gluon
.
Block
):
raise
NotImplementedError
(
"Only Hybrid Blocks are supported now."
)
else
:
msg
=
"mxnet.Symbol or gluon.HybridBlock expected, got {}"
.
format
(
type
(
symbol
))
raise
ValueError
(
msg
)
return
sym
,
params
mod
[
mod
.
entry_func
]
=
func
return
mod
,
params
python/tvm/relay/frontend/onnx.py
View file @
fa351045
...
...
@@ -24,6 +24,7 @@ import tvm
from
...
import
nd
as
_nd
from
..
import
ir_pass
from
..
import
expr
as
_expr
from
..
import
module
as
_module
from
..
import
op
as
_op
from
.common
import
AttrCvt
,
Renamer
from
.common
import
get_relay_op
,
new_var
,
infer_shape
,
infer_channels
,
get_name
...
...
@@ -999,8 +1000,9 @@ class GraphProto(object):
Returns
-------
sym : tvm.relay.expr.Function
The returned relay function
mod : tvm.relay.Module
The returned relay module
params : dict
A dict of name: tvm.nd.array pairs, used as pretrained weights
"""
...
...
@@ -1090,7 +1092,7 @@ class GraphProto(object):
outputs
=
[
self
.
_nodes
[
self
.
_parse_value_proto
(
i
)]
for
i
in
graph
.
output
]
outputs
=
outputs
[
0
]
if
len
(
outputs
)
==
1
else
_expr
.
Tuple
(
outputs
)
func
=
_expr
.
Function
(
ir_pass
.
free_vars
(
outputs
),
outputs
)
return
func
,
self
.
_params
return
_module
.
Module
.
from_expr
(
func
)
,
self
.
_params
def
_parse_value_proto
(
self
,
value_proto
):
"""Parse ValueProto or raw str."""
...
...
@@ -1219,8 +1221,8 @@ def from_onnx(model,
Returns
-------
sym : tvm.relay.expr.Function
Compatible relay func
tion
mod : tvm.relay.Module
The relay module for compila
tion
params : dict of str to tvm.NDArray
The parameter dict to be used by relay
...
...
@@ -1243,5 +1245,5 @@ def from_onnx(model,
opset
=
model
.
opset_import
[
0
]
.
version
if
model
.
opset_import
else
1
except
AttributeError
:
opset
=
1
sym
,
params
=
g
.
from_onnx
(
graph
,
opset
)
return
sym
,
params
mod
,
params
=
g
.
from_onnx
(
graph
,
opset
)
return
mod
,
params
python/tvm/relay/frontend/tensorflow.py
View file @
fa351045
...
...
@@ -31,6 +31,7 @@ from .. import ir_pass
from
..
import
expr
as
_expr
from
..
import
op
as
_op
from
..expr_functor
import
ExprMutator
from
..
import
module
as
_module
__all__
=
[
'from_tensorflow'
]
...
...
@@ -1823,6 +1824,7 @@ class GraphProto(object):
self
.
_input_shapes
=
{}
self
.
_loops
=
{}
self
.
_branches
=
{}
self
.
_mod
=
_module
.
Module
({})
def
from_tensorflow
(
self
,
graph
,
layout
=
"NHWC"
,
shape
=
None
,
outputs
=
None
):
"""Construct relay nodes from tensorflow graph definition - GraphDef.
...
...
@@ -1856,8 +1858,9 @@ class GraphProto(object):
Returns
-------
sym : relay.op
The returned relay operator
mod : tvm.relay.Module
The module that optimizations will be performed on.
params : dict
A dict of name: tvm.nd.array pairs, used as pretrained weights
"""
...
...
@@ -2046,8 +2049,8 @@ class GraphProto(object):
out
=
out
[
0
]
if
len
(
out
)
==
1
else
_expr
.
Tuple
(
out
)
func
=
_expr
.
Function
(
ir_pass
.
free_vars
(
out
),
out
)
return
func
,
self
.
_params
self
.
_mod
[
self
.
_mod
.
entry_func
]
=
func
return
self
.
_mod
,
self
.
_params
def
_parse_import_prerequisites
(
self
,
graph
):
""" Calculate the named preconditions from TensorFlow `graph`.
...
...
@@ -2336,12 +2339,12 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None):
Returns
-------
sym : relay.op
Compatible relay operator
mod : tvm.relay.Module
The module that optimizations will be performed on.
params : dict of str to tvm.ndarray
Dict of converted parameters stored in tvm.ndarray format
"""
g
=
GraphProto
()
sym
,
params
=
g
.
from_tensorflow
(
graph
,
layout
,
shape
,
outputs
)
return
sym
,
params
mod
,
params
=
g
.
from_tensorflow
(
graph
,
layout
,
shape
,
outputs
)
return
mod
,
params
python/tvm/relay/frontend/tflite.py
View file @
fa351045
...
...
@@ -22,6 +22,7 @@ import numpy as np
import
tvm
from
..
import
ir_pass
from
..
import
expr
as
_expr
from
..
import
module
as
_module
from
..
import
op
as
_op
from
...
import
nd
as
_nd
from
.common
import
ExprTable
...
...
@@ -749,8 +750,8 @@ def from_tflite(model, shape_dict, dtype_dict):
Returns
-------
func : tvm.relay.Function
Compatible relay Function
mod : tvm.relay.Module
The relay module for compilation.
params : dict of str to tvm.NDArray
The parameter dict to be used by relay
...
...
@@ -788,4 +789,4 @@ def from_tflite(model, shape_dict, dtype_dict):
outputs
=
[
exp_tab
.
get_expr
(
get_tensor_name
(
subgraph
,
i
))
for
i
in
model_outputs
]
outputs
=
outputs
[
0
]
if
len
(
outputs
)
==
1
else
_expr
.
Tuple
(
outputs
)
func
=
_expr
.
Function
(
ir_pass
.
free_vars
(
outputs
),
outputs
)
return
func
,
params
return
_module
.
Module
.
from_expr
(
func
)
,
params
tests/python/frontend/caffe2/test_forward.py
View file @
fa351045
...
...
@@ -40,9 +40,10 @@ def get_tvm_output(model,
input_names
=
model
.
predict_net
.
op
[
0
]
.
input
[
0
]
shape_dict
=
{
input_names
:
input_data
.
shape
}
dtype_dict
=
{
input_names
:
input_data
.
dtype
}
func
,
params
=
relay
.
frontend
.
from_caffe2
(
model
.
init_net
,
model
.
predict_net
,
shape_dict
,
dtype_dict
)
mod
,
params
=
relay
.
frontend
.
from_caffe2
(
model
.
init_net
,
model
.
predict_net
,
shape_dict
,
dtype_dict
)
with
relay
.
build_config
(
opt_level
=
3
):
graph
,
lib
,
params
=
relay
.
build
(
func
,
target
,
params
=
params
)
graph
,
lib
,
params
=
relay
.
build
(
mod
[
mod
.
entry_func
]
,
target
,
params
=
params
)
m
=
graph_runtime
.
create
(
graph
,
lib
,
ctx
)
...
...
tests/python/frontend/caffe2/test_graph.py
View file @
fa351045
...
...
@@ -28,9 +28,10 @@ def compare_graph(f1, f2):
def
test_squeeze_net
():
shape_dict
=
{
'data'
:
(
1
,
3
,
224
,
224
)}
dtype_dict
=
{
'data'
:
'float32'
}
from_c2_func
,
_
=
relay
.
frontend
.
from_caffe2
(
c2_squeezenet
.
init_net
,
c2_squeezenet
.
predict_net
,
shape_dict
,
dtype_dict
)
mod
,
_
,
=
relay
.
frontend
.
from_caffe2
(
c2_squeezenet
.
init_net
,
c2_squeezenet
.
predict_net
,
shape_dict
,
dtype_dict
)
relay_func
,
_
=
relay_squeezenet
()
compare_graph
(
from_c2_func
,
relay_func
)
compare_graph
(
mod
[
mod
.
entry_func
]
,
relay_func
)
if
__name__
==
'__main__'
:
...
...
tests/python/frontend/coreml/test_forward.py
View file @
fa351045
...
...
@@ -46,9 +46,9 @@ def run_model_checkonly(model_file, model_name='', input_name='image'):
model
=
cm
.
models
.
MLModel
(
model_file
)
x
=
model_zoo
.
get_cat_image
()
shape_dict
=
{
input_name
:
x
.
shape
}
func
,
params
=
relay
.
frontend
.
from_coreml
(
model
,
shape_dict
)
mod
,
params
=
relay
.
frontend
.
from_coreml
(
model
,
shape_dict
)
for
target
,
ctx
in
ctx_list
():
tvm_output
=
get_tvm_output
(
func
,
x
,
params
,
target
,
ctx
)
tvm_output
=
get_tvm_output
(
mod
[
mod
.
entry_func
]
,
x
,
params
,
target
,
ctx
)
print
(
target
,
ctx
,
model_name
,
'prediction id: '
,
np
.
argmax
(
tvm_output
.
flat
))
def
test_mobilenet_checkonly
():
...
...
@@ -71,9 +71,9 @@ def run_tvm_graph(coreml_model, target, ctx, input_data, input_name, output_shap
shape_dict
=
{
input_name
:
input_data
.
shape
}
dtype_dict
=
{
input_name
:
input_data
.
dtype
}
func
,
params
=
relay
.
frontend
.
from_coreml
(
coreml_model
,
shape_dict
)
mod
,
params
=
relay
.
frontend
.
from_coreml
(
coreml_model
,
shape_dict
)
with
relay
.
transform
.
build_config
(
opt_level
=
3
):
graph
,
lib
,
params
=
relay
.
build
(
func
,
target
,
params
=
params
)
graph
,
lib
,
params
=
relay
.
build
(
mod
[
mod
.
entry_func
]
,
target
,
params
=
params
)
from
tvm.contrib
import
graph_runtime
m
=
graph_runtime
.
create
(
graph
,
lib
,
ctx
)
...
...
tests/python/frontend/darknet/test_forward.py
View file @
fa351045
...
...
@@ -52,10 +52,12 @@ def _read_memory_buffer(shape, data, dtype='float32'):
def
_get_tvm_output
(
net
,
data
,
build_dtype
=
'float32'
,
states
=
None
):
'''Compute TVM output'''
dtype
=
'float32'
sym
,
params
=
relay
.
frontend
.
from_darknet
(
net
,
data
.
shape
,
dtype
)
mod
,
params
=
relay
.
frontend
.
from_darknet
(
net
,
data
.
shape
,
dtype
)
target
=
'llvm'
shape_dict
=
{
'data'
:
data
.
shape
}
graph
,
library
,
params
=
relay
.
build
(
sym
,
target
,
params
=
params
)
graph
,
library
,
params
=
relay
.
build
(
mod
[
mod
.
entry_func
],
target
,
params
=
params
)
# Execute on TVM
ctx
=
tvm
.
cpu
(
0
)
...
...
tests/python/frontend/keras/test_forward.py
View file @
fa351045
...
...
@@ -42,9 +42,11 @@ def verify_keras_frontend(keras_model, need_transpose=True):
def
get_tvm_output
(
xs
,
target
,
ctx
,
dtype
=
'float32'
):
shape_dict
=
{
name
:
x
.
shape
for
(
name
,
x
)
in
zip
(
keras_model
.
input_names
,
xs
)}
func
,
params
=
relay
.
frontend
.
from_keras
(
keras_model
,
shape_dict
)
mod
,
params
=
relay
.
frontend
.
from_keras
(
keras_model
,
shape_dict
)
with
relay
.
transform
.
build_config
(
opt_level
=
2
):
graph
,
lib
,
params
=
relay
.
build
(
func
,
target
,
params
=
params
)
graph
,
lib
,
params
=
relay
.
build
(
mod
[
mod
.
entry_func
],
target
,
params
=
params
)
m
=
graph_runtime
.
create
(
graph
,
lib
,
ctx
)
for
name
,
x
in
zip
(
keras_model
.
input_names
,
xs
):
m
.
set_input
(
name
,
tvm
.
nd
.
array
(
x
.
astype
(
dtype
)))
...
...
tests/python/frontend/mxnet/test_forward.py
View file @
fa351045
This diff is collapsed.
Click to expand it.
tests/python/frontend/mxnet/test_graph.py
View file @
fa351045
...
...
@@ -26,60 +26,60 @@ def compare_graph(f1, f2):
def
test_mlp
():
shape
=
{
"data"
:
(
1
,
1
,
28
,
28
)}
mx_fun
=
model_zoo
.
mx_mlp
()
from_mx_fun
,
_
=
relay
.
frontend
.
from_mxnet
(
mx_fun
,
shape
=
shape
)
mod
,
_
=
relay
.
frontend
.
from_mxnet
(
mx_fun
,
shape
=
shape
)
relay_fun
=
model_zoo
.
relay_mlp
()
compare_graph
(
from_mx_fun
,
relay_fun
)
compare_graph
(
mod
[
mod
.
entry_func
]
,
relay_fun
)
def
test_vgg
():
shape
=
{
"data"
:
(
1
,
3
,
224
,
224
)}
for
n
in
[
11
,
13
,
16
,
19
]:
mx_sym
=
model_zoo
.
mx_vgg
(
n
)
from_mx_sym
,
_
=
relay
.
frontend
.
from_mxnet
(
mx_sym
,
shape
=
shape
)
mod
,
_
=
relay
.
frontend
.
from_mxnet
(
mx_sym
,
shape
=
shape
)
relay_sym
=
model_zoo
.
relay_vgg
(
n
)
compare_graph
(
from_mx_sym
,
relay_sym
)
compare_graph
(
mod
[
mod
.
entry_func
]
,
relay_sym
)
def
test_resnet
():
shape
=
{
"data"
:
(
1
,
3
,
224
,
224
)}
for
n
in
[
18
,
34
,
50
,
101
]:
mx_sym
=
model_zoo
.
mx_resnet
(
n
)
from_mx_sym
,
_
=
relay
.
frontend
.
from_mxnet
(
mx_sym
,
shape
=
shape
)
mod
,
_
=
relay
.
frontend
.
from_mxnet
(
mx_sym
,
shape
=
shape
)
relay_sym
=
model_zoo
.
relay_resnet
(
n
)
compare_graph
(
from_mx_sym
,
relay_sym
)
compare_graph
(
mod
[
mod
.
entry_func
]
,
relay_sym
)
def
test_squeezenet
():
shape
=
{
"data"
:
(
1
,
3
,
224
,
224
)}
for
version
in
[
'1.0'
,
'1.1'
]:
mx_sym
=
model_zoo
.
mx_squeezenet
(
version
)
from_mx_sym
,
_
=
relay
.
frontend
.
from_mxnet
(
mx_sym
,
shape
)
mod
,
_
=
relay
.
frontend
.
from_mxnet
(
mx_sym
,
shape
)
relay_sym
=
model_zoo
.
relay_squeezenet
(
version
)
compare_graph
(
from_mx_sym
,
relay_sym
)
compare_graph
(
mod
[
mod
.
entry_func
]
,
relay_sym
)
def
test_inception_v3
():
shape
=
{
"data"
:
(
1
,
3
,
299
,
299
)}
mx_sym
=
model_zoo
.
mx_inception_v3
()
from_mx_sym
,
_
=
relay
.
frontend
.
from_mxnet
(
mx_sym
,
shape
)
mod
,
_
=
relay
.
frontend
.
from_mxnet
(
mx_sym
,
shape
)
relay_sym
=
model_zoo
.
relay_inception_v3
()
compare_graph
(
from_mx_sym
,
relay_sym
)
compare_graph
(
mod
[
mod
.
entry_func
]
,
relay_sym
)
def
test_dqn
():
shape
=
{
"data"
:
(
1
,
4
,
84
,
84
)}
mx_sym
=
model_zoo
.
mx_dqn
()
from_mx_sym
,
_
=
relay
.
frontend
.
from_mxnet
(
mx_sym
,
shape
)
mod
,
_
=
relay
.
frontend
.
from_mxnet
(
mx_sym
,
shape
)
relay_sym
=
model_zoo
.
relay_dqn
()
compare_graph
(
from_mx_sym
,
relay_sym
)
compare_graph
(
mod
[
mod
.
entry_func
]
,
relay_sym
)
def
test_dcgan
():
shape
=
{
"data"
:
(
2
,
100
)}
mx_sym
=
model_zoo
.
mx_dcgan
()
from_mx_sym
,
_
=
relay
.
frontend
.
from_mxnet
(
mx_sym
,
shape
)
mod
,
_
=
relay
.
frontend
.
from_mxnet
(
mx_sym
,
shape
)
relay_sym
=
model_zoo
.
relay_dcgan
(
batch_size
=
2
)
compare_graph
(
from_mx_sym
,
relay_sym
)
compare_graph
(
mod
[
mod
.
entry_func
]
,
relay_sym
)
def
test_multi_outputs
():
...
...
@@ -100,10 +100,10 @@ def test_multi_outputs():
return
relay
.
Function
(
relay
.
ir_pass
.
free_vars
(
z
),
z
)
mx_sym
=
mx_compose
(
mx
,
num_outputs
=
3
,
axis
=
1
)
from_mx_sym
,
_
=
relay
.
frontend
.
from_mxnet
(
mod
,
_
=
relay
.
frontend
.
from_mxnet
(
mx_sym
,
shape
=
{
"x"
:
xshape
,
"y"
:
yshape
})
relay_sym
=
relay_compose
(
relay
,
indices_or_sections
=
3
,
axis
=
1
)
compare_graph
(
from_mx_sym
,
relay_sym
)
compare_graph
(
mod
[
mod
.
entry_func
]
,
relay_sym
)
if
__name__
==
"__main__"
:
...
...
tests/python/frontend/onnx/test_forward.py
View file @
fa351045
...
...
@@ -42,9 +42,11 @@ def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output
shape_dict
=
{
input_names
:
input_data
.
shape
}
dtype_dict
=
{
input_names
:
input_data
.
dtype
}
sym
,
params
=
relay
.
frontend
.
from_onnx
(
graph_def
,
shape_dict
)
mod
,
params
=
relay
.
frontend
.
from_onnx
(
graph_def
,
shape_dict
)
with
relay
.
build_config
(
opt_level
=
1
):
graph
,
lib
,
params
=
relay
.
build
(
sym
,
target
,
params
=
params
)
graph
,
lib
,
params
=
relay
.
build
(
mod
[
mod
.
entry_func
],
target
,
params
=
params
)
ctx
=
tvm
.
cpu
(
0
)
from
tvm.contrib
import
graph_runtime
...
...
tests/python/frontend/tensorflow/test_control_flow.py
View file @
fa351045
...
...
@@ -22,9 +22,9 @@ from tvm.relay.frontend.tensorflow import from_tensorflow
def
check_equal
(
graph
,
tf_out
):
expr
,
params
=
from_tensorflow
(
graph
.
as_graph_def
(
add_shapes
=
True
))
ex
=
relay
.
create_executor
(
'debug'
)
relay_out
=
ex
.
evaluate
(
expr
)(
**
params
)
mod
,
params
=
from_tensorflow
(
graph
.
as_graph_def
(
add_shapes
=
True
))
ex
=
relay
.
create_executor
(
'debug'
,
mod
=
mod
)
relay_out
=
ex
.
evaluate
()(
**
params
)
if
isinstance
(
relay_out
,
relay
.
backend
.
interpreter
.
TensorValue
):
np
.
testing
.
assert_allclose
(
tf_out
,
relay_out
.
asnumpy
())
else
:
...
...
tests/python/frontend/tensorflow/test_forward.py
View file @
fa351045
...
...
@@ -60,13 +60,12 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
shape_dict
=
{
e
:
i
.
shape
for
e
,
i
in
zip
(
input_node
,
input_data
)}
sym
,
params
=
relay
.
frontend
.
from_tensorflow
(
graph_def
,
mod
,
params
=
relay
.
frontend
.
from_tensorflow
(
graph_def
,
layout
=
layout
,
shape
=
shape_dict
,
outputs
=
out_names
)
with
relay
.
build_config
(
opt_level
=
opt_level
):
graph
,
lib
,
params
=
relay
.
build
(
sym
,
target
,
target_host
,
params
)
graph
,
lib
,
params
=
relay
.
build
(
mod
[
mod
.
entry_func
]
,
target
,
target_host
,
params
)
ctx
=
tvm
.
context
(
target
,
0
)
from
tvm.contrib
import
graph_runtime
...
...
@@ -1442,14 +1441,16 @@ def test_forward_ptb():
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_c'
:(
num_layers
,
batch_size
,
num_hidden
),
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_h'
:(
num_layers
,
batch_size
,
num_hidden
)}
sym
,
params
=
relay
.
frontend
.
from_tensorflow
(
graph_def
,
shape
=
shape_dict
)
mod
,
params
=
relay
.
frontend
.
from_tensorflow
(
graph_def
,
shape
=
shape_dict
)
dtype_dict
=
{
'Model/Placeholder'
:
'int32'
,
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_c'
:
'float32'
,
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_h'
:
'float32'
}
target
=
'llvm'
with
relay
.
build_config
(
opt_level
=
0
):
graph
,
lib
,
params
=
relay
.
build
(
sym
,
target
,
params
=
params
)
graph
,
lib
,
params
=
relay
.
build
(
mod
[
mod
.
entry_func
],
target
,
params
=
params
)
from
tvm.contrib
import
graph_runtime
ctx
=
tvm
.
cpu
(
0
)
return
params
,
graph_runtime
.
create
(
graph
,
lib
,
ctx
)
...
...
tests/python/frontend/tflite/test_forward.py
View file @
fa351045
...
...
@@ -63,11 +63,13 @@ def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target
shape_dict
[
e
]
=
input_data
[
i
]
.
shape
dtype_dict
[
e
]
=
input_data
[
i
]
.
dtype
.
name
func
,
params
=
relay
.
frontend
.
from_tflite
(
tflite_model
,
shape_dict
=
shape_dict
,
dtype_dict
=
dtype_dict
)
mod
,
params
=
relay
.
frontend
.
from_tflite
(
tflite_model
,
shape_dict
=
shape_dict
,
dtype_dict
=
dtype_dict
)
with
relay
.
build_config
(
opt_level
=
3
):
graph
,
lib
,
params
=
relay
.
build
(
func
,
target
,
params
=
params
)
graph
,
lib
,
params
=
relay
.
build
(
mod
[
mod
.
entry_func
],
target
,
params
=
params
)
ctx
=
tvm
.
context
(
target
,
0
)
from
tvm.contrib
import
graph_runtime
...
...
tests/python/relay/test_vm.py
View file @
fa351045
...
...
@@ -35,9 +35,9 @@ def veval(f, *args, ctx=tvm.cpu()):
mod
=
f
ex
=
relay
.
create_executor
(
'vm'
,
mod
=
mod
,
ctx
=
ctx
)
if
len
(
args
)
==
0
:
return
ex
.
evaluate
(
mod
[
mod
.
entry_func
]
)
return
ex
.
evaluate
()
else
:
return
ex
.
evaluate
(
mod
[
mod
.
entry_func
]
)(
*
args
)
return
ex
.
evaluate
()(
*
args
)
def
test_split
():
x
=
relay
.
var
(
'x'
,
shape
=
(
12
,))
...
...
tutorials/frontend/deploy_model_on_android.py
View file @
fa351045
...
...
@@ -260,10 +260,10 @@ elif test_target == 'vulkan':
input_name
=
'input_1'
shape_dict
=
{
input_name
:
x
.
shape
}
func
,
params
=
relay
.
frontend
.
from_keras
(
keras_mobilenet_v2
,
shape_dict
)
mod
,
params
=
relay
.
frontend
.
from_keras
(
keras_mobilenet_v2
,
shape_dict
)
with
relay
.
build_config
(
opt_level
=
3
):
graph
,
lib
,
params
=
relay
.
build
(
func
,
target
=
target
,
graph
,
lib
,
params
=
relay
.
build
(
mod
[
mod
.
entry_func
]
,
target
=
target
,
target_host
=
target_host
,
params
=
params
)
# After `relay.build`, you will get three return values: graph,
...
...
tutorials/frontend/deploy_model_on_rasp.py
View file @
fa351045
...
...
@@ -140,8 +140,9 @@ with open(synset_path) as f:
# We support MXNet static graph(symbol) and HybridBlock in mxnet.gluon
shape_dict
=
{
'data'
:
x
.
shape
}
func
,
params
=
relay
.
frontend
.
from_mxnet
(
block
,
shape_dict
)
mod
,
params
=
relay
.
frontend
.
from_mxnet
(
block
,
shape_dict
)
# we want a probability so add a softmax operator
func
=
mod
[
mod
.
entry_func
]
func
=
relay
.
Function
(
func
.
params
,
relay
.
nn
.
softmax
(
func
.
body
),
None
,
func
.
type_params
,
func
.
attrs
)
######################################################################
...
...
tutorials/frontend/deploy_ssd_gluoncv.py
View file @
fa351045
...
...
@@ -76,9 +76,9 @@ x, img = data.transforms.presets.ssd.load_test(im_fname, short=512)
block
=
model_zoo
.
get_model
(
model_name
,
pretrained
=
True
)
def
build
(
target
):
net
,
params
=
relay
.
frontend
.
from_mxnet
(
block
,
{
"data"
:
dshape
})
mod
,
params
=
relay
.
frontend
.
from_mxnet
(
block
,
{
"data"
:
dshape
})
with
relay
.
build_config
(
opt_level
=
3
):
graph
,
lib
,
params
=
relay
.
build
(
net
,
target
,
params
=
params
)
graph
,
lib
,
params
=
relay
.
build
(
mod
[
mod
.
entry_func
]
,
target
,
params
=
params
)
return
graph
,
lib
,
params
######################################################################
...
...
tutorials/frontend/from_caffe2.py
View file @
fa351045
...
...
@@ -83,13 +83,13 @@ dtype_dict = {input_name: data.dtype}
# parse Caffe2 model and convert into Relay computation graph
from
tvm
import
relay
func
,
params
=
relay
.
frontend
.
from_caffe2
(
resnet50
.
init_net
,
resnet50
.
predict_net
,
shape_dict
,
dtype_dict
)
mod
,
params
=
relay
.
frontend
.
from_caffe2
(
resnet50
.
init_net
,
resnet50
.
predict_net
,
shape_dict
,
dtype_dict
)
# compile the model
# target x86 CPU
target
=
'llvm'
with
relay
.
build_config
(
opt_level
=
3
):
graph
,
lib
,
params
=
relay
.
build
(
func
,
target
,
params
=
params
)
graph
,
lib
,
params
=
relay
.
build
(
mod
[
mod
.
entry_func
]
,
target
,
params
=
params
)
######################################################################
# Execute on TVM
...
...
tutorials/frontend/from_coreml.py
View file @
fa351045
...
...
@@ -68,10 +68,12 @@ target = 'cuda'
shape_dict
=
{
'image'
:
x
.
shape
}
# Parse CoreML model and convert into Relay computation graph
func
,
params
=
relay
.
frontend
.
from_coreml
(
mlmodel
,
shape_dict
)
mod
,
params
=
relay
.
frontend
.
from_coreml
(
mlmodel
,
shape_dict
)
with
relay
.
build_config
(
opt_level
=
3
):
graph
,
lib
,
params
=
relay
.
build
(
func
,
target
,
params
=
params
)
graph
,
lib
,
params
=
relay
.
build
(
mod
[
mod
.
entry_func
],
target
,
params
=
params
)
######################################################################
# Execute on TVM
...
...
tutorials/frontend/from_darknet.py
View file @
fa351045
...
...
@@ -82,7 +82,7 @@ batch_size = 1
data
=
np
.
empty
([
batch_size
,
net
.
c
,
net
.
h
,
net
.
w
],
dtype
)
shape_dict
=
{
'data'
:
data
.
shape
}
print
(
"Converting darknet to relay functions..."
)
sym
,
params
=
relay
.
frontend
.
from_darknet
(
net
,
dtype
=
dtype
,
shape
=
data
.
shape
)
mod
,
params
=
relay
.
frontend
.
from_darknet
(
net
,
dtype
=
dtype
,
shape
=
data
.
shape
)
######################################################################
# Import the graph to Relay
...
...
@@ -95,7 +95,10 @@ data = np.empty([batch_size, net.c, net.h, net.w], dtype)
shape
=
{
'data'
:
data
.
shape
}
print
(
"Compiling the model..."
)
with
relay
.
build_config
(
opt_level
=
3
):
graph
,
lib
,
params
=
relay
.
build
(
sym
,
target
=
target
,
target_host
=
target_host
,
params
=
params
)
graph
,
lib
,
params
=
relay
.
build
(
mod
[
mod
.
entry_func
],
target
=
target
,
target_host
=
target_host
,
params
=
params
)
[
neth
,
netw
]
=
shape
[
'data'
][
2
:]
# Current image shape is 608x608
######################################################################
...
...
tutorials/frontend/from_keras.py
View file @
fa351045
...
...
@@ -74,18 +74,18 @@ print('input_1', data.shape)
# ----------------------------
# convert the keras model(NHWC layout) to Relay format(NCHW layout).
shape_dict
=
{
'input_1'
:
data
.
shape
}
func
,
params
=
relay
.
frontend
.
from_keras
(
keras_resnet50
,
shape_dict
)
mod
,
params
=
relay
.
frontend
.
from_keras
(
keras_resnet50
,
shape_dict
)
# compile the model
target
=
'cuda'
ctx
=
tvm
.
gpu
(
0
)
with
relay
.
build_config
(
opt_level
=
3
):
executor
=
relay
.
build_module
.
create_executor
(
'graph'
,
func
,
ctx
,
target
)
executor
=
relay
.
build_module
.
create_executor
(
'graph'
,
mod
,
ctx
,
target
)
######################################################################
# Execute on TVM
# ---------------
dtype
=
'float32'
tvm_out
=
executor
.
evaluate
(
func
)(
tvm
.
nd
.
array
(
data
.
astype
(
dtype
)),
**
params
)
tvm_out
=
executor
.
evaluate
()(
tvm
.
nd
.
array
(
data
.
astype
(
dtype
)),
**
params
)
top1_tvm
=
np
.
argmax
(
tvm_out
.
asnumpy
()[
0
])
#####################################################################
...
...
tutorials/frontend/from_mxnet.py
View file @
fa351045
...
...
@@ -82,8 +82,9 @@ print('x', x.shape)
# It's as easy as several lines.
# We support MXNet static graph(symbol) and HybridBlock in mxnet.gluon
shape_dict
=
{
'data'
:
x
.
shape
}
func
,
params
=
relay
.
frontend
.
from_mxnet
(
block
,
shape_dict
)
mod
,
params
=
relay
.
frontend
.
from_mxnet
(
block
,
shape_dict
)
## we want a probability so add a softmax operator
func
=
mod
[
mod
.
entry_func
]
func
=
relay
.
Function
(
func
.
params
,
relay
.
nn
.
softmax
(
func
.
body
),
None
,
func
.
type_params
,
func
.
attrs
)
######################################################################
...
...
@@ -132,6 +133,6 @@ mx.model.save_checkpoint('resnet18_v1', 0, mx_sym, args, auxs)
# for a normal mxnet model, we start from here
mx_sym
,
args
,
auxs
=
mx
.
model
.
load_checkpoint
(
'resnet18_v1'
,
0
)
# now we use the same API to get Relay computation graph
relay_func
,
relay_params
=
relay
.
frontend
.
from_mxnet
(
mx_sym
,
shape_dict
,
arg_params
=
args
,
aux_params
=
auxs
)
mod
,
relay_params
=
relay
.
frontend
.
from_mxnet
(
mx_sym
,
shape_dict
,
arg_params
=
args
,
aux_params
=
auxs
)
# repeat the same steps to run this model using TVM
tutorials/frontend/from_onnx.py
View file @
fa351045
...
...
@@ -71,16 +71,16 @@ target = 'llvm'
input_name
=
'1'
shape_dict
=
{
input_name
:
x
.
shape
}
sym
,
params
=
relay
.
frontend
.
from_onnx
(
onnx_model
,
shape_dict
)
mod
,
params
=
relay
.
frontend
.
from_onnx
(
onnx_model
,
shape_dict
)
with
relay
.
build_config
(
opt_level
=
1
):
intrp
=
relay
.
build_module
.
create_executor
(
'graph'
,
sym
,
tvm
.
cpu
(
0
),
target
)
intrp
=
relay
.
build_module
.
create_executor
(
'graph'
,
mod
,
tvm
.
cpu
(
0
),
target
)
######################################################################
# Execute on TVM
# ---------------------------------------------
dtype
=
'float32'
tvm_output
=
intrp
.
evaluate
(
sym
)(
tvm
.
nd
.
array
(
x
.
astype
(
dtype
)),
**
params
)
.
asnumpy
()
tvm_output
=
intrp
.
evaluate
()(
tvm
.
nd
.
array
(
x
.
astype
(
dtype
)),
**
params
)
.
asnumpy
()
######################################################################
# Display results
...
...
tutorials/frontend/from_tensorflow.py
View file @
fa351045
...
...
@@ -124,7 +124,9 @@ x = np.array(image)
# params: params converted from tensorflow params (tensor protobuf).
shape_dict
=
{
'DecodeJpeg/contents'
:
x
.
shape
}
dtype_dict
=
{
'DecodeJpeg/contents'
:
'uint8'
}
sym
,
params
=
relay
.
frontend
.
from_tensorflow
(
graph_def
,
layout
=
layout
,
shape
=
shape_dict
)
mod
,
params
=
relay
.
frontend
.
from_tensorflow
(
graph_def
,
layout
=
layout
,
shape
=
shape_dict
)
print
(
"Tensorflow protobuf imported to relay frontend."
)
######################################################################
...
...
@@ -138,7 +140,10 @@ print("Tensorflow protobuf imported to relay frontend.")
# lib: target library which can be deployed on target with TVM runtime.
with
relay
.
build_config
(
opt_level
=
3
):
graph
,
lib
,
params
=
relay
.
build
(
sym
,
target
=
target
,
target_host
=
target_host
,
params
=
params
)
graph
,
lib
,
params
=
relay
.
build
(
mod
[
mod
.
entry_func
],
target
=
target
,
target_host
=
target_host
,
params
=
params
)
######################################################################
# Execute the portable graph on TVM
...
...
tutorials/frontend/from_tflite.py
View file @
fa351045
...
...
@@ -138,14 +138,14 @@ input_dtype = "float32"
# parse TFLite model and convert into Relay computation graph
from
tvm
import
relay
func
,
params
=
relay
.
frontend
.
from_tflite
(
tflite_model
,
shape_dict
=
{
input_tensor
:
input_shape
},
dtype_dict
=
{
input_tensor
:
input_dtype
})
mod
,
params
=
relay
.
frontend
.
from_tflite
(
tflite_model
,
shape_dict
=
{
input_tensor
:
input_shape
},
dtype_dict
=
{
input_tensor
:
input_dtype
})
# target x86 CPU
target
=
"llvm"
with
relay
.
build_config
(
opt_level
=
3
):
graph
,
lib
,
params
=
relay
.
build
(
func
,
target
,
params
=
params
)
graph
,
lib
,
params
=
relay
.
build
(
mod
[
mod
.
entry_func
]
,
target
,
params
=
params
)
######################################################################
# Execute on TVM
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment