Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
9a6feca6
Commit
9a6feca6
authored
Oct 15, 2017
by
Joshua Z. Zhang
Committed by
Tianqi Chen
May 29, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[FRONTEND] Composed operators (#175)
* fix for composed symbol * fix * clean up * fix exception type
parent
9fb13a69
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
109 additions
and
59 deletions
+109
-59
nnvm/python/nnvm/frontend/mxnet.py
+75
-58
nnvm/tests/python/frontend/mxnet/test_forward.py
+34
-1
No files found.
nnvm/python/nnvm/frontend/mxnet.py
View file @
9a6feca6
...
...
@@ -7,6 +7,12 @@ from .. import symbol as _sym
__all__
=
[
'from_mxnet'
]
def
_get_nnvm_op
(
op_name
):
op
=
getattr
(
_sym
,
op_name
)
if
not
op
:
raise
RuntimeError
(
"Unable to map op_name {} to nnvm.sym"
.
format
(
op_name
))
return
op
def
_get_mxnet_version
():
try
:
import
mxnet
as
mx
...
...
@@ -39,14 +45,11 @@ def _parse_bool_str(attr, key, default='False'):
return
attr
.
get
(
key
,
default
)
.
strip
()
.
lower
()
in
[
'true'
,
'1'
,
't'
,
'y'
,
'yes'
]
def
_rename
(
new_name
):
def
impl
(
attr
):
return
new_name
,
attr
def
impl
(
inputs
,
attrs
):
return
_get_nnvm_op
(
new_name
)(
*
inputs
,
**
attrs
)
return
impl
def
_variable
(
attrs
):
return
"Variable"
,
attrs
def
_pooling
(
attrs
):
def
_pooling
(
inputs
,
attrs
):
kernel
=
_parse_tshape
(
_required_attr
(
attrs
,
'kernel'
))
if
len
(
kernel
)
!=
2
:
_raise_not_supported
(
'non-2d kernel'
,
'pool_2d'
)
...
...
@@ -61,9 +64,9 @@ def _pooling(attrs):
new_attrs
[
'strides'
]
=
attrs
.
get
(
'stride'
,
(
1
,
1
))
new_attrs
[
'padding'
]
=
attrs
.
get
(
'pad'
,
(
0
,
0
))
new_attrs
[
'ceil_mode'
]
=
(
attrs
.
get
(
'pooling_convention'
,
'valid'
)
==
'full'
)
return
op_name
,
new_attrs
return
_get_nnvm_op
(
op_name
)(
*
inputs
,
**
new_attrs
)
def
_batch_norm
(
attrs
):
def
_batch_norm
(
inputs
,
attrs
):
if
_parse_bool_str
(
attrs
,
'output_mean_var'
):
_raise_not_supported
(
'output_mean_var'
,
'batch_norm'
)
# if _parse_bool_str(attrs, 'fix_gamma'):
...
...
@@ -77,14 +80,14 @@ def _batch_norm(attrs):
new_attrs
[
'epsilon'
]
=
attrs
.
get
(
'eps'
,
0.001
)
new_attrs
[
'center'
]
=
True
new_attrs
[
'scale'
]
=
True
return
op_name
,
new_attrs
return
_get_nnvm_op
(
op_name
)(
*
inputs
,
**
new_attrs
)
def
_concat
(
attrs
):
def
_concat
(
inputs
,
attrs
):
op_name
=
'concatenate'
new_attrs
=
{
'axis'
:
attrs
.
get
(
'dim'
,
1
)}
return
op_name
,
new_attrs
return
_get_nnvm_op
(
op_name
)(
*
inputs
,
**
new_attrs
)
def
_conv2d
(
attrs
):
def
_conv2d
(
inputs
,
attrs
):
kernel
=
_parse_tshape
(
_required_attr
(
attrs
,
'kernel'
))
if
len
(
kernel
)
!=
2
:
_raise_not_supported
(
'non 2d kernel'
,
'conv2d'
)
...
...
@@ -100,9 +103,9 @@ def _conv2d(attrs):
new_attrs
[
'groups'
]
=
attrs
.
get
(
'num_group'
,
1
)
new_attrs
[
'layout'
]
=
layout
new_attrs
[
'use_bias'
]
=
attrs
.
get
(
'no_bias'
,
'False'
)
.
strip
()
==
'False'
return
op_name
,
new_attrs
return
_get_nnvm_op
(
op_name
)(
*
inputs
,
**
new_attrs
)
def
_conv2d_transpose
(
attrs
):
def
_conv2d_transpose
(
inputs
,
attrs
):
if
'target_shape'
in
attrs
:
_raise_not_supported
(
'target_shape'
,
'conv2d_transpose'
)
kernel
=
_parse_tshape
(
_required_attr
(
attrs
,
'kernel'
))
...
...
@@ -121,51 +124,68 @@ def _conv2d_transpose(attrs):
new_attrs
[
'groups'
]
=
attrs
.
get
(
'num_group'
,
1
)
new_attrs
[
'layout'
]
=
layout
new_attrs
[
'use_bias'
]
=
not
_parse_bool_str
(
attrs
,
'no_bias'
)
return
op_name
,
new_attrs
return
_get_nnvm_op
(
op_name
)(
*
inputs
,
**
new_attrs
)
def
_dense
(
attrs
):
def
_dense
(
inputs
,
attrs
):
op_name
,
new_attrs
=
'dense'
,
{}
new_attrs
[
'units'
]
=
_required_attr
(
attrs
,
'num_hidden'
)
new_attrs
[
'use_bias'
]
=
not
_parse_bool_str
(
attrs
,
'no_bias'
)
major
,
minor
,
micro
=
_get_mxnet_version
()
if
major
>=
0
and
minor
>=
11
and
micro
>=
1
:
new_attrs
[
'flatten'
]
=
_parse_bool_str
(
attrs
,
'flatten'
,
'True'
)
return
op_name
,
new_attrs
use_flatten
=
_parse_bool_str
(
attrs
,
'flatten'
,
'True'
)
if
use_flatten
:
inputs
[
0
]
=
_sym
.
flatten
(
inputs
[
0
])
return
_get_nnvm_op
(
op_name
)(
*
inputs
,
**
new_attrs
)
def
_dropout
(
attrs
):
def
_dropout
(
inputs
,
attrs
):
op_name
,
new_attrs
=
'dropout'
,
{}
new_attrs
[
'rate'
]
=
attrs
.
get
(
'p'
,
0.5
)
return
op_name
,
new_attrs
return
_get_nnvm_op
(
op_name
)(
*
inputs
,
**
new_attrs
)
def
_leaky_relu
(
attrs
):
def
_leaky_relu
(
inputs
,
attrs
):
act_type
=
_required_attr
(
attrs
,
'act_type'
)
if
act_type
not
in
[
'leaky'
]:
if
act_type
in
[
'leaky'
]:
op_name
,
new_attrs
=
'leaky_relu'
,
{}
new_attrs
[
'alpha'
]
=
attrs
.
get
(
'slope'
,
0.25
)
sym
=
_get_nnvm_op
(
op_name
)(
*
inputs
,
**
new_attrs
)
elif
act_type
==
'elu'
:
slope
=
attrs
.
get
(
'slope'
,
0.25
)
sym
=
-
slope
*
_sym
.
relu
(
1
-
_sym
.
exp
(
*
inputs
))
+
_sym
.
relu
(
*
inputs
)
elif
act_type
==
'rrelu'
:
lower_bound
=
float
(
_required_attr
(
attrs
,
'lower_bound'
))
upper_bound
=
float
(
_required_attr
(
attrs
,
'upper_bound'
))
slope
=
(
lower_bound
+
upper_bound
)
/
2.0
op_name
,
new_attrs
=
'leaky_relu'
,
{
'alpha'
:
str
(
slope
)}
sym
=
_get_nnvm_op
(
op_name
)(
*
inputs
,
**
new_attrs
)
else
:
_raise_not_supported
(
'act_type: '
+
act_type
)
op_name
,
new_attrs
=
'leaky_relu'
,
{}
new_attrs
[
'alpha'
]
=
attrs
.
get
(
'slope'
,
0.25
)
return
op_name
,
new_attrs
return
sym
def
_activations
(
attrs
):
def
_activations
(
inputs
,
attrs
):
act_type
=
_required_attr
(
attrs
,
'act_type'
)
if
act_type
not
in
[
'relu'
,
'sigmoid'
,
'tanh'
]:
if
act_type
in
[
'relu'
,
'sigmoid'
,
'tanh'
]:
op_name
,
new_attrs
=
act_type
,
{}
sym
=
_get_nnvm_op
(
op_name
)(
*
inputs
,
**
new_attrs
)
elif
act_type
==
'softrelu'
:
sym
=
_sym
.
log
((
1
+
_sym
.
exp
(
*
inputs
)))
else
:
_raise_not_supported
(
'act_type: '
+
act_type
)
op_name
,
new_attrs
=
act_type
,
{}
return
op_name
,
new_attrs
return
sym
def
_reshape
(
attrs
):
def
_reshape
(
inputs
,
attrs
):
if
_parse_bool_str
(
attrs
,
'reverse'
):
_raise_not_supported
(
'reverse'
,
'reshape'
)
op_name
,
new_attrs
=
'reshape'
,
{}
new_attrs
[
'shape'
]
=
_required_attr
(
attrs
,
'shape'
)
return
op_name
,
new_attrs
return
_get_nnvm_op
(
op_name
)(
*
inputs
,
**
new_attrs
)
def
_split
(
attrs
):
def
_split
(
inputs
,
attrs
):
if
_parse_bool_str
(
attrs
,
'squeeze_axis'
):
_raise_not_supported
(
'squeeze_axis'
,
'split'
)
op_name
,
new_attrs
=
'split'
,
{}
new_attrs
[
'indices_or_sections'
]
=
_required_attr
(
attrs
,
'num_outputs'
)
new_attrs
[
'axis'
]
=
attrs
.
get
(
'axis'
,
1
)
return
op_name
,
new_attrs
return
_get_nnvm_op
(
op_name
)(
*
inputs
,
**
new_attrs
)
_identity_list
=
[
'__add_scalar__'
,
'__add_symbol__'
,
'__div_scalar__'
,
'__div_symbol__'
,
'__mul_scalar__'
,
'__mul_symbol__'
,
...
...
@@ -178,7 +198,12 @@ _identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__',
'relu'
,
'sigmoid'
,
'softmax'
,
'sum'
,
'tanh'
,
'transpose'
]
_convert_map
=
{
'null'
:
_variable
,
'_div_scalar'
:
_rename
(
'__div_scalar__'
),
'_minus_scalar'
:
_rename
(
'__sub_scalar__'
),
'_mul_scalar'
:
_rename
(
'__mul_scalar__'
),
'_plus_scalar'
:
_rename
(
'__add_scalar__'
),
'_rdiv_scalar'
:
_rename
(
'__rdiv_scalar__'
),
'_rminus_scalar'
:
_rename
(
'__rsub_scalar__'
),
'Activation'
:
_activations
,
'BatchNorm'
:
_batch_norm
,
'BatchNorm_v1'
:
_batch_norm
,
...
...
@@ -202,7 +227,7 @@ _convert_map = {
'sum_axis'
:
_rename
(
'sum'
),
}
def
_convert_symbol
(
op_name
,
attrs
,
def
_convert_symbol
(
op_name
,
inputs
,
attrs
,
identity_list
=
None
,
convert_map
=
None
):
"""Convert from mxnet op to nnvm op.
...
...
@@ -213,6 +238,8 @@ def _convert_symbol(op_name, attrs,
----------
op_name : str
Operator name, such as Convolution, FullyConnected
inputs : list of nnvm.Symbol
List of input symbols.
attrs : dict
Dict of operator attributes
identity_list : list
...
...
@@ -224,21 +251,19 @@ def _convert_symbol(op_name, attrs,
Returns
-------
(op_name, attrs)
Converted
(op_name, attrs) for nnvm.
sym : nnvm.Symbol
Converted
nnvm Symbol
"""
identity_list
=
identity_list
if
identity_list
else
_identity_list
convert_map
=
convert_map
if
convert_map
else
_convert_map
if
op_name
in
identity_list
:
pass
op
=
_get_nnvm_op
(
op_name
)
sym
=
op
(
*
inputs
,
**
attrs
)
elif
op_name
in
convert_map
:
op_name
,
attrs
=
convert_map
[
op_name
](
attrs
)
sym
=
convert_map
[
op_name
](
inputs
,
attrs
)
else
:
_raise_not_supported
(
'Operator: '
+
op_name
)
op
=
getattr
(
_sym
,
op_name
,
None
)
if
not
op
:
raise
RuntimeError
(
"Unable to map op_name {} to nnvm.sym"
.
format
(
op_name
))
return
op
,
attrs
return
sym
def
_is_mxnet_group_symbol
(
symbol
):
"""Internal check for mxnet group symbol."""
...
...
@@ -274,28 +299,20 @@ def _from_mxnet_impl(symbol, graph):
node
=
graph
.
get
(
name
,
None
)
if
node
:
return
node
attr
=
symbol
.
list_attr
()
# op_name = symbol.attr('op_name')
if
symbol
.
get_children
():
childs
=
symbol
.
get_children
()
if
childs
:
op_name
=
symbol
.
attr
(
'op_name'
)
else
:
op_name
=
json
.
loads
(
symbol
.
tojson
())[
'nodes'
][
0
][
'op'
]
attr
=
symbol
.
list_attr
()
new_op
,
new_attr
=
_convert_symbol
(
op_name
,
attr
)
if
new_op
==
_sym
.
Variable
:
node
=
new_op
(
name
=
name
,
**
new_attr
)
else
:
childs
=
symbol
.
get_children
()
childs
=
[
_from_mxnet_impl
(
c
,
graph
)
for
c
in
_as_list
(
childs
)]
childs
=
[
x
for
y
in
childs
for
x
in
_as_list
(
y
)]
# expand group symbol
if
new_op
==
_sym
.
dense
and
'flatten'
in
new_attr
:
if
new_attr
[
'flatten'
]:
childs
[
0
]
=
_sym
.
flatten
(
childs
[
0
])
new_attr
.
pop
(
'flatten'
)
node
=
new_op
(
name
=
name
,
*
childs
,
**
new_attr
)
node
=
_convert_symbol
(
op_name
,
childs
,
attr
)
else
:
op_name
=
json
.
loads
(
symbol
.
tojson
())[
'nodes'
][
0
][
'op'
]
node
=
_sym
.
Variable
(
name
=
name
,
**
attr
)
graph
[
name
]
=
node
return
node
def
from_mxnet
(
symbol
,
arg_params
=
None
,
aux_params
=
None
):
"""Convert from MXNet's model into compatible NNVM format.
...
...
nnvm/tests/python/frontend/mxnet/test_forward.py
View file @
9a6feca6
...
...
@@ -46,7 +46,7 @@ def verify_mxnet_frontend_impl(mx_symbol, data_shape=(1, 3, 224, 224), out_shape
assert
"data"
not
in
args
for
target
,
ctx
in
ctx_list
():
tvm_out
=
get_tvm_output
(
mx_symbol
,
x
,
args
,
auxs
,
target
,
ctx
,
dtype
)
np
.
testing
.
assert_allclose
(
mx_out
,
tvm_out
,
rtol
=
1e-5
)
np
.
testing
.
assert_allclose
(
mx_out
,
tvm_out
,
rtol
=
1e-5
,
atol
=
1e-5
)
def
test_forward_mlp
():
mlp
=
model_zoo
.
mx_mlp
...
...
@@ -62,7 +62,40 @@ def test_forward_resnet():
mx_sym
=
model_zoo
.
mx_resnet
[
n
]
verify_mxnet_frontend_impl
(
mx_sym
)
def
test_forward_elu
():
data
=
mx
.
sym
.
var
(
'data'
)
data
=
mx
.
sym
.
concat
(
data
,
-
data
,
dim
=
1
)
# negative part explicitly
mx_sym
=
mx
.
sym
.
LeakyReLU
(
data
,
act_type
=
'elu'
)
verify_mxnet_frontend_impl
(
mx_sym
,
(
1
,
3
,
100
,
100
),
(
1
,
6
,
100
,
100
))
def
test_forward_rrelu
():
data
=
mx
.
sym
.
var
(
'data'
)
data
=
mx
.
sym
.
concat
(
data
,
-
data
,
dim
=
1
)
# negative part explicitly
mx_sym
=
mx
.
sym
.
LeakyReLU
(
data
,
act_type
=
'rrelu'
,
lower_bound
=
0.3
,
upper_bound
=
0.7
)
verify_mxnet_frontend_impl
(
mx_sym
,
(
1
,
3
,
100
,
100
),
(
1
,
6
,
100
,
100
))
def
test_forward_softrelu
():
data
=
mx
.
sym
.
var
(
'data'
)
data
=
mx
.
sym
.
concat
(
data
,
-
data
,
dim
=
1
)
# negative part explicitly
mx_sym
=
mx
.
sym
.
Activation
(
data
,
act_type
=
'softrelu'
)
verify_mxnet_frontend_impl
(
mx_sym
,
(
1
,
3
,
100
,
100
),
(
1
,
6
,
100
,
100
))
def
test_forward_fc_flatten
():
# test flatten=True option in mxnet 0.11.1
data
=
mx
.
sym
.
var
(
'data'
)
try
:
mx_sym
=
mx
.
sym
.
FullyConnected
(
data
,
num_hidden
=
100
,
flatten
=
True
)
verify_mxnet_frontend_impl
(
mx_sym
,
(
1
,
3
,
100
,
100
),
(
1
,
100
))
mx_sym
=
mx
.
sym
.
FullyConnected
(
mx
.
sym
.
Flatten
(
data
),
num_hidden
=
100
,
flatten
=
False
)
verify_mxnet_frontend_impl
(
mx_sym
,
(
1
,
3
,
100
,
100
),
(
1
,
100
))
except
:
pass
if
__name__
==
'__main__'
:
test_forward_mlp
()
test_forward_vgg
()
test_forward_resnet
()
test_forward_elu
()
test_forward_rrelu
()
test_forward_softrelu
()
test_forward_fc_flatten
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment