Commit 56299010 by Haichen Shen Committed by Yao Wang

[Frontend][MxNet] Support bidirectional RNN layer (#3397)

* Support bidirectional RNN layer

* tweak

* tweak
parent b98e2c76
...@@ -748,13 +748,12 @@ def _mx_rnn_layer(inputs, attrs): ...@@ -748,13 +748,12 @@ def _mx_rnn_layer(inputs, attrs):
num_layers = attrs.get_int("num_layers", 1) num_layers = attrs.get_int("num_layers", 1)
mode = attrs.get_str("mode") mode = attrs.get_str("mode")
output_states = attrs.get_bool("state_outputs", False)
if mode.startswith("rnn"): if mode.startswith("rnn"):
mode, activation = mode.split('_') mode, activation = mode.split('_')
assert mode in ["rnn", "gru", "lstm"] assert mode in ["rnn", "gru", "lstm"]
bidirectional = attrs.get_bool("bidirectional", False) bidirectional = attrs.get_bool("bidirectional", False)
if bidirectional: direct = 2 if bidirectional else 1
raise tvm.error.OpAttributeUnimplemented(
"Bidirectional RNN op is not supported yet")
layout = attrs.get_str("layout", "TNC") layout = attrs.get_str("layout", "TNC")
if layout != "TNC": if layout != "TNC":
raise tvm.error.OpAttributeUnimplemented( raise tvm.error.OpAttributeUnimplemented(
...@@ -765,11 +764,10 @@ def _mx_rnn_layer(inputs, attrs): ...@@ -765,11 +764,10 @@ def _mx_rnn_layer(inputs, attrs):
seq_data = inputs[0] seq_data = inputs[0]
concat_weight = inputs[1] concat_weight = inputs[1]
init_states = inputs[2:] init_states = inputs[2:]
data_shape = ir_pass.infer_type(seq_data).checked_type.shape data_shape = ir_pass.infer_type(seq_data).checked_type.shape
seq_len = int(data_shape[0]) seq_len = int(data_shape[0])
assert len(concat_weight) == num_layers * 4 assert len(concat_weight) == num_layers * 4 * direct
output_states = True
for idx, state in enumerate(init_states[:]): for idx, state in enumerate(init_states[:]):
if isinstance(state, dict): if isinstance(state, dict):
node = state node = state
...@@ -787,43 +785,76 @@ def _mx_rnn_layer(inputs, attrs): ...@@ -787,43 +785,76 @@ def _mx_rnn_layer(inputs, attrs):
assert axis >= 0 assert axis >= 0
new_shape[i] = int(data_shape[axis]) new_shape[i] = int(data_shape[axis])
init_states[idx] = _op.zeros(new_shape, dtype) init_states[idx] = _op.zeros(new_shape, dtype)
output_states = False
weights = [] weights = []
bias = [] bias = []
states = [] states = []
back_weights = []
back_bias = []
back_states = []
for i in range(num_layers): for i in range(num_layers):
w = [] weights.append([concat_weight[i*2*direct].args[0],
b = [] concat_weight[i*2*direct + 1].args[0]])
bias.append([concat_weight[(num_layers+i)*2*direct].args[0],
concat_weight[(num_layers+i)*2*direct + 1].args[0]])
s = [] s = []
for j in range(2):
w.append(concat_weight[i*2 + j].args[0])
b.append(concat_weight[num_layers*2 + i*2 + j].args[0])
for state in init_states: for state in init_states:
s.append(_op.take(state, _expr.const(i, "int32"), axis=0)) s.append(_op.take(state, _expr.const(i*direct, "int32"), axis=0))
weights.append(w)
bias.append(b)
states.append(s) states.append(s)
if bidirectional:
seq_output = [] back_weights.append([concat_weight[i*2*direct + 2].args[0],
for t in range(seq_len): concat_weight[i*2*direct + 3].args[0]])
data = _op.take(seq_data, _expr.const(t, "int32"), axis=0) back_bias.append([concat_weight[(num_layers+i)*2*direct + 2].args[0],
for l in range(num_layers): concat_weight[(num_layers+i)*2*direct + 3].args[0]])
s = []
for state in init_states:
s.append(_op.take(state, _expr.const(i*direct+1, "int32"), axis=0))
back_states.append(s)
xs = [_op.take(seq_data, _expr.const(t, "int32"), axis=0) for t in range(seq_len)]
for l in range(num_layers):
outputs = []
back_outputs = []
for x in xs:
if mode == "rnn": if mode == "rnn":
out, new_states = _rnn_cell(data, states[l], *weights[l], *bias[l], activation) out, new_states = _rnn_cell(x, states[l], *weights[l], *bias[l], activation)
elif mode == "gru": elif mode == "gru":
out, new_states = _gru_cell(data, states[l], *weights[l], *bias[l]) out, new_states = _gru_cell(x, states[l], *weights[l], *bias[l])
else: # mode == "lstm" else: # mode == "lstm"
out, new_states = _lstm_cell(data, states[l], *weights[l], *bias[l]) out, new_states = _lstm_cell(x, states[l], *weights[l], *bias[l])
states[l] = new_states states[l] = new_states
data = out outputs.append(out)
seq_output.append(out) if bidirectional:
for x in reversed(xs):
outputs = [_op.stack(seq_output, axis=0)] if mode == "rnn":
out, new_states = _rnn_cell(
x, back_states[l], *back_weights[l], *back_bias[l], activation)
elif mode == "gru":
out, new_states = _gru_cell(
x, back_states[l], *back_weights[l], *back_bias[l])
else: # mode == "lstm"
out, new_states = _lstm_cell(
x, back_states[l], *back_weights[l], *back_bias[l])
back_states[l] = new_states
back_outputs.append(out)
back_outputs.reverse()
concat_outputs = []
for t, out in enumerate(outputs):
new_out = _op.concatenate([out, back_outputs[t]], axis=-1)
concat_outputs.append(new_out)
outputs = concat_outputs
xs = outputs
ret = [_op.stack(outputs, axis=0)]
if output_states: if output_states:
for i in range(num_states): for i in range(num_states):
outputs.append(_op.stack([s[i] for s in states], axis=0)) inputs = []
return outputs for l, s in enumerate(states):
inputs.append(s[i])
if bidirectional:
inputs.append(back_states[l][i])
ret.append(_op.stack(inputs, axis=0))
return ret
# Note: due to attribute conversion constraint # Note: due to attribute conversion constraint
......
...@@ -536,29 +536,31 @@ def test_forward_bilinear_resize(): ...@@ -536,29 +536,31 @@ def test_forward_bilinear_resize():
verify_mxnet_frontend_impl(mx_sym, (1, 2, 3, 4), (1, 2, 5, 10)) verify_mxnet_frontend_impl(mx_sym, (1, 2, 3, 4), (1, 2, 5, 10))
def test_forward_rnn_layer(): def test_forward_rnn_layer():
def verify(mode, input_size, seq_len, hidden_size, num_layers, init_states=True): def verify(mode, seq_len, input_size, hidden_size, num_layers,
batch=1, init_states=True, bidirectional=False):
if mode == "rnn": if mode == "rnn":
layer = gluon.rnn.RNN(hidden_size, num_layers) layer = gluon.rnn.RNN(hidden_size, num_layers, bidirectional=bidirectional)
elif mode == "gru": elif mode == "gru":
layer = gluon.rnn.GRU(hidden_size, num_layers) layer = gluon.rnn.GRU(hidden_size, num_layers, bidirectional=bidirectional)
else: # mode == "lstm" else: # mode == "lstm"
layer = gluon.rnn.LSTM(hidden_size, num_layers) layer = gluon.rnn.LSTM(hidden_size, num_layers, bidirectional=bidirectional)
num_states = 2 if mode == "lstm" else 1 num_states = 2 if mode == "lstm" else 1
layer.initialize() layer.initialize()
layer.hybridize() layer.hybridize()
dtype = "float32" dtype = "float32"
batch = 1 directions = 2 if bidirectional else 1
data_np = np.random.uniform(size=(seq_len, batch, input_size)).astype(dtype) data_np = np.random.uniform(size=(seq_len, batch, input_size)).astype(dtype)
data_mx = mx.nd.array(data_np) data_mx = mx.nd.array(data_np)
if init_states: if init_states:
shape_dict = {'data0': data_np.shape} shape_dict = {'data0': data_np.shape}
inputs = {'data0': data_np} inputs = {'data0': data_np}
state_shape = (num_layers*directions, batch, hidden_size)
states_np = [] states_np = []
states_mx = [] states_mx = []
for i in range(num_states): for i in range(num_states):
s = np.random.uniform(size=(num_layers, batch, hidden_size)).astype(dtype) s = np.random.uniform(size=state_shape).astype(dtype)
states_np.append(s) states_np.append(s)
states_mx.append(mx.nd.array(s)) states_mx.append(mx.nd.array(s))
shape_dict['data%s' % (i+1)] = s.shape shape_dict['data%s' % (i+1)] = s.shape
...@@ -592,10 +594,13 @@ def test_forward_rnn_layer(): ...@@ -592,10 +594,13 @@ def test_forward_rnn_layer():
op_res.asnumpy(), mx_res.asnumpy(), rtol=1e-3) op_res.asnumpy(), mx_res.asnumpy(), rtol=1e-3)
for mode in ["rnn", "gru", "lstm"]: for mode in ["rnn", "gru", "lstm"]:
verify(mode, 64, 10, 64, 1) verify(mode, 1, 64, 64, 1)
verify(mode, 64, 10, 64, 2) verify(mode, 10, 64, 64, 2)
verify(mode, 64, 10, 32, 2) verify(mode, 10, 64, 32, 2)
verify(mode, 64, 10, 64, 2, init_states=False) verify(mode, 10, 64, 32, 2, batch=2)
verify(mode, 10, 64, 64, 3, init_states=False)
verify(mode, 10, 32, 64, 1, bidirectional=True)
verify(mode, 10, 64, 64, 3, batch=2, bidirectional=True, init_states=False)
def test_forward_Crop(): def test_forward_Crop():
def verify(xshape, yshape, offset=None): def verify(xshape, yshape, offset=None):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment