Commit 56299010 by Haichen Shen Committed by Yao Wang

[Frontend][MxNet] Support bidirectional RNN layer (#3397)

* Support bidirectional RNN layer

* tweak

* tweak
parent b98e2c76
......@@ -748,13 +748,12 @@ def _mx_rnn_layer(inputs, attrs):
num_layers = attrs.get_int("num_layers", 1)
mode = attrs.get_str("mode")
output_states = attrs.get_bool("state_outputs", False)
if mode.startswith("rnn"):
mode, activation = mode.split('_')
assert mode in ["rnn", "gru", "lstm"]
bidirectional = attrs.get_bool("bidirectional", False)
if bidirectional:
raise tvm.error.OpAttributeUnimplemented(
"Bidirectional RNN op is not supported yet")
direct = 2 if bidirectional else 1
layout = attrs.get_str("layout", "TNC")
if layout != "TNC":
raise tvm.error.OpAttributeUnimplemented(
......@@ -765,11 +764,10 @@ def _mx_rnn_layer(inputs, attrs):
seq_data = inputs[0]
concat_weight = inputs[1]
init_states = inputs[2:]
data_shape = ir_pass.infer_type(seq_data).checked_type.shape
seq_len = int(data_shape[0])
assert len(concat_weight) == num_layers * 4
output_states = True
assert len(concat_weight) == num_layers * 4 * direct
for idx, state in enumerate(init_states[:]):
if isinstance(state, dict):
node = state
......@@ -787,43 +785,76 @@ def _mx_rnn_layer(inputs, attrs):
assert axis >= 0
new_shape[i] = int(data_shape[axis])
init_states[idx] = _op.zeros(new_shape, dtype)
output_states = False
weights = []
bias = []
states = []
back_weights = []
back_bias = []
back_states = []
for i in range(num_layers):
w = []
b = []
weights.append([concat_weight[i*2*direct].args[0],
concat_weight[i*2*direct + 1].args[0]])
bias.append([concat_weight[(num_layers+i)*2*direct].args[0],
concat_weight[(num_layers+i)*2*direct + 1].args[0]])
s = []
for j in range(2):
w.append(concat_weight[i*2 + j].args[0])
b.append(concat_weight[num_layers*2 + i*2 + j].args[0])
for state in init_states:
s.append(_op.take(state, _expr.const(i, "int32"), axis=0))
weights.append(w)
bias.append(b)
s.append(_op.take(state, _expr.const(i*direct, "int32"), axis=0))
states.append(s)
seq_output = []
for t in range(seq_len):
data = _op.take(seq_data, _expr.const(t, "int32"), axis=0)
for l in range(num_layers):
if bidirectional:
back_weights.append([concat_weight[i*2*direct + 2].args[0],
concat_weight[i*2*direct + 3].args[0]])
back_bias.append([concat_weight[(num_layers+i)*2*direct + 2].args[0],
concat_weight[(num_layers+i)*2*direct + 3].args[0]])
s = []
for state in init_states:
s.append(_op.take(state, _expr.const(i*direct+1, "int32"), axis=0))
back_states.append(s)
xs = [_op.take(seq_data, _expr.const(t, "int32"), axis=0) for t in range(seq_len)]
for l in range(num_layers):
outputs = []
back_outputs = []
for x in xs:
if mode == "rnn":
out, new_states = _rnn_cell(data, states[l], *weights[l], *bias[l], activation)
out, new_states = _rnn_cell(x, states[l], *weights[l], *bias[l], activation)
elif mode == "gru":
out, new_states = _gru_cell(data, states[l], *weights[l], *bias[l])
out, new_states = _gru_cell(x, states[l], *weights[l], *bias[l])
else: # mode == "lstm"
out, new_states = _lstm_cell(data, states[l], *weights[l], *bias[l])
out, new_states = _lstm_cell(x, states[l], *weights[l], *bias[l])
states[l] = new_states
data = out
seq_output.append(out)
outputs = [_op.stack(seq_output, axis=0)]
outputs.append(out)
if bidirectional:
for x in reversed(xs):
if mode == "rnn":
out, new_states = _rnn_cell(
x, back_states[l], *back_weights[l], *back_bias[l], activation)
elif mode == "gru":
out, new_states = _gru_cell(
x, back_states[l], *back_weights[l], *back_bias[l])
else: # mode == "lstm"
out, new_states = _lstm_cell(
x, back_states[l], *back_weights[l], *back_bias[l])
back_states[l] = new_states
back_outputs.append(out)
back_outputs.reverse()
concat_outputs = []
for t, out in enumerate(outputs):
new_out = _op.concatenate([out, back_outputs[t]], axis=-1)
concat_outputs.append(new_out)
outputs = concat_outputs
xs = outputs
ret = [_op.stack(outputs, axis=0)]
if output_states:
for i in range(num_states):
outputs.append(_op.stack([s[i] for s in states], axis=0))
return outputs
inputs = []
for l, s in enumerate(states):
inputs.append(s[i])
if bidirectional:
inputs.append(back_states[l][i])
ret.append(_op.stack(inputs, axis=0))
return ret
# Note: due to attribute conversion constraint
......
......@@ -536,29 +536,31 @@ def test_forward_bilinear_resize():
verify_mxnet_frontend_impl(mx_sym, (1, 2, 3, 4), (1, 2, 5, 10))
def test_forward_rnn_layer():
def verify(mode, input_size, seq_len, hidden_size, num_layers, init_states=True):
def verify(mode, seq_len, input_size, hidden_size, num_layers,
batch=1, init_states=True, bidirectional=False):
if mode == "rnn":
layer = gluon.rnn.RNN(hidden_size, num_layers)
layer = gluon.rnn.RNN(hidden_size, num_layers, bidirectional=bidirectional)
elif mode == "gru":
layer = gluon.rnn.GRU(hidden_size, num_layers)
layer = gluon.rnn.GRU(hidden_size, num_layers, bidirectional=bidirectional)
else: # mode == "lstm"
layer = gluon.rnn.LSTM(hidden_size, num_layers)
layer = gluon.rnn.LSTM(hidden_size, num_layers, bidirectional=bidirectional)
num_states = 2 if mode == "lstm" else 1
layer.initialize()
layer.hybridize()
dtype = "float32"
batch = 1
directions = 2 if bidirectional else 1
data_np = np.random.uniform(size=(seq_len, batch, input_size)).astype(dtype)
data_mx = mx.nd.array(data_np)
if init_states:
shape_dict = {'data0': data_np.shape}
inputs = {'data0': data_np}
state_shape = (num_layers*directions, batch, hidden_size)
states_np = []
states_mx = []
for i in range(num_states):
s = np.random.uniform(size=(num_layers, batch, hidden_size)).astype(dtype)
s = np.random.uniform(size=state_shape).astype(dtype)
states_np.append(s)
states_mx.append(mx.nd.array(s))
shape_dict['data%s' % (i+1)] = s.shape
......@@ -592,10 +594,13 @@ def test_forward_rnn_layer():
op_res.asnumpy(), mx_res.asnumpy(), rtol=1e-3)
for mode in ["rnn", "gru", "lstm"]:
verify(mode, 64, 10, 64, 1)
verify(mode, 64, 10, 64, 2)
verify(mode, 64, 10, 32, 2)
verify(mode, 64, 10, 64, 2, init_states=False)
verify(mode, 1, 64, 64, 1)
verify(mode, 10, 64, 64, 2)
verify(mode, 10, 64, 32, 2)
verify(mode, 10, 64, 32, 2, batch=2)
verify(mode, 10, 64, 64, 3, init_states=False)
verify(mode, 10, 32, 64, 1, bidirectional=True)
verify(mode, 10, 64, 64, 3, batch=2, bidirectional=True, init_states=False)
def test_forward_Crop():
def verify(xshape, yshape, offset=None):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment