Commit 29ee8a23 by Haichen Shen Committed by Tianqi Chen

[Relay][Frontend] Fix MxNet RNN without providing state initialization as input (#3326)

parent d0c45648
......@@ -93,6 +93,15 @@ def _mx_compare(new_op, wrapper):
return impl
def _mx_zeros(inputs, attrs):
assert len(inputs) == 0
shape = attrs.get_int_tuple("shape")
dtype = attrs.get_str("dtype", "float32")
if 0 in shape:
return None
return _op.zeros(shape=shape, dtype=dtype)
def _mx_conv2d(inputs, attrs):
kernel_size = attrs.get_int_tuple("kernel")
if len(kernel_size) != 2:
......@@ -754,9 +763,30 @@ def _mx_rnn_layer(inputs, attrs):
seq_data = inputs[0]
concat_weight = inputs[1]
concat_states = inputs[2:]
seq_len = int(ir_pass.infer_type(seq_data).checked_type.shape[0])
init_states = inputs[2:]
data_shape = ir_pass.infer_type(seq_data).checked_type.shape
seq_len = int(data_shape[0])
assert len(concat_weight) == num_layers * 4
output_states = True
for idx, state in enumerate(init_states[:]):
if isinstance(state, dict):
node = state
attrs = StrAttrsDict(node.get("attrs", {}))
op_name = node["op"]
# by default, RNN layer uses zeros to initialize states
assert op_name == "_zeros"
shape = attrs.get_int_tuple("shape")
dtype = attrs.get_str("dtype", "float32")
init_layout = attrs.get_str("__layout__")
new_shape = list(shape)
for i, dim in enumerate(shape):
if dim == 0:
axis = layout.find(init_layout[i])
assert axis >= 0
new_shape[i] = int(data_shape[axis])
init_states[idx] = _op.zeros(new_shape, dtype)
output_states = False
weights = []
bias = []
......@@ -768,7 +798,7 @@ def _mx_rnn_layer(inputs, attrs):
for j in range(2):
w.append(concat_weight[i*2 + j].args[0])
b.append(concat_weight[num_layers*2 + i*2 + j].args[0])
for state in concat_states:
for state in init_states:
s.append(_op.take(state, _expr.const(i, "int32"), axis=0))
weights.append(w)
bias.append(b)
......@@ -789,6 +819,7 @@ def _mx_rnn_layer(inputs, attrs):
seq_output.append(out)
outputs = [_op.stack(seq_output, axis=0)]
if output_states:
for i in range(num_states):
outputs.append(_op.stack([s[i] for s in states], axis=0))
return outputs
......@@ -881,7 +912,6 @@ _convert_map = {
"argmin" : _arg_reduce(_op.argmin),
# init ops
"_ones" : _init_op(_op.ones),
"_zeros" : _init_op(_op.zeros),
# softmax
"softmax" : _softmax_op(_op.nn.softmax),
"log_softmax" : _softmax_op(_op.nn.log_softmax),
......@@ -895,6 +925,7 @@ _convert_map = {
"UpSampling" : _upsampling,
"add_n" : _elemwise_sum,
# MXNet specific implementations
"_zeros" : _mx_zeros,
"FullyConnected": _mx_fully_connected,
"Activation" : _mx_activations,
"Convolution" : _mx_conv2d,
......@@ -1002,7 +1033,10 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info):
node_map[nid] = [_expr.var(node_name, shape=shape, dtype=dtype)]
elif op_name in _convert_map:
res = _convert_map[op_name](children, attrs)
if isinstance(res, (_expr.TupleWrapper, tuple, list)):
if res is None:
# defer conversion, used in RNN state initialization
res = [node]
elif isinstance(res, (_expr.TupleWrapper, tuple, list)):
pass
elif isinstance(res, _expr.Expr):
res = [res]
......
......@@ -536,7 +536,7 @@ def test_forward_bilinear_resize():
verify_mxnet_frontend_impl(mx_sym, (1, 2, 3, 4), (1, 2, 5, 10))
def test_forward_rnn_layer():
def verify(mode, input_size, seq_len, hidden_size, num_layers, batch=1):
def verify(mode, input_size, seq_len, hidden_size, num_layers, init_states=True):
if mode == "rnn":
layer = gluon.rnn.RNN(hidden_size, num_layers)
elif mode == "gru":
......@@ -545,23 +545,31 @@ def test_forward_rnn_layer():
layer = gluon.rnn.LSTM(hidden_size, num_layers)
num_states = 2 if mode == "lstm" else 1
layer.initialize()
layer.hybridize()
dtype = "float32"
batch = 1
data_np = np.random.uniform(size=(seq_len, batch, input_size)).astype(dtype)
states_np = []
states_mx = []
data_mx = mx.nd.array(data_np)
if init_states:
shape_dict = {'data0': data_np.shape}
inputs = {'data0': data_np}
states_np = []
states_mx = []
for i in range(num_states):
s = np.random.uniform(size=(num_layers, batch, hidden_size)).astype(dtype)
states_np.append(s)
states_mx.append(mx.nd.array(s))
shape_dict['data%s' % (i+1)] = s.shape
inputs['data%s' % (i+1)] = s
layer.hybridize()
mx_out, mx_states = layer(mx.nd.array(data_np), states_mx)
mx_out, mx_states = layer(data_mx, states_mx)
mx_res = [mx_out] + mx_states
else:
shape_dict = {'data': data_np.shape}
inputs = {'data': data_np}
mx_res = layer(data_mx)
mx_sym = layer._cached_graph[1]
mx_params = {}
for name, param in layer.collect_params().items():
......@@ -574,14 +582,20 @@ def test_forward_rnn_layer():
for kind in ["graph"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)(**inputs, **params)
if init_states:
assert len(op_res) == len(mx_res)
for i, val in enumerate(op_res):
tvm.testing.assert_allclose(val.asnumpy(), mx_res[i].asnumpy(), rtol=1e-3)
tvm.testing.assert_allclose(
val.asnumpy(), mx_res[i].asnumpy(), rtol=1e-3)
else:
tvm.testing.assert_allclose(
op_res.asnumpy(), mx_res.asnumpy(), rtol=1e-3)
for mode in ["rnn", "gru", "lstm"]:
verify(mode, 64, 10, 64, 1)
verify(mode, 64, 10, 64, 2)
verify(mode, 64, 10, 32, 2)
verify(mode, 64, 10, 64, 2, init_states=False)
def test_forward_Crop():
def verify(xshape, yshape, offset=None):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment