Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
dddb0ed0
Commit
dddb0ed0
authored
Nov 12, 2019
by
Haichen Shen
Committed by
Leyuan Wang
Nov 12, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add (#4311)
parent
83bac2d1
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
91 additions
and
3 deletions
+91
-3
python/tvm/relay/frontend/mxnet.py
+65
-3
tests/python/frontend/mxnet/test_forward.py
+26
-0
No files found.
python/tvm/relay/frontend/mxnet.py
View file @
dddb0ed0
...
...
@@ -20,10 +20,12 @@ from __future__ import absolute_import as _abs
import
json
import
tvm
from
topi.util
import
get_const_tuple
from
..
import
analysis
from
..
import
expr
as
_expr
from
..
import
op
as
_op
from
..
import
module
as
_module
from
..
import
scope_builder
as
_scope_builder
from
...
import
nd
as
_nd
from
.common
import
StrAttrsDict
...
...
@@ -1037,6 +1039,47 @@ def _mx_contrib_fifo_buffer(inputs, attrs):
new_attrs
[
'axis'
]
=
attrs
.
get_int
(
'axis'
)
return
_op
.
nn
.
fifo_buffer
(
*
inputs
,
**
new_attrs
)
def
_mx_cond
(
inputs
,
attrs
,
subgraphs
):
assert
len
(
subgraphs
)
==
3
cond_input_locs
=
json
.
loads
(
attrs
.
get_str
(
"cond_input_locs"
))
then_input_locs
=
json
.
loads
(
attrs
.
get_str
(
"then_input_locs"
))
else_input_locs
=
json
.
loads
(
attrs
.
get_str
(
"else_input_locs"
))
num_outputs
=
attrs
.
get_int
(
"num_outputs"
)
input_args
=
[]
for
i
,
arg
in
enumerate
(
inputs
):
var
=
_expr
.
var
(
"arg
%
s"
%
i
,
_infer_type
(
arg
)
.
checked_type
)
input_args
.
append
(
var
)
cond_args
=
[
input_args
[
i
]
for
i
in
cond_input_locs
]
then_args
=
[
input_args
[
i
]
for
i
in
then_input_locs
]
else_args
=
[
input_args
[
i
]
for
i
in
else_input_locs
]
cond_arg_shapes
=
[
arg
.
type_annotation
.
shape
for
arg
in
cond_args
]
cond_arg_dtype_info
=
[
arg
.
type_annotation
.
dtype
for
arg
in
cond_args
]
cond_func
=
_from_mxnet_impl
(
subgraphs
[
0
],
cond_arg_shapes
,
cond_arg_dtype_info
)
cond
=
_expr
.
Call
(
cond_func
,
cond_args
)
.
astype
(
"bool"
)
cond_shape
=
get_const_tuple
(
_infer_type
(
cond
)
.
checked_type
.
shape
)
if
len
(
cond_shape
)
>
0
:
assert
len
(
cond_shape
)
==
1
and
cond_shape
[
0
]
==
1
,
"Condition is not scalar"
cond
=
_op
.
take
(
cond
,
_expr
.
const
(
1
,
"int"
))
sb
=
_scope_builder
.
ScopeBuilder
()
with
sb
.
if_scope
(
cond
):
then_arg_shapes
=
[
arg
.
type_annotation
.
shape
for
arg
in
then_args
]
then_arg_dtype_info
=
[
arg
.
type_annotation
.
dtype
for
arg
in
then_args
]
then_func
=
_from_mxnet_impl
(
subgraphs
[
1
],
then_arg_shapes
,
then_arg_dtype_info
)
sb
.
ret
(
_expr
.
Call
(
then_func
,
then_args
))
with
sb
.
else_scope
():
else_arg_shapes
=
[
arg
.
type_annotation
.
shape
for
arg
in
else_args
]
else_arg_dtype_info
=
[
arg
.
type_annotation
.
dtype
for
arg
in
else_args
]
else_func
=
_from_mxnet_impl
(
subgraphs
[
2
],
else_arg_shapes
,
else_arg_dtype_info
)
sb
.
ret
(
_expr
.
Call
(
else_func
,
else_args
))
func
=
_expr
.
Function
(
input_args
,
sb
.
get
())
ret
=
_expr
.
Call
(
func
,
inputs
)
if
num_outputs
>
1
:
ret
=
_expr
.
TupleWrapper
(
ret
,
num_outputs
)
return
ret
# Note: due to attribute conversion constraint
# ops in the identity set must be attribute free
...
...
@@ -1204,6 +1247,8 @@ _convert_map = {
# NLP
"RNN"
:
_mx_rnn_layer
,
"_rnn_param_concat"
:
_mx_rnn_param_concat
,
# control flow
"_cond"
:
_mx_cond
,
# Depricated:
"Crop"
:
_mx_crop_like
,
# List of missing operators that are present in NNVMv1
...
...
@@ -1245,9 +1290,13 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info, mod=None):
Converted relay Function
"""
assert
symbol
is
not
None
jgraph
=
json
.
loads
(
symbol
.
tojson
())
if
isinstance
(
symbol
,
dict
):
jgraph
=
symbol
else
:
jgraph
=
json
.
loads
(
symbol
.
tojson
())
jnodes
=
jgraph
[
"nodes"
]
node_map
=
{}
shape_idx
=
0
for
nid
,
node
in
enumerate
(
jnodes
):
children
=
[
node_map
[
e
[
0
]][
e
[
1
]]
for
e
in
node
[
"inputs"
]]
...
...
@@ -1255,14 +1304,27 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info, mod=None):
node_name
=
node
[
"name"
]
op_name
=
node
[
"op"
]
if
op_name
==
"null"
:
shape
=
shape_dict
[
node_name
]
if
node_name
in
shape_dict
else
None
if
isinstance
(
shape_dict
,
dict
):
shape
=
shape_dict
[
node_name
]
if
node_name
in
shape_dict
else
None
elif
isinstance
(
shape_dict
,
(
list
,
tuple
)):
shape
=
shape_dict
[
shape_idx
]
else
:
raise
ValueError
(
"Unknown type of shape_dict:
%
s"
+
type
(
shape_dict
))
if
isinstance
(
dtype_info
,
dict
):
dtype
=
dtype_info
[
node_name
]
if
node_name
in
dtype_info
else
"float32"
elif
isinstance
(
dtype_info
,
(
list
,
tuple
)):
dtype
=
dtype_info
[
shape_idx
]
else
:
dtype
=
dtype_info
if
isinstance
(
shape_dict
,
(
list
,
tuple
)):
shape_idx
+=
1
node_map
[
nid
]
=
[
_expr
.
var
(
node_name
,
shape
=
shape
,
dtype
=
dtype
)]
elif
op_name
in
_convert_map
:
res
=
_convert_map
[
op_name
](
children
,
attrs
)
if
op_name
in
[
'_cond'
,
'_foreach'
,
'_while_loop'
]:
subgraphs
=
node
[
'subgraphs'
]
res
=
_convert_map
[
op_name
](
children
,
attrs
,
subgraphs
)
else
:
res
=
_convert_map
[
op_name
](
children
,
attrs
)
if
res
is
None
:
# defer conversion, used in RNN state initialization
res
=
[
node
]
...
...
tests/python/frontend/mxnet/test_forward.py
View file @
dddb0ed0
...
...
@@ -909,6 +909,31 @@ def test_forward_deconvolution():
verify
(
data_shape
=
(
1
,
8
,
32
,
32
),
kernel_size
=
(
3
,
3
),
stride
=
(
1
,
1
),
pad
=
(
1
,
1
),
num_filter
=
2
)
verify
(
data_shape
=
(
20
,
8
,
32
,
32
),
kernel_size
=
(
3
,
3
),
stride
=
(
1
,
1
),
pad
=
(
1
,
1
),
num_filter
=
2
)
def
test_forward_cond
():
def
verify
(
a_np
,
b_np
):
a_nd
,
b_nd
=
mx
.
nd
.
array
(
a_np
),
mx
.
nd
.
array
(
b_np
)
pred
=
a_nd
*
b_nd
<
5
then_func
=
lambda
:
(
a_nd
+
5
)
*
(
b_nd
+
5
)
else_func
=
lambda
:
(
a_nd
-
5
)
*
(
b_nd
-
5
)
ref_res
=
mx
.
nd
.
contrib
.
cond
(
pred
,
then_func
,
else_func
)
a_sym
,
b_sym
=
mx
.
sym
.
var
(
"a"
),
mx
.
sym
.
var
(
"b"
)
pred
=
a_sym
*
b_sym
<
5
then_func
=
lambda
:
(
a_sym
+
5
)
*
(
b_sym
+
5
)
else_func
=
lambda
:
(
a_sym
-
5
)
*
(
b_sym
-
5
)
mx_sym
=
mx
.
sym
.
contrib
.
cond
(
pred
,
then_func
,
else_func
)
shape_dict
=
{
"a"
:
a_np
.
shape
,
"b"
:
b_np
.
shape
}
mod
,
_
=
relay
.
frontend
.
from_mxnet
(
mx_sym
,
shape_dict
)
for
target
,
ctx
in
ctx_list
():
for
kind
in
[
"debug"
,
"vm"
]:
intrp
=
relay
.
create_executor
(
kind
,
mod
=
mod
,
ctx
=
ctx
,
target
=
target
)
op_res
=
intrp
.
evaluate
()(
a_np
,
b_np
)
tvm
.
testing
.
assert_allclose
(
op_res
.
asnumpy
(),
ref_res
.
asnumpy
(),
rtol
=
1e-3
)
verify
(
np
.
asarray
([
1.0
],
'float32'
),
np
.
asarray
([
2.0
],
'float32'
))
verify
(
np
.
asarray
([
4.0
],
'float32'
),
np
.
asarray
([
3.0
],
'float32'
))
if
__name__
==
'__main__'
:
test_forward_mlp
()
...
...
@@ -963,3 +988,4 @@ if __name__ == '__main__':
test_forward_one_hot
()
test_forward_convolution
()
test_forward_deconvolution
()
test_forward_cond
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment