Commit 54a115ef by Siju Committed by Tianqi Chen

[FRONTEND][DARKNET]LSTM and GRU support (#1576)

parent bc7431bd
...@@ -412,7 +412,12 @@ class GraphProto(object): ...@@ -412,7 +412,12 @@ class GraphProto(object):
self._sym_array = {} self._sym_array = {}
self._tvmparams = {} self._tvmparams = {}
self._outs = [] self._outs = []
self._rnn_state_ctr = 0 self._state_ctr = {}
self._state_ctr['rnn'] = 0
self._state_ctr['crnn'] = 0
self._state_ctr['lstm'] = 0
self._state_ctr['cell_state'] = 0
self._state_ctr['gru'] = 0
def _read_memory_buffer(self, shape, data): def _read_memory_buffer(self, shape, data):
length = 1 length = 1
...@@ -623,16 +628,16 @@ class GraphProto(object): ...@@ -623,16 +628,16 @@ class GraphProto(object):
"""Returs the layer name.""" """Returs the layer name."""
return layer.type return layer.type
def _new_rnn_state_sym(self, state=None): def _new_rnn_state_sym(self, state=None, name='rnn'):
"""Returs a symbol for state""" """Returs a symbol for state"""
name = "rnn%d_state" % (self._rnn_state_ctr) sym_name = name + "%d_state" % self._state_ctr[name]
self._rnn_state_ctr += 1 self._state_ctr[name] += 1
return _sym.Variable(name=name, init=state) return _sym.Variable(name=sym_name, init=state)
def _get_rnn_state_buffer(self, layer): def _get_rnn_state_buffer(self, layer, name):
"""Get the state buffer for rnn.""" """Get the state buffer for rnn."""
buffer = np.zeros((1, layer.outputs), self.dtype) buffer = np.zeros((1, layer.outputs), self.dtype)
return self._new_rnn_state_sym(buffer) return self._new_rnn_state_sym(buffer, name)
def _get_darknet_rnn_attrs(self, layer, sym): def _get_darknet_rnn_attrs(self, layer, sym):
"""Get the rnn converted symbol from attributes.""" """Get the rnn converted symbol from attributes."""
...@@ -653,7 +658,7 @@ class GraphProto(object): ...@@ -653,7 +658,7 @@ class GraphProto(object):
attr.update({'batch' : layer.batch}) attr.update({'batch' : layer.batch})
attr.update({'num_hidden' : str(layer.outputs)}) attr.update({'num_hidden' : str(layer.outputs)})
state = self._get_rnn_state_buffer(layer) state = self._get_rnn_state_buffer(layer, 'rnn')
for _ in range(layer.steps): for _ in range(layer.steps):
input_layer = layer.input_layer input_layer = layer.input_layer
...@@ -678,7 +683,7 @@ class GraphProto(object): ...@@ -678,7 +683,7 @@ class GraphProto(object):
attr.update({'batch' : layer.batch}) attr.update({'batch' : layer.batch})
attr.update({'num_hidden' : str(layer.outputs)}) attr.update({'num_hidden' : str(layer.outputs)})
state = self._get_rnn_state_buffer(layer) state = self._get_rnn_state_buffer(layer, 'crnn')
for _ in range(layer.steps): for _ in range(layer.steps):
input_layer = layer.input_layer input_layer = layer.input_layer
...@@ -698,6 +703,123 @@ class GraphProto(object): ...@@ -698,6 +703,123 @@ class GraphProto(object):
self._sym_array[layer_num] = sym self._sym_array[layer_num] = sym
processed = True processed = True
elif LAYERTYPE.LSTM == layer.type:
if layer.steps > 1:
raise NotImplementedError("Currently support only single step GRU")
op_name_add = 'elemwise_add'
op_name_mul = 'elemwise_mul'
attrs = {}
act_attr = {}
h_state = self._get_rnn_state_buffer(layer, 'lstm')
c_state = self._get_rnn_state_buffer(layer, 'cell_state')
for _ in range(layer.steps):
sym_wf = self._get_darknet_rnn_attrs(layer.wf, h_state)
sym_wi = self._get_darknet_rnn_attrs(layer.wi, h_state)
sym_wg = self._get_darknet_rnn_attrs(layer.wg, h_state)
sym_wo = self._get_darknet_rnn_attrs(layer.wo, h_state)
input_sym = sym
sym_uf = self._get_darknet_rnn_attrs(layer.uf, input_sym)
sym_ui = self._get_darknet_rnn_attrs(layer.ui, input_sym)
sym_ug = self._get_darknet_rnn_attrs(layer.ug, input_sym)
sym_uo = self._get_darknet_rnn_attrs(layer.uo, input_sym)
new_inputs = _as_list([sym_wf, sym_uf])
add_f = _darknet_get_nnvm_op(op_name_add)(*new_inputs, **attrs)
new_inputs = _as_list([sym_wi, sym_ui])
add_i = _darknet_get_nnvm_op(op_name_add)(*new_inputs, **attrs)
new_inputs = _as_list([sym_wg, sym_ug])
add_g = _darknet_get_nnvm_op(op_name_add)(*new_inputs, **attrs)
new_inputs = _as_list([sym_wo, sym_uo])
add_o = _darknet_get_nnvm_op(op_name_add)(*new_inputs, **attrs)
act_attr['activation'] = ACTIVATION.LOGISTIC
act_f, _ = _darknet_activations(_as_list(add_f), act_attr)
act_attr['activation'] = ACTIVATION.LOGISTIC
act_i, _ = _darknet_activations(_as_list(add_i), act_attr)
act_attr['activation'] = ACTIVATION.TANH
act_g, _ = _darknet_activations(_as_list(add_g), act_attr)
act_attr['activation'] = ACTIVATION.LOGISTIC
act_o, _ = _darknet_activations(_as_list(add_o), act_attr)
new_inputs = _as_list([act_i, act_g])
mul_t = _darknet_get_nnvm_op(op_name_mul)(*new_inputs, **attrs)
new_inputs = _as_list([act_f, c_state])
c_state = _darknet_get_nnvm_op(op_name_mul)(*new_inputs, **attrs)
new_inputs = _as_list([mul_t, c_state])
c_state = _darknet_get_nnvm_op(op_name_add)(*new_inputs, **attrs)
act_attr['activation'] = ACTIVATION.TANH
h_state, _ = _darknet_activations(_as_list(c_state), act_attr)
new_inputs = _as_list([act_o, h_state])
h_state = _darknet_get_nnvm_op(op_name_mul)(*new_inputs, **attrs)
self._outs = self._outs + [c_state, h_state]
sym = h_state
self._sym_array[layer_num] = sym
processed = True
elif LAYERTYPE.GRU == layer.type:
if layer.steps > 1:
raise NotImplementedError("Currently support only single step GRU")
op_name_add = 'elemwise_add'
op_name_mul = 'elemwise_mul'
attrs = {}
act_attr = {}
state = self._get_rnn_state_buffer(layer, "gru")
for _ in range(layer.steps):
sym_wz = self._get_darknet_rnn_attrs(layer.wz, state)
sym_wr = self._get_darknet_rnn_attrs(layer.wr, state)
input_sym = sym
sym_uz = self._get_darknet_rnn_attrs(layer.uz, input_sym)
sym_ur = self._get_darknet_rnn_attrs(layer.ur, input_sym)
sym_uh = self._get_darknet_rnn_attrs(layer.uh, input_sym)
new_inputs = _as_list([sym_uz, sym_wz])
add_z = _darknet_get_nnvm_op(op_name_add)(*new_inputs, **attrs)
new_inputs = _as_list([sym_ur, sym_wr])
add_r = _darknet_get_nnvm_op(op_name_add)(*new_inputs, **attrs)
act_attr['activation'] = ACTIVATION.LOGISTIC
act_z, _ = _darknet_activations(_as_list(add_z), act_attr)
act_attr['activation'] = ACTIVATION.LOGISTIC
act_r, _ = _darknet_activations(_as_list(add_r), act_attr)
new_inputs = _as_list([act_r, state])
forgot = _darknet_get_nnvm_op(op_name_mul)(*new_inputs, **attrs)
sym_wh = self._get_darknet_rnn_attrs(layer.wh, forgot)
new_inputs = _as_list([sym_uh, sym_wh])
h_state = _darknet_get_nnvm_op(op_name_add)(*new_inputs, **attrs)
if layer.tanh == 1:
act_attr['activation'] = ACTIVATION.TANH
else:
act_attr['activation'] = ACTIVATION.LOGISTIC
h_state, _ = _darknet_activations(_as_list(h_state), act_attr)
sym = act_z * state + (1 - act_z) * h_state
self._outs = self._outs + [sym]
self._sym_array[layer_num] = sym
processed = True
return processed, sym return processed, sym
def from_darknet(self): def from_darknet(self):
......
...@@ -491,6 +491,9 @@ layer make_reorg_layer(int batch, int w, int h, int c, int stride, int reverse, ...@@ -491,6 +491,9 @@ layer make_reorg_layer(int batch, int w, int h, int c, int stride, int reverse,
layer make_region_layer(int batch, int w, int h, int n, int classes, int coords); layer make_region_layer(int batch, int w, int h, int n, int classes, int coords);
layer make_softmax_layer(int batch, int inputs, int groups); layer make_softmax_layer(int batch, int inputs, int groups);
layer make_rnn_layer(int batch, int inputs, int outputs, int steps, ACTIVATION activation, int batch_normalize, int adam); layer make_rnn_layer(int batch, int inputs, int outputs, int steps, ACTIVATION activation, int batch_normalize, int adam);
layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int output_filters, int steps, ACTIVATION activation, int batch_normalize);
layer make_lstm_layer(int batch, int inputs, int outputs, int steps, int batch_normalize, int adam);
layer make_gru_layer(int batch, int inputs, int outputs, int steps, int batch_normalize, int adam);
void free_network(network *net); void free_network(network *net);
""" """
) )
...@@ -306,7 +306,7 @@ def test_forward_softmax_temperature(): ...@@ -306,7 +306,7 @@ def test_forward_softmax_temperature():
LIB.free_network(net) LIB.free_network(net)
def test_forward_rnn(): def test_forward_rnn():
'''test softmax layer''' '''test RNN layer'''
net = LIB.make_network(1) net = LIB.make_network(1)
batch = 1 batch = 1
inputs = 256 inputs = 256
...@@ -325,7 +325,7 @@ def test_forward_rnn(): ...@@ -325,7 +325,7 @@ def test_forward_rnn():
LIB.free_network(net) LIB.free_network(net)
def test_forward_crnn(): def test_forward_crnn():
'''test softmax layer''' '''test CRNN layer'''
net = LIB.make_network(1) net = LIB.make_network(1)
batch = 1 batch = 1
c = 3 c = 3
...@@ -349,6 +349,42 @@ def test_forward_crnn(): ...@@ -349,6 +349,42 @@ def test_forward_crnn():
test_forward(net) test_forward(net)
LIB.free_network(net) LIB.free_network(net)
def test_forward_lstm():
'''test LSTM layer'''
net = LIB.make_network(1)
batch = 1
inputs = 256
outputs = 256
steps = 1
batch_normalize = 0
adam = 0
layer_1 = LIB.make_lstm_layer(batch, inputs, outputs, steps, batch_normalize, adam)
net.layers[0] = layer_1
net.inputs = inputs
net.outputs = outputs
net.w = net.h = 0
LIB.resize_network(net, net.w, net.h)
test_rnn_forward(net)
LIB.free_network(net)
def test_forward_gru():
'''test GRU layer'''
net = LIB.make_network(1)
batch = 1
inputs = 256
outputs = 256
steps = 1
batch_normalize = 0
adam = 0
layer_1 = LIB.make_gru_layer(batch, inputs, outputs, steps, batch_normalize, adam)
net.layers[0] = layer_1
net.inputs = inputs
net.outputs = outputs
net.w = net.h = 0
LIB.resize_network(net, net.w, net.h)
test_rnn_forward(net)
LIB.free_network(net)
def test_forward_activation_logistic(): def test_forward_activation_logistic():
'''test logistic activation layer''' '''test logistic activation layer'''
net = LIB.make_network(1) net = LIB.make_network(1)
...@@ -395,4 +431,6 @@ if __name__ == '__main__': ...@@ -395,4 +431,6 @@ if __name__ == '__main__':
test_forward_elu() test_forward_elu()
test_forward_rnn() test_forward_rnn()
test_forward_crnn() test_forward_crnn()
test_forward_activation_logistic() test_forward_lstm()
\ No newline at end of file test_forward_gru()
test_forward_activation_logistic()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment