Commit 29ee8a23 by Haichen Shen Committed by Tianqi Chen

[Relay][Frontend] Fix MxNet RNN without providing state initialization as input (#3326)

parent d0c45648
...@@ -93,6 +93,15 @@ def _mx_compare(new_op, wrapper): ...@@ -93,6 +93,15 @@ def _mx_compare(new_op, wrapper):
return impl return impl
def _mx_zeros(inputs, attrs):
assert len(inputs) == 0
shape = attrs.get_int_tuple("shape")
dtype = attrs.get_str("dtype", "float32")
if 0 in shape:
return None
return _op.zeros(shape=shape, dtype=dtype)
def _mx_conv2d(inputs, attrs): def _mx_conv2d(inputs, attrs):
kernel_size = attrs.get_int_tuple("kernel") kernel_size = attrs.get_int_tuple("kernel")
if len(kernel_size) != 2: if len(kernel_size) != 2:
...@@ -754,9 +763,30 @@ def _mx_rnn_layer(inputs, attrs): ...@@ -754,9 +763,30 @@ def _mx_rnn_layer(inputs, attrs):
seq_data = inputs[0] seq_data = inputs[0]
concat_weight = inputs[1] concat_weight = inputs[1]
concat_states = inputs[2:] init_states = inputs[2:]
seq_len = int(ir_pass.infer_type(seq_data).checked_type.shape[0])
data_shape = ir_pass.infer_type(seq_data).checked_type.shape
seq_len = int(data_shape[0])
assert len(concat_weight) == num_layers * 4 assert len(concat_weight) == num_layers * 4
output_states = True
for idx, state in enumerate(init_states[:]):
if isinstance(state, dict):
node = state
attrs = StrAttrsDict(node.get("attrs", {}))
op_name = node["op"]
# by default, RNN layer uses zeros to initialize states
assert op_name == "_zeros"
shape = attrs.get_int_tuple("shape")
dtype = attrs.get_str("dtype", "float32")
init_layout = attrs.get_str("__layout__")
new_shape = list(shape)
for i, dim in enumerate(shape):
if dim == 0:
axis = layout.find(init_layout[i])
assert axis >= 0
new_shape[i] = int(data_shape[axis])
init_states[idx] = _op.zeros(new_shape, dtype)
output_states = False
weights = [] weights = []
bias = [] bias = []
...@@ -768,7 +798,7 @@ def _mx_rnn_layer(inputs, attrs): ...@@ -768,7 +798,7 @@ def _mx_rnn_layer(inputs, attrs):
for j in range(2): for j in range(2):
w.append(concat_weight[i*2 + j].args[0]) w.append(concat_weight[i*2 + j].args[0])
b.append(concat_weight[num_layers*2 + i*2 + j].args[0]) b.append(concat_weight[num_layers*2 + i*2 + j].args[0])
for state in concat_states: for state in init_states:
s.append(_op.take(state, _expr.const(i, "int32"), axis=0)) s.append(_op.take(state, _expr.const(i, "int32"), axis=0))
weights.append(w) weights.append(w)
bias.append(b) bias.append(b)
...@@ -789,8 +819,9 @@ def _mx_rnn_layer(inputs, attrs): ...@@ -789,8 +819,9 @@ def _mx_rnn_layer(inputs, attrs):
seq_output.append(out) seq_output.append(out)
outputs = [_op.stack(seq_output, axis=0)] outputs = [_op.stack(seq_output, axis=0)]
for i in range(num_states): if output_states:
outputs.append(_op.stack([s[i] for s in states], axis=0)) for i in range(num_states):
outputs.append(_op.stack([s[i] for s in states], axis=0))
return outputs return outputs
...@@ -881,7 +912,6 @@ _convert_map = { ...@@ -881,7 +912,6 @@ _convert_map = {
"argmin" : _arg_reduce(_op.argmin), "argmin" : _arg_reduce(_op.argmin),
# init ops # init ops
"_ones" : _init_op(_op.ones), "_ones" : _init_op(_op.ones),
"_zeros" : _init_op(_op.zeros),
# softmax # softmax
"softmax" : _softmax_op(_op.nn.softmax), "softmax" : _softmax_op(_op.nn.softmax),
"log_softmax" : _softmax_op(_op.nn.log_softmax), "log_softmax" : _softmax_op(_op.nn.log_softmax),
...@@ -895,6 +925,7 @@ _convert_map = { ...@@ -895,6 +925,7 @@ _convert_map = {
"UpSampling" : _upsampling, "UpSampling" : _upsampling,
"add_n" : _elemwise_sum, "add_n" : _elemwise_sum,
# MXNet specific implementations # MXNet specific implementations
"_zeros" : _mx_zeros,
"FullyConnected": _mx_fully_connected, "FullyConnected": _mx_fully_connected,
"Activation" : _mx_activations, "Activation" : _mx_activations,
"Convolution" : _mx_conv2d, "Convolution" : _mx_conv2d,
...@@ -1002,7 +1033,10 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info): ...@@ -1002,7 +1033,10 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info):
node_map[nid] = [_expr.var(node_name, shape=shape, dtype=dtype)] node_map[nid] = [_expr.var(node_name, shape=shape, dtype=dtype)]
elif op_name in _convert_map: elif op_name in _convert_map:
res = _convert_map[op_name](children, attrs) res = _convert_map[op_name](children, attrs)
if isinstance(res, (_expr.TupleWrapper, tuple, list)): if res is None:
# defer conversion, used in RNN state initialization
res = [node]
elif isinstance(res, (_expr.TupleWrapper, tuple, list)):
pass pass
elif isinstance(res, _expr.Expr): elif isinstance(res, _expr.Expr):
res = [res] res = [res]
......
...@@ -536,7 +536,7 @@ def test_forward_bilinear_resize(): ...@@ -536,7 +536,7 @@ def test_forward_bilinear_resize():
verify_mxnet_frontend_impl(mx_sym, (1, 2, 3, 4), (1, 2, 5, 10)) verify_mxnet_frontend_impl(mx_sym, (1, 2, 3, 4), (1, 2, 5, 10))
def test_forward_rnn_layer(): def test_forward_rnn_layer():
def verify(mode, input_size, seq_len, hidden_size, num_layers, batch=1): def verify(mode, input_size, seq_len, hidden_size, num_layers, init_states=True):
if mode == "rnn": if mode == "rnn":
layer = gluon.rnn.RNN(hidden_size, num_layers) layer = gluon.rnn.RNN(hidden_size, num_layers)
elif mode == "gru": elif mode == "gru":
...@@ -545,23 +545,31 @@ def test_forward_rnn_layer(): ...@@ -545,23 +545,31 @@ def test_forward_rnn_layer():
layer = gluon.rnn.LSTM(hidden_size, num_layers) layer = gluon.rnn.LSTM(hidden_size, num_layers)
num_states = 2 if mode == "lstm" else 1 num_states = 2 if mode == "lstm" else 1
layer.initialize() layer.initialize()
layer.hybridize()
dtype = "float32" dtype = "float32"
batch = 1
data_np = np.random.uniform(size=(seq_len, batch, input_size)).astype(dtype) data_np = np.random.uniform(size=(seq_len, batch, input_size)).astype(dtype)
states_np = [] data_mx = mx.nd.array(data_np)
states_mx = []
shape_dict = {'data0': data_np.shape} if init_states:
inputs = {'data0': data_np} shape_dict = {'data0': data_np.shape}
for i in range(num_states): inputs = {'data0': data_np}
s = np.random.uniform(size=(num_layers, batch, hidden_size)).astype(dtype) states_np = []
states_np.append(s) states_mx = []
states_mx.append(mx.nd.array(s)) for i in range(num_states):
shape_dict['data%s' % (i+1)] = s.shape s = np.random.uniform(size=(num_layers, batch, hidden_size)).astype(dtype)
inputs['data%s' % (i+1)] = s states_np.append(s)
states_mx.append(mx.nd.array(s))
shape_dict['data%s' % (i+1)] = s.shape
inputs['data%s' % (i+1)] = s
mx_out, mx_states = layer(data_mx, states_mx)
mx_res = [mx_out] + mx_states
else:
shape_dict = {'data': data_np.shape}
inputs = {'data': data_np}
mx_res = layer(data_mx)
layer.hybridize()
mx_out, mx_states = layer(mx.nd.array(data_np), states_mx)
mx_res = [mx_out] + mx_states
mx_sym = layer._cached_graph[1] mx_sym = layer._cached_graph[1]
mx_params = {} mx_params = {}
for name, param in layer.collect_params().items(): for name, param in layer.collect_params().items():
...@@ -574,14 +582,20 @@ def test_forward_rnn_layer(): ...@@ -574,14 +582,20 @@ def test_forward_rnn_layer():
for kind in ["graph"]: for kind in ["graph"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target) intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)(**inputs, **params) op_res = intrp.evaluate(new_sym)(**inputs, **params)
assert len(op_res) == len(mx_res) if init_states:
for i, val in enumerate(op_res): assert len(op_res) == len(mx_res)
tvm.testing.assert_allclose(val.asnumpy(), mx_res[i].asnumpy(), rtol=1e-3) for i, val in enumerate(op_res):
tvm.testing.assert_allclose(
val.asnumpy(), mx_res[i].asnumpy(), rtol=1e-3)
else:
tvm.testing.assert_allclose(
op_res.asnumpy(), mx_res.asnumpy(), rtol=1e-3)
for mode in ["rnn", "gru", "lstm"]: for mode in ["rnn", "gru", "lstm"]:
verify(mode, 64, 10, 64, 1) verify(mode, 64, 10, 64, 1)
verify(mode, 64, 10, 64, 2) verify(mode, 64, 10, 64, 2)
verify(mode, 64, 10, 32, 2) verify(mode, 64, 10, 32, 2)
verify(mode, 64, 10, 64, 2, init_states=False)
def test_forward_Crop(): def test_forward_Crop():
def verify(xshape, yshape, offset=None): def verify(xshape, yshape, offset=None):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment