Commit 7db5779f by ttyang1018 Committed by Tianqi Chen

[Relay][Frontend] Fix tensorflow frontend lstm forget bias adding order (#3410)

parent 6c43019b
......@@ -1437,9 +1437,8 @@ def _LSTMBlockCell():
gate_list = _op.split(gates_bias, indices_or_sections=4, axis=1)
in_gate = _op.sigmoid(gate_list[0])
in_transform = _op.tanh(gate_list[1])
forget_gate = _op.sigmoid(gate_list[2])
forget_gate = _op.add(forget_gate,
tvm.relay.const(forget_bias, attr['T'].name))
forget_gate = _op.add(gate_list[2], tvm.relay.const(forget_bias, attr['T'].name))
forget_gate = _op.sigmoid(forget_gate)
out_gate = _op.sigmoid(gate_list[3])
next_c = _op.add(_op.multiply(forget_gate, in_state_c),
_op.multiply(in_gate, in_transform))
......
......@@ -1183,7 +1183,7 @@ def _test_lstm_cell(batch_size, num_hidden, num_layers, forget_bias, dtype):
def test_forward_lstm():
'''test LSTM block cell'''
_test_lstm_cell(1, 2, 1, 0.0, 'float32')
_test_lstm_cell(1, 2, 1, 0.5, 'float32')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment