Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
56397826
Commit
56397826
authored
Jun 17, 2019
by
Zhi
Committed by
Tianqi Chen
Jun 17, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
hotfix for onnx (#3387)
parent
1119c40b
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
79 additions
and
16 deletions
+79
-16
python/tvm/relay/frontend/onnx.py
+31
-16
tests/python/frontend/onnx/test_forward.py
+48
-0
No files found.
python/tvm/relay/frontend/onnx.py
View file @
56397826
...
...
@@ -23,6 +23,7 @@ import numpy as np
import
tvm
from
...
import
nd
as
_nd
from
..
import
ir_pass
from
..
import
transform
as
_transform
from
..
import
expr
as
_expr
from
..
import
module
as
_module
from
..
import
op
as
_op
...
...
@@ -409,21 +410,27 @@ class Reshape(OnnxOpConverter):
shape
=
tuple
(
params
[
inputs
[
1
]
.
name_hint
]
.
asnumpy
())
out
=
_op
.
reshape
(
inputs
[
0
],
shape
)
else
:
# Try to infer shape by precompute prune if possible.
# TODO: good to check inputs to be in params.
# to be enhanced when relay support list_input_names API of NNVM
logging
.
warning
(
"Infering Reshape argument by precompute"
)
func
=
_expr
.
Function
(
ir_pass
.
free_vars
(
inputs
[
1
]),
inputs
[
1
])
data
,
shape
=
inputs
logging
.
warning
(
"Constant evaluating Reshape's shape argument, may reduce performance"
)
shape_params
=
ir_pass
.
free_vars
(
shape
)
func
=
_expr
.
Function
(
shape_params
,
shape
)
mod
=
_module
.
Module
.
from_expr
(
func
)
seq
=
_transform
.
Sequential
([
_transform
.
InferType
(),
_transform
.
FoldConstant
(),
_transform
.
FuseOps
(
0
),
_transform
.
InferType
()])
with
tvm
.
relay
.
PassContext
(
opt_level
=
2
):
mod
=
seq
(
mod
)
with
tvm
.
relay
.
build_config
(
opt_level
=
0
):
graph
,
lib
,
params
=
tvm
.
relay
.
build
(
func
,
target
=
"llvm"
,
params
=
params
)
ctx
=
tvm
.
context
(
"llvm"
,
0
)
from
tvm.contrib
import
graph_runtime
m
=
graph_runtime
.
create
(
graph
,
lib
,
ctx
)
m
.
set_input
(
**
params
)
m
.
run
()
params_new
=
m
.
get_output
(
0
)
inputs
.
pop
(
1
)
out
=
_op
.
reshape
(
inputs
[
0
],
tuple
(
params_new
.
asnumpy
()
.
astype
(
'int32'
)
.
flatten
()))
ex
=
tvm
.
relay
.
create_executor
(
"debug"
,
mod
=
mod
)
inputs
=
[]
for
sp
in
shape_params
:
if
not
sp
.
name_hint
in
params
:
sh
=
[
int
(
i
)
for
i
in
sp
.
type_annotation
.
shape
]
inputs
.
append
(
tvm
.
nd
.
array
(
np
.
random
.
rand
(
*
sh
)
.
astype
(
'float32'
))
)
static_shape
=
ex
.
evaluate
()(
*
inputs
,
**
params
)
out
=
_op
.
reshape
(
data
,
newshape
=
tuple
(
static_shape
.
asnumpy
()))
return
out
...
...
@@ -568,6 +575,7 @@ class Shape(OnnxOpConverter):
@classmethod
def
_impl_v1
(
cls
,
inputs
,
attr
,
params
):
# TODO(@jroesch): use shape_of once it has been fixed)
return
_op
.
shape_of
(
inputs
[
0
])
class
Cast
(
OnnxOpConverter
):
...
...
@@ -1058,8 +1066,15 @@ class GraphProto(object):
if
op_name
==
"Constant"
:
t_proto
=
self
.
_parse_attr
(
node
.
attribute
)[
"value"
]
self
.
_num_param
+=
1
self
.
_params
[
node
.
output
[
0
]]
=
self
.
_parse_array
(
t_proto
)
self
.
_nodes
[
node
.
output
[
0
]]
=
new_var
(
node
.
output
[
0
],
shape
=
list
(
t_proto
.
dims
))
# We should convert scalar integers to int32, to normalize.
array
=
self
.
_parse_array
(
t_proto
)
if
len
(
array
.
shape
)
==
0
and
array
.
dtype
==
'int64'
:
array
=
_nd
.
array
(
array
.
asnumpy
()
.
astype
(
'int32'
))
self
.
_params
[
node
.
output
[
0
]]
=
array
self
.
_nodes
[
node
.
output
[
0
]]
=
new_var
(
node
.
output
[
0
],
shape
=
list
(
t_proto
.
dims
),
dtype
=
array
.
dtype
)
else
:
if
op_name
==
"ConstantFill"
:
fill_value
=
attr
.
get
(
'value'
,
0.0
)
...
...
tests/python/frontend/onnx/test_forward.py
View file @
56397826
...
...
@@ -14,8 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import
attr
import
numpy
as
np
import
math
import
torch
import
torchvision
import
topi
import
topi.testing
import
tvm
...
...
@@ -1072,6 +1075,48 @@ def test_LogSoftmax():
'LogSoftmax'
,
{
'axis'
:
1
})
def
check_torch_conversion
(
model
,
input_size
):
dummy_input
=
torch
.
randn
(
*
input_size
)
file_name
=
'{}.onnx'
.
format
(
model
.
__name__
)
# Set verbose=True for more output
torch
.
onnx
.
export
(
model
(),
dummy_input
,
file_name
,
export_params
=
True
,
verbose
=
False
)
onnx_model
=
onnx
.
load
(
file_name
)
shapes
=
{
'0'
:
input_size
}
expr
,
params
=
relay
.
frontend
.
from_onnx
(
onnx_model
,
shape
=
shapes
)
def
test_resnet
():
check_torch_conversion
(
torchvision
.
models
.
resnet18
,
(
1
,
3
,
224
,
224
))
# check_torch_conversion(torchvision.models.resnet101, (1,3,224,224))
# def test_alexnet():
# Torch's ONNX export does not support the adaptive pooling used by AlexNet?
# check_torch_conversion(torchvision.models.alexnet, (1,3,224,224))
# Torch's ONNX export does not support the adaptive pooling used by vgg16?
# def test_vgg16():
# check_torch_conversion(torchvision.models.vgg16, (1,3,224,224))
# TODO(@jroesch): Update Torch + ONNX to support this import.
# def test_squeezenet():
# # Torch's ONNX export does not support the max pooling used by Squezenet
# check_torch_conversion(torchvision.models.squeezenet1_0, (1,3,224,224))
def
test_densenet
():
check_torch_conversion
(
torchvision
.
models
.
densenet161
,
(
1
,
3
,
224
,
224
))
def
test_inception
():
check_torch_conversion
(
torchvision
.
models
.
inception_v3
,
(
1
,
3
,
224
,
224
))
# TODO(@jroesch): Update Torch + ONNX to support this import.
# def test_googlenet():
# check_torch_conversion(torchvision.models.googlenet, (1,3,224,224))
# TODO(@jroesch): Update Torch + ONNX to support this import.
# def test_shufflenetv2():
# check_torch_conversion(torchvision.models.shufflenetv2, (1,3,224,224))
if
__name__
==
'__main__'
:
test_flatten
()
test_reshape
()
...
...
@@ -1111,3 +1156,6 @@ if __name__ == '__main__':
test_ParametricSoftplus
()
test_Scale
()
test_LogSoftmax
()
test_resnet
()
test_inception
()
test_densenet
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment