Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
29ee8a23
Commit
29ee8a23
authored
Jun 12, 2019
by
Haichen Shen
Committed by
Tianqi Chen
Jun 12, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay][Frontend] Fix MxNet RNN without providing state initialization as input (#3326)
parent
d0c45648
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
72 additions
and
24 deletions
+72
-24
python/tvm/relay/frontend/mxnet.py
+41
-7
tests/python/frontend/mxnet/test_forward.py
+31
-17
No files found.
python/tvm/relay/frontend/mxnet.py
View file @
29ee8a23
...
@@ -93,6 +93,15 @@ def _mx_compare(new_op, wrapper):
...
@@ -93,6 +93,15 @@ def _mx_compare(new_op, wrapper):
return
impl
return
impl
def
_mx_zeros
(
inputs
,
attrs
):
assert
len
(
inputs
)
==
0
shape
=
attrs
.
get_int_tuple
(
"shape"
)
dtype
=
attrs
.
get_str
(
"dtype"
,
"float32"
)
if
0
in
shape
:
return
None
return
_op
.
zeros
(
shape
=
shape
,
dtype
=
dtype
)
def
_mx_conv2d
(
inputs
,
attrs
):
def
_mx_conv2d
(
inputs
,
attrs
):
kernel_size
=
attrs
.
get_int_tuple
(
"kernel"
)
kernel_size
=
attrs
.
get_int_tuple
(
"kernel"
)
if
len
(
kernel_size
)
!=
2
:
if
len
(
kernel_size
)
!=
2
:
...
@@ -754,9 +763,30 @@ def _mx_rnn_layer(inputs, attrs):
...
@@ -754,9 +763,30 @@ def _mx_rnn_layer(inputs, attrs):
seq_data
=
inputs
[
0
]
seq_data
=
inputs
[
0
]
concat_weight
=
inputs
[
1
]
concat_weight
=
inputs
[
1
]
concat_states
=
inputs
[
2
:]
init_states
=
inputs
[
2
:]
seq_len
=
int
(
ir_pass
.
infer_type
(
seq_data
)
.
checked_type
.
shape
[
0
])
data_shape
=
ir_pass
.
infer_type
(
seq_data
)
.
checked_type
.
shape
seq_len
=
int
(
data_shape
[
0
])
assert
len
(
concat_weight
)
==
num_layers
*
4
assert
len
(
concat_weight
)
==
num_layers
*
4
output_states
=
True
for
idx
,
state
in
enumerate
(
init_states
[:]):
if
isinstance
(
state
,
dict
):
node
=
state
attrs
=
StrAttrsDict
(
node
.
get
(
"attrs"
,
{}))
op_name
=
node
[
"op"
]
# by default, RNN layer uses zeros to initialize states
assert
op_name
==
"_zeros"
shape
=
attrs
.
get_int_tuple
(
"shape"
)
dtype
=
attrs
.
get_str
(
"dtype"
,
"float32"
)
init_layout
=
attrs
.
get_str
(
"__layout__"
)
new_shape
=
list
(
shape
)
for
i
,
dim
in
enumerate
(
shape
):
if
dim
==
0
:
axis
=
layout
.
find
(
init_layout
[
i
])
assert
axis
>=
0
new_shape
[
i
]
=
int
(
data_shape
[
axis
])
init_states
[
idx
]
=
_op
.
zeros
(
new_shape
,
dtype
)
output_states
=
False
weights
=
[]
weights
=
[]
bias
=
[]
bias
=
[]
...
@@ -768,7 +798,7 @@ def _mx_rnn_layer(inputs, attrs):
...
@@ -768,7 +798,7 @@ def _mx_rnn_layer(inputs, attrs):
for
j
in
range
(
2
):
for
j
in
range
(
2
):
w
.
append
(
concat_weight
[
i
*
2
+
j
]
.
args
[
0
])
w
.
append
(
concat_weight
[
i
*
2
+
j
]
.
args
[
0
])
b
.
append
(
concat_weight
[
num_layers
*
2
+
i
*
2
+
j
]
.
args
[
0
])
b
.
append
(
concat_weight
[
num_layers
*
2
+
i
*
2
+
j
]
.
args
[
0
])
for
state
in
conca
t_states
:
for
state
in
ini
t_states
:
s
.
append
(
_op
.
take
(
state
,
_expr
.
const
(
i
,
"int32"
),
axis
=
0
))
s
.
append
(
_op
.
take
(
state
,
_expr
.
const
(
i
,
"int32"
),
axis
=
0
))
weights
.
append
(
w
)
weights
.
append
(
w
)
bias
.
append
(
b
)
bias
.
append
(
b
)
...
@@ -789,8 +819,9 @@ def _mx_rnn_layer(inputs, attrs):
...
@@ -789,8 +819,9 @@ def _mx_rnn_layer(inputs, attrs):
seq_output
.
append
(
out
)
seq_output
.
append
(
out
)
outputs
=
[
_op
.
stack
(
seq_output
,
axis
=
0
)]
outputs
=
[
_op
.
stack
(
seq_output
,
axis
=
0
)]
for
i
in
range
(
num_states
):
if
output_states
:
outputs
.
append
(
_op
.
stack
([
s
[
i
]
for
s
in
states
],
axis
=
0
))
for
i
in
range
(
num_states
):
outputs
.
append
(
_op
.
stack
([
s
[
i
]
for
s
in
states
],
axis
=
0
))
return
outputs
return
outputs
...
@@ -881,7 +912,6 @@ _convert_map = {
...
@@ -881,7 +912,6 @@ _convert_map = {
"argmin"
:
_arg_reduce
(
_op
.
argmin
),
"argmin"
:
_arg_reduce
(
_op
.
argmin
),
# init ops
# init ops
"_ones"
:
_init_op
(
_op
.
ones
),
"_ones"
:
_init_op
(
_op
.
ones
),
"_zeros"
:
_init_op
(
_op
.
zeros
),
# softmax
# softmax
"softmax"
:
_softmax_op
(
_op
.
nn
.
softmax
),
"softmax"
:
_softmax_op
(
_op
.
nn
.
softmax
),
"log_softmax"
:
_softmax_op
(
_op
.
nn
.
log_softmax
),
"log_softmax"
:
_softmax_op
(
_op
.
nn
.
log_softmax
),
...
@@ -895,6 +925,7 @@ _convert_map = {
...
@@ -895,6 +925,7 @@ _convert_map = {
"UpSampling"
:
_upsampling
,
"UpSampling"
:
_upsampling
,
"add_n"
:
_elemwise_sum
,
"add_n"
:
_elemwise_sum
,
# MXNet specific implementations
# MXNet specific implementations
"_zeros"
:
_mx_zeros
,
"FullyConnected"
:
_mx_fully_connected
,
"FullyConnected"
:
_mx_fully_connected
,
"Activation"
:
_mx_activations
,
"Activation"
:
_mx_activations
,
"Convolution"
:
_mx_conv2d
,
"Convolution"
:
_mx_conv2d
,
...
@@ -1002,7 +1033,10 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info):
...
@@ -1002,7 +1033,10 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info):
node_map
[
nid
]
=
[
_expr
.
var
(
node_name
,
shape
=
shape
,
dtype
=
dtype
)]
node_map
[
nid
]
=
[
_expr
.
var
(
node_name
,
shape
=
shape
,
dtype
=
dtype
)]
elif
op_name
in
_convert_map
:
elif
op_name
in
_convert_map
:
res
=
_convert_map
[
op_name
](
children
,
attrs
)
res
=
_convert_map
[
op_name
](
children
,
attrs
)
if
isinstance
(
res
,
(
_expr
.
TupleWrapper
,
tuple
,
list
)):
if
res
is
None
:
# defer conversion, used in RNN state initialization
res
=
[
node
]
elif
isinstance
(
res
,
(
_expr
.
TupleWrapper
,
tuple
,
list
)):
pass
pass
elif
isinstance
(
res
,
_expr
.
Expr
):
elif
isinstance
(
res
,
_expr
.
Expr
):
res
=
[
res
]
res
=
[
res
]
...
...
tests/python/frontend/mxnet/test_forward.py
View file @
29ee8a23
...
@@ -536,7 +536,7 @@ def test_forward_bilinear_resize():
...
@@ -536,7 +536,7 @@ def test_forward_bilinear_resize():
verify_mxnet_frontend_impl
(
mx_sym
,
(
1
,
2
,
3
,
4
),
(
1
,
2
,
5
,
10
))
verify_mxnet_frontend_impl
(
mx_sym
,
(
1
,
2
,
3
,
4
),
(
1
,
2
,
5
,
10
))
def
test_forward_rnn_layer
():
def
test_forward_rnn_layer
():
def
verify
(
mode
,
input_size
,
seq_len
,
hidden_size
,
num_layers
,
batch
=
1
):
def
verify
(
mode
,
input_size
,
seq_len
,
hidden_size
,
num_layers
,
init_states
=
True
):
if
mode
==
"rnn"
:
if
mode
==
"rnn"
:
layer
=
gluon
.
rnn
.
RNN
(
hidden_size
,
num_layers
)
layer
=
gluon
.
rnn
.
RNN
(
hidden_size
,
num_layers
)
elif
mode
==
"gru"
:
elif
mode
==
"gru"
:
...
@@ -545,23 +545,31 @@ def test_forward_rnn_layer():
...
@@ -545,23 +545,31 @@ def test_forward_rnn_layer():
layer
=
gluon
.
rnn
.
LSTM
(
hidden_size
,
num_layers
)
layer
=
gluon
.
rnn
.
LSTM
(
hidden_size
,
num_layers
)
num_states
=
2
if
mode
==
"lstm"
else
1
num_states
=
2
if
mode
==
"lstm"
else
1
layer
.
initialize
()
layer
.
initialize
()
layer
.
hybridize
()
dtype
=
"float32"
dtype
=
"float32"
batch
=
1
data_np
=
np
.
random
.
uniform
(
size
=
(
seq_len
,
batch
,
input_size
))
.
astype
(
dtype
)
data_np
=
np
.
random
.
uniform
(
size
=
(
seq_len
,
batch
,
input_size
))
.
astype
(
dtype
)
states_np
=
[]
data_mx
=
mx
.
nd
.
array
(
data_np
)
states_mx
=
[]
shape_dict
=
{
'data0'
:
data_np
.
shape
}
if
init_states
:
inputs
=
{
'data0'
:
data_np
}
shape_dict
=
{
'data0'
:
data_np
.
shape
}
for
i
in
range
(
num_states
):
inputs
=
{
'data0'
:
data_np
}
s
=
np
.
random
.
uniform
(
size
=
(
num_layers
,
batch
,
hidden_size
))
.
astype
(
dtype
)
states_np
=
[]
states_np
.
append
(
s
)
states_mx
=
[]
states_mx
.
append
(
mx
.
nd
.
array
(
s
))
for
i
in
range
(
num_states
):
shape_dict
[
'data
%
s'
%
(
i
+
1
)]
=
s
.
shape
s
=
np
.
random
.
uniform
(
size
=
(
num_layers
,
batch
,
hidden_size
))
.
astype
(
dtype
)
inputs
[
'data
%
s'
%
(
i
+
1
)]
=
s
states_np
.
append
(
s
)
states_mx
.
append
(
mx
.
nd
.
array
(
s
))
shape_dict
[
'data
%
s'
%
(
i
+
1
)]
=
s
.
shape
inputs
[
'data
%
s'
%
(
i
+
1
)]
=
s
mx_out
,
mx_states
=
layer
(
data_mx
,
states_mx
)
mx_res
=
[
mx_out
]
+
mx_states
else
:
shape_dict
=
{
'data'
:
data_np
.
shape
}
inputs
=
{
'data'
:
data_np
}
mx_res
=
layer
(
data_mx
)
layer
.
hybridize
()
mx_out
,
mx_states
=
layer
(
mx
.
nd
.
array
(
data_np
),
states_mx
)
mx_res
=
[
mx_out
]
+
mx_states
mx_sym
=
layer
.
_cached_graph
[
1
]
mx_sym
=
layer
.
_cached_graph
[
1
]
mx_params
=
{}
mx_params
=
{}
for
name
,
param
in
layer
.
collect_params
()
.
items
():
for
name
,
param
in
layer
.
collect_params
()
.
items
():
...
@@ -574,14 +582,20 @@ def test_forward_rnn_layer():
...
@@ -574,14 +582,20 @@ def test_forward_rnn_layer():
for
kind
in
[
"graph"
]:
for
kind
in
[
"graph"
]:
intrp
=
relay
.
create_executor
(
kind
,
ctx
=
ctx
,
target
=
target
)
intrp
=
relay
.
create_executor
(
kind
,
ctx
=
ctx
,
target
=
target
)
op_res
=
intrp
.
evaluate
(
new_sym
)(
**
inputs
,
**
params
)
op_res
=
intrp
.
evaluate
(
new_sym
)(
**
inputs
,
**
params
)
assert
len
(
op_res
)
==
len
(
mx_res
)
if
init_states
:
for
i
,
val
in
enumerate
(
op_res
):
assert
len
(
op_res
)
==
len
(
mx_res
)
tvm
.
testing
.
assert_allclose
(
val
.
asnumpy
(),
mx_res
[
i
]
.
asnumpy
(),
rtol
=
1e-3
)
for
i
,
val
in
enumerate
(
op_res
):
tvm
.
testing
.
assert_allclose
(
val
.
asnumpy
(),
mx_res
[
i
]
.
asnumpy
(),
rtol
=
1e-3
)
else
:
tvm
.
testing
.
assert_allclose
(
op_res
.
asnumpy
(),
mx_res
.
asnumpy
(),
rtol
=
1e-3
)
for
mode
in
[
"rnn"
,
"gru"
,
"lstm"
]:
for
mode
in
[
"rnn"
,
"gru"
,
"lstm"
]:
verify
(
mode
,
64
,
10
,
64
,
1
)
verify
(
mode
,
64
,
10
,
64
,
1
)
verify
(
mode
,
64
,
10
,
64
,
2
)
verify
(
mode
,
64
,
10
,
64
,
2
)
verify
(
mode
,
64
,
10
,
32
,
2
)
verify
(
mode
,
64
,
10
,
32
,
2
)
verify
(
mode
,
64
,
10
,
64
,
2
,
init_states
=
False
)
def
test_forward_Crop
():
def
test_forward_Crop
():
def
verify
(
xshape
,
yshape
,
offset
=
None
):
def
verify
(
xshape
,
yshape
,
offset
=
None
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment