Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
9a6feca6
Commit
9a6feca6
authored
Oct 15, 2017
by
Joshua Z. Zhang
Committed by
Tianqi Chen
May 29, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[FRONTEND] Composed operators (#175)
* fix for composed symbol * fix * clean up * fix exception type
parent
9fb13a69
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
109 additions
and
59 deletions
+109
-59
nnvm/python/nnvm/frontend/mxnet.py
+75
-58
nnvm/tests/python/frontend/mxnet/test_forward.py
+34
-1
No files found.
nnvm/python/nnvm/frontend/mxnet.py
View file @
9a6feca6
...
@@ -7,6 +7,12 @@ from .. import symbol as _sym
...
@@ -7,6 +7,12 @@ from .. import symbol as _sym
__all__
=
[
'from_mxnet'
]
__all__
=
[
'from_mxnet'
]
def
_get_nnvm_op
(
op_name
):
op
=
getattr
(
_sym
,
op_name
)
if
not
op
:
raise
RuntimeError
(
"Unable to map op_name {} to nnvm.sym"
.
format
(
op_name
))
return
op
def
_get_mxnet_version
():
def
_get_mxnet_version
():
try
:
try
:
import
mxnet
as
mx
import
mxnet
as
mx
...
@@ -39,14 +45,11 @@ def _parse_bool_str(attr, key, default='False'):
...
@@ -39,14 +45,11 @@ def _parse_bool_str(attr, key, default='False'):
return
attr
.
get
(
key
,
default
)
.
strip
()
.
lower
()
in
[
'true'
,
'1'
,
't'
,
'y'
,
'yes'
]
return
attr
.
get
(
key
,
default
)
.
strip
()
.
lower
()
in
[
'true'
,
'1'
,
't'
,
'y'
,
'yes'
]
def
_rename
(
new_name
):
def
_rename
(
new_name
):
def
impl
(
attr
):
def
impl
(
inputs
,
attrs
):
return
new_name
,
attr
return
_get_nnvm_op
(
new_name
)(
*
inputs
,
**
attrs
)
return
impl
return
impl
def
_variable
(
attrs
):
def
_pooling
(
inputs
,
attrs
):
return
"Variable"
,
attrs
def
_pooling
(
attrs
):
kernel
=
_parse_tshape
(
_required_attr
(
attrs
,
'kernel'
))
kernel
=
_parse_tshape
(
_required_attr
(
attrs
,
'kernel'
))
if
len
(
kernel
)
!=
2
:
if
len
(
kernel
)
!=
2
:
_raise_not_supported
(
'non-2d kernel'
,
'pool_2d'
)
_raise_not_supported
(
'non-2d kernel'
,
'pool_2d'
)
...
@@ -61,9 +64,9 @@ def _pooling(attrs):
...
@@ -61,9 +64,9 @@ def _pooling(attrs):
new_attrs
[
'strides'
]
=
attrs
.
get
(
'stride'
,
(
1
,
1
))
new_attrs
[
'strides'
]
=
attrs
.
get
(
'stride'
,
(
1
,
1
))
new_attrs
[
'padding'
]
=
attrs
.
get
(
'pad'
,
(
0
,
0
))
new_attrs
[
'padding'
]
=
attrs
.
get
(
'pad'
,
(
0
,
0
))
new_attrs
[
'ceil_mode'
]
=
(
attrs
.
get
(
'pooling_convention'
,
'valid'
)
==
'full'
)
new_attrs
[
'ceil_mode'
]
=
(
attrs
.
get
(
'pooling_convention'
,
'valid'
)
==
'full'
)
return
op_name
,
new_attrs
return
_get_nnvm_op
(
op_name
)(
*
inputs
,
**
new_attrs
)
def
_batch_norm
(
attrs
):
def
_batch_norm
(
inputs
,
attrs
):
if
_parse_bool_str
(
attrs
,
'output_mean_var'
):
if
_parse_bool_str
(
attrs
,
'output_mean_var'
):
_raise_not_supported
(
'output_mean_var'
,
'batch_norm'
)
_raise_not_supported
(
'output_mean_var'
,
'batch_norm'
)
# if _parse_bool_str(attrs, 'fix_gamma'):
# if _parse_bool_str(attrs, 'fix_gamma'):
...
@@ -77,14 +80,14 @@ def _batch_norm(attrs):
...
@@ -77,14 +80,14 @@ def _batch_norm(attrs):
new_attrs
[
'epsilon'
]
=
attrs
.
get
(
'eps'
,
0.001
)
new_attrs
[
'epsilon'
]
=
attrs
.
get
(
'eps'
,
0.001
)
new_attrs
[
'center'
]
=
True
new_attrs
[
'center'
]
=
True
new_attrs
[
'scale'
]
=
True
new_attrs
[
'scale'
]
=
True
return
op_name
,
new_attrs
return
_get_nnvm_op
(
op_name
)(
*
inputs
,
**
new_attrs
)
def
_concat
(
attrs
):
def
_concat
(
inputs
,
attrs
):
op_name
=
'concatenate'
op_name
=
'concatenate'
new_attrs
=
{
'axis'
:
attrs
.
get
(
'dim'
,
1
)}
new_attrs
=
{
'axis'
:
attrs
.
get
(
'dim'
,
1
)}
return
op_name
,
new_attrs
return
_get_nnvm_op
(
op_name
)(
*
inputs
,
**
new_attrs
)
def
_conv2d
(
attrs
):
def
_conv2d
(
inputs
,
attrs
):
kernel
=
_parse_tshape
(
_required_attr
(
attrs
,
'kernel'
))
kernel
=
_parse_tshape
(
_required_attr
(
attrs
,
'kernel'
))
if
len
(
kernel
)
!=
2
:
if
len
(
kernel
)
!=
2
:
_raise_not_supported
(
'non 2d kernel'
,
'conv2d'
)
_raise_not_supported
(
'non 2d kernel'
,
'conv2d'
)
...
@@ -100,9 +103,9 @@ def _conv2d(attrs):
...
@@ -100,9 +103,9 @@ def _conv2d(attrs):
new_attrs
[
'groups'
]
=
attrs
.
get
(
'num_group'
,
1
)
new_attrs
[
'groups'
]
=
attrs
.
get
(
'num_group'
,
1
)
new_attrs
[
'layout'
]
=
layout
new_attrs
[
'layout'
]
=
layout
new_attrs
[
'use_bias'
]
=
attrs
.
get
(
'no_bias'
,
'False'
)
.
strip
()
==
'False'
new_attrs
[
'use_bias'
]
=
attrs
.
get
(
'no_bias'
,
'False'
)
.
strip
()
==
'False'
return
op_name
,
new_attrs
return
_get_nnvm_op
(
op_name
)(
*
inputs
,
**
new_attrs
)
def
_conv2d_transpose
(
attrs
):
def
_conv2d_transpose
(
inputs
,
attrs
):
if
'target_shape'
in
attrs
:
if
'target_shape'
in
attrs
:
_raise_not_supported
(
'target_shape'
,
'conv2d_transpose'
)
_raise_not_supported
(
'target_shape'
,
'conv2d_transpose'
)
kernel
=
_parse_tshape
(
_required_attr
(
attrs
,
'kernel'
))
kernel
=
_parse_tshape
(
_required_attr
(
attrs
,
'kernel'
))
...
@@ -121,51 +124,68 @@ def _conv2d_transpose(attrs):
...
@@ -121,51 +124,68 @@ def _conv2d_transpose(attrs):
new_attrs
[
'groups'
]
=
attrs
.
get
(
'num_group'
,
1
)
new_attrs
[
'groups'
]
=
attrs
.
get
(
'num_group'
,
1
)
new_attrs
[
'layout'
]
=
layout
new_attrs
[
'layout'
]
=
layout
new_attrs
[
'use_bias'
]
=
not
_parse_bool_str
(
attrs
,
'no_bias'
)
new_attrs
[
'use_bias'
]
=
not
_parse_bool_str
(
attrs
,
'no_bias'
)
return
op_name
,
new_attrs
return
_get_nnvm_op
(
op_name
)(
*
inputs
,
**
new_attrs
)
def
_dense
(
attrs
):
def
_dense
(
inputs
,
attrs
):
op_name
,
new_attrs
=
'dense'
,
{}
op_name
,
new_attrs
=
'dense'
,
{}
new_attrs
[
'units'
]
=
_required_attr
(
attrs
,
'num_hidden'
)
new_attrs
[
'units'
]
=
_required_attr
(
attrs
,
'num_hidden'
)
new_attrs
[
'use_bias'
]
=
not
_parse_bool_str
(
attrs
,
'no_bias'
)
new_attrs
[
'use_bias'
]
=
not
_parse_bool_str
(
attrs
,
'no_bias'
)
major
,
minor
,
micro
=
_get_mxnet_version
()
major
,
minor
,
micro
=
_get_mxnet_version
()
if
major
>=
0
and
minor
>=
11
and
micro
>=
1
:
if
major
>=
0
and
minor
>=
11
and
micro
>=
1
:
new_attrs
[
'flatten'
]
=
_parse_bool_str
(
attrs
,
'flatten'
,
'True'
)
use_flatten
=
_parse_bool_str
(
attrs
,
'flatten'
,
'True'
)
return
op_name
,
new_attrs
if
use_flatten
:
inputs
[
0
]
=
_sym
.
flatten
(
inputs
[
0
])
return
_get_nnvm_op
(
op_name
)(
*
inputs
,
**
new_attrs
)
def
_dropout
(
attrs
):
def
_dropout
(
inputs
,
attrs
):
op_name
,
new_attrs
=
'dropout'
,
{}
op_name
,
new_attrs
=
'dropout'
,
{}
new_attrs
[
'rate'
]
=
attrs
.
get
(
'p'
,
0.5
)
new_attrs
[
'rate'
]
=
attrs
.
get
(
'p'
,
0.5
)
return
op_name
,
new_attrs
return
_get_nnvm_op
(
op_name
)(
*
inputs
,
**
new_attrs
)
def
_leaky_relu
(
attrs
):
def
_leaky_relu
(
inputs
,
attrs
):
act_type
=
_required_attr
(
attrs
,
'act_type'
)
act_type
=
_required_attr
(
attrs
,
'act_type'
)
if
act_type
not
in
[
'leaky'
]:
if
act_type
in
[
'leaky'
]:
op_name
,
new_attrs
=
'leaky_relu'
,
{}
new_attrs
[
'alpha'
]
=
attrs
.
get
(
'slope'
,
0.25
)
sym
=
_get_nnvm_op
(
op_name
)(
*
inputs
,
**
new_attrs
)
elif
act_type
==
'elu'
:
slope
=
attrs
.
get
(
'slope'
,
0.25
)
sym
=
-
slope
*
_sym
.
relu
(
1
-
_sym
.
exp
(
*
inputs
))
+
_sym
.
relu
(
*
inputs
)
elif
act_type
==
'rrelu'
:
lower_bound
=
float
(
_required_attr
(
attrs
,
'lower_bound'
))
upper_bound
=
float
(
_required_attr
(
attrs
,
'upper_bound'
))
slope
=
(
lower_bound
+
upper_bound
)
/
2.0
op_name
,
new_attrs
=
'leaky_relu'
,
{
'alpha'
:
str
(
slope
)}
sym
=
_get_nnvm_op
(
op_name
)(
*
inputs
,
**
new_attrs
)
else
:
_raise_not_supported
(
'act_type: '
+
act_type
)
_raise_not_supported
(
'act_type: '
+
act_type
)
op_name
,
new_attrs
=
'leaky_relu'
,
{}
return
sym
new_attrs
[
'alpha'
]
=
attrs
.
get
(
'slope'
,
0.25
)
return
op_name
,
new_attrs
def
_activations
(
attrs
):
def
_activations
(
inputs
,
attrs
):
act_type
=
_required_attr
(
attrs
,
'act_type'
)
act_type
=
_required_attr
(
attrs
,
'act_type'
)
if
act_type
not
in
[
'relu'
,
'sigmoid'
,
'tanh'
]:
if
act_type
in
[
'relu'
,
'sigmoid'
,
'tanh'
]:
op_name
,
new_attrs
=
act_type
,
{}
sym
=
_get_nnvm_op
(
op_name
)(
*
inputs
,
**
new_attrs
)
elif
act_type
==
'softrelu'
:
sym
=
_sym
.
log
((
1
+
_sym
.
exp
(
*
inputs
)))
else
:
_raise_not_supported
(
'act_type: '
+
act_type
)
_raise_not_supported
(
'act_type: '
+
act_type
)
op_name
,
new_attrs
=
act_type
,
{}
return
sym
return
op_name
,
new_attrs
def
_reshape
(
attrs
):
def
_reshape
(
inputs
,
attrs
):
if
_parse_bool_str
(
attrs
,
'reverse'
):
if
_parse_bool_str
(
attrs
,
'reverse'
):
_raise_not_supported
(
'reverse'
,
'reshape'
)
_raise_not_supported
(
'reverse'
,
'reshape'
)
op_name
,
new_attrs
=
'reshape'
,
{}
op_name
,
new_attrs
=
'reshape'
,
{}
new_attrs
[
'shape'
]
=
_required_attr
(
attrs
,
'shape'
)
new_attrs
[
'shape'
]
=
_required_attr
(
attrs
,
'shape'
)
return
op_name
,
new_attrs
return
_get_nnvm_op
(
op_name
)(
*
inputs
,
**
new_attrs
)
def
_split
(
attrs
):
def
_split
(
inputs
,
attrs
):
if
_parse_bool_str
(
attrs
,
'squeeze_axis'
):
if
_parse_bool_str
(
attrs
,
'squeeze_axis'
):
_raise_not_supported
(
'squeeze_axis'
,
'split'
)
_raise_not_supported
(
'squeeze_axis'
,
'split'
)
op_name
,
new_attrs
=
'split'
,
{}
op_name
,
new_attrs
=
'split'
,
{}
new_attrs
[
'indices_or_sections'
]
=
_required_attr
(
attrs
,
'num_outputs'
)
new_attrs
[
'indices_or_sections'
]
=
_required_attr
(
attrs
,
'num_outputs'
)
new_attrs
[
'axis'
]
=
attrs
.
get
(
'axis'
,
1
)
new_attrs
[
'axis'
]
=
attrs
.
get
(
'axis'
,
1
)
return
op_name
,
new_attrs
return
_get_nnvm_op
(
op_name
)(
*
inputs
,
**
new_attrs
)
_identity_list
=
[
'__add_scalar__'
,
'__add_symbol__'
,
'__div_scalar__'
,
_identity_list
=
[
'__add_scalar__'
,
'__add_symbol__'
,
'__div_scalar__'
,
'__div_symbol__'
,
'__mul_scalar__'
,
'__mul_symbol__'
,
'__div_symbol__'
,
'__mul_scalar__'
,
'__mul_symbol__'
,
...
@@ -178,7 +198,12 @@ _identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__',
...
@@ -178,7 +198,12 @@ _identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__',
'relu'
,
'sigmoid'
,
'softmax'
,
'sum'
,
'tanh'
,
'transpose'
]
'relu'
,
'sigmoid'
,
'softmax'
,
'sum'
,
'tanh'
,
'transpose'
]
_convert_map
=
{
_convert_map
=
{
'null'
:
_variable
,
'_div_scalar'
:
_rename
(
'__div_scalar__'
),
'_minus_scalar'
:
_rename
(
'__sub_scalar__'
),
'_mul_scalar'
:
_rename
(
'__mul_scalar__'
),
'_plus_scalar'
:
_rename
(
'__add_scalar__'
),
'_rdiv_scalar'
:
_rename
(
'__rdiv_scalar__'
),
'_rminus_scalar'
:
_rename
(
'__rsub_scalar__'
),
'Activation'
:
_activations
,
'Activation'
:
_activations
,
'BatchNorm'
:
_batch_norm
,
'BatchNorm'
:
_batch_norm
,
'BatchNorm_v1'
:
_batch_norm
,
'BatchNorm_v1'
:
_batch_norm
,
...
@@ -202,7 +227,7 @@ _convert_map = {
...
@@ -202,7 +227,7 @@ _convert_map = {
'sum_axis'
:
_rename
(
'sum'
),
'sum_axis'
:
_rename
(
'sum'
),
}
}
def
_convert_symbol
(
op_name
,
attrs
,
def
_convert_symbol
(
op_name
,
inputs
,
attrs
,
identity_list
=
None
,
identity_list
=
None
,
convert_map
=
None
):
convert_map
=
None
):
"""Convert from mxnet op to nnvm op.
"""Convert from mxnet op to nnvm op.
...
@@ -213,6 +238,8 @@ def _convert_symbol(op_name, attrs,
...
@@ -213,6 +238,8 @@ def _convert_symbol(op_name, attrs,
----------
----------
op_name : str
op_name : str
Operator name, such as Convolution, FullyConnected
Operator name, such as Convolution, FullyConnected
inputs : list of nnvm.Symbol
List of input symbols.
attrs : dict
attrs : dict
Dict of operator attributes
Dict of operator attributes
identity_list : list
identity_list : list
...
@@ -224,21 +251,19 @@ def _convert_symbol(op_name, attrs,
...
@@ -224,21 +251,19 @@ def _convert_symbol(op_name, attrs,
Returns
Returns
-------
-------
(op_name, attrs)
sym : nnvm.Symbol
Converted
(op_name, attrs) for nnvm.
Converted
nnvm Symbol
"""
"""
identity_list
=
identity_list
if
identity_list
else
_identity_list
identity_list
=
identity_list
if
identity_list
else
_identity_list
convert_map
=
convert_map
if
convert_map
else
_convert_map
convert_map
=
convert_map
if
convert_map
else
_convert_map
if
op_name
in
identity_list
:
if
op_name
in
identity_list
:
pass
op
=
_get_nnvm_op
(
op_name
)
sym
=
op
(
*
inputs
,
**
attrs
)
elif
op_name
in
convert_map
:
elif
op_name
in
convert_map
:
op_name
,
attrs
=
convert_map
[
op_name
](
attrs
)
sym
=
convert_map
[
op_name
](
inputs
,
attrs
)
else
:
else
:
_raise_not_supported
(
'Operator: '
+
op_name
)
_raise_not_supported
(
'Operator: '
+
op_name
)
op
=
getattr
(
_sym
,
op_name
,
None
)
return
sym
if
not
op
:
raise
RuntimeError
(
"Unable to map op_name {} to nnvm.sym"
.
format
(
op_name
))
return
op
,
attrs
def
_is_mxnet_group_symbol
(
symbol
):
def
_is_mxnet_group_symbol
(
symbol
):
"""Internal check for mxnet group symbol."""
"""Internal check for mxnet group symbol."""
...
@@ -274,28 +299,20 @@ def _from_mxnet_impl(symbol, graph):
...
@@ -274,28 +299,20 @@ def _from_mxnet_impl(symbol, graph):
node
=
graph
.
get
(
name
,
None
)
node
=
graph
.
get
(
name
,
None
)
if
node
:
if
node
:
return
node
return
node
attr
=
symbol
.
list_attr
()
# op_name = symbol.attr('op_name')
# op_name = symbol.attr('op_name')
if
symbol
.
get_children
():
childs
=
symbol
.
get_children
()
if
childs
:
op_name
=
symbol
.
attr
(
'op_name'
)
op_name
=
symbol
.
attr
(
'op_name'
)
else
:
op_name
=
json
.
loads
(
symbol
.
tojson
())[
'nodes'
][
0
][
'op'
]
attr
=
symbol
.
list_attr
()
new_op
,
new_attr
=
_convert_symbol
(
op_name
,
attr
)
if
new_op
==
_sym
.
Variable
:
node
=
new_op
(
name
=
name
,
**
new_attr
)
else
:
childs
=
symbol
.
get_children
()
childs
=
[
_from_mxnet_impl
(
c
,
graph
)
for
c
in
_as_list
(
childs
)]
childs
=
[
_from_mxnet_impl
(
c
,
graph
)
for
c
in
_as_list
(
childs
)]
childs
=
[
x
for
y
in
childs
for
x
in
_as_list
(
y
)]
# expand group symbol
childs
=
[
x
for
y
in
childs
for
x
in
_as_list
(
y
)]
# expand group symbol
if
new_op
==
_sym
.
dense
and
'flatten'
in
new_attr
:
node
=
_convert_symbol
(
op_name
,
childs
,
attr
)
if
new_attr
[
'flatten'
]:
else
:
childs
[
0
]
=
_sym
.
flatten
(
childs
[
0
])
op_name
=
json
.
loads
(
symbol
.
tojson
())[
'nodes'
][
0
][
'op'
]
new_attr
.
pop
(
'flatten'
)
node
=
_sym
.
Variable
(
name
=
name
,
**
attr
)
node
=
new_op
(
name
=
name
,
*
childs
,
**
new_attr
)
graph
[
name
]
=
node
graph
[
name
]
=
node
return
node
return
node
def
from_mxnet
(
symbol
,
arg_params
=
None
,
aux_params
=
None
):
def
from_mxnet
(
symbol
,
arg_params
=
None
,
aux_params
=
None
):
"""Convert from MXNet's model into compatible NNVM format.
"""Convert from MXNet's model into compatible NNVM format.
...
...
nnvm/tests/python/frontend/mxnet/test_forward.py
View file @
9a6feca6
...
@@ -46,7 +46,7 @@ def verify_mxnet_frontend_impl(mx_symbol, data_shape=(1, 3, 224, 224), out_shape
...
@@ -46,7 +46,7 @@ def verify_mxnet_frontend_impl(mx_symbol, data_shape=(1, 3, 224, 224), out_shape
assert
"data"
not
in
args
assert
"data"
not
in
args
for
target
,
ctx
in
ctx_list
():
for
target
,
ctx
in
ctx_list
():
tvm_out
=
get_tvm_output
(
mx_symbol
,
x
,
args
,
auxs
,
target
,
ctx
,
dtype
)
tvm_out
=
get_tvm_output
(
mx_symbol
,
x
,
args
,
auxs
,
target
,
ctx
,
dtype
)
np
.
testing
.
assert_allclose
(
mx_out
,
tvm_out
,
rtol
=
1e-5
)
np
.
testing
.
assert_allclose
(
mx_out
,
tvm_out
,
rtol
=
1e-5
,
atol
=
1e-5
)
def
test_forward_mlp
():
def
test_forward_mlp
():
mlp
=
model_zoo
.
mx_mlp
mlp
=
model_zoo
.
mx_mlp
...
@@ -62,7 +62,40 @@ def test_forward_resnet():
...
@@ -62,7 +62,40 @@ def test_forward_resnet():
mx_sym
=
model_zoo
.
mx_resnet
[
n
]
mx_sym
=
model_zoo
.
mx_resnet
[
n
]
verify_mxnet_frontend_impl
(
mx_sym
)
verify_mxnet_frontend_impl
(
mx_sym
)
def
test_forward_elu
():
data
=
mx
.
sym
.
var
(
'data'
)
data
=
mx
.
sym
.
concat
(
data
,
-
data
,
dim
=
1
)
# negative part explicitly
mx_sym
=
mx
.
sym
.
LeakyReLU
(
data
,
act_type
=
'elu'
)
verify_mxnet_frontend_impl
(
mx_sym
,
(
1
,
3
,
100
,
100
),
(
1
,
6
,
100
,
100
))
def
test_forward_rrelu
():
data
=
mx
.
sym
.
var
(
'data'
)
data
=
mx
.
sym
.
concat
(
data
,
-
data
,
dim
=
1
)
# negative part explicitly
mx_sym
=
mx
.
sym
.
LeakyReLU
(
data
,
act_type
=
'rrelu'
,
lower_bound
=
0.3
,
upper_bound
=
0.7
)
verify_mxnet_frontend_impl
(
mx_sym
,
(
1
,
3
,
100
,
100
),
(
1
,
6
,
100
,
100
))
def
test_forward_softrelu
():
data
=
mx
.
sym
.
var
(
'data'
)
data
=
mx
.
sym
.
concat
(
data
,
-
data
,
dim
=
1
)
# negative part explicitly
mx_sym
=
mx
.
sym
.
Activation
(
data
,
act_type
=
'softrelu'
)
verify_mxnet_frontend_impl
(
mx_sym
,
(
1
,
3
,
100
,
100
),
(
1
,
6
,
100
,
100
))
def
test_forward_fc_flatten
():
# test flatten=True option in mxnet 0.11.1
data
=
mx
.
sym
.
var
(
'data'
)
try
:
mx_sym
=
mx
.
sym
.
FullyConnected
(
data
,
num_hidden
=
100
,
flatten
=
True
)
verify_mxnet_frontend_impl
(
mx_sym
,
(
1
,
3
,
100
,
100
),
(
1
,
100
))
mx_sym
=
mx
.
sym
.
FullyConnected
(
mx
.
sym
.
Flatten
(
data
),
num_hidden
=
100
,
flatten
=
False
)
verify_mxnet_frontend_impl
(
mx_sym
,
(
1
,
3
,
100
,
100
),
(
1
,
100
))
except
:
pass
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_forward_mlp
()
test_forward_mlp
()
test_forward_vgg
()
test_forward_vgg
()
test_forward_resnet
()
test_forward_resnet
()
test_forward_elu
()
test_forward_rrelu
()
test_forward_softrelu
()
test_forward_fc_flatten
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment