Commit 303b00eb by Siju Committed by Tianqi Chen

C-RNN layer support is added (#1492)

parent b65a13dd
...@@ -673,6 +673,31 @@ class GraphProto(object): ...@@ -673,6 +673,31 @@ class GraphProto(object):
self._sym_array[layer_num] = sym self._sym_array[layer_num] = sym
processed = True processed = True
elif LAYERTYPE.CRNN == layer.type:
attr.update({'n' : layer.n})
attr.update({'batch' : layer.batch})
attr.update({'num_hidden' : str(layer.outputs)})
state = self._get_rnn_state_buffer(layer)
for _ in range(layer.steps):
input_layer = layer.input_layer
sym = self._get_darknet_rnn_attrs(input_layer, sym)
self_layer = layer.self_layer
state = self._get_darknet_rnn_attrs(self_layer, state)
op_name, new_attrs = 'elemwise_add', {}
new_inputs = _as_list([sym, state])
state = _darknet_get_nnvm_op(op_name)(*new_inputs, **new_attrs)
self._outs.append(state)
output_layer = layer.output_layer
sym = self._get_darknet_rnn_attrs(output_layer, state)
self._sym_array[layer_num] = sym
processed = True
return processed, sym return processed, sym
def from_darknet(self): def from_darknet(self):
......
...@@ -324,6 +324,31 @@ def test_forward_rnn(): ...@@ -324,6 +324,31 @@ def test_forward_rnn():
test_rnn_forward(net) test_rnn_forward(net)
LIB.free_network(net) LIB.free_network(net)
def test_forward_crnn():
'''test softmax layer'''
net = LIB.make_network(1)
batch = 1
c = 3
h = 224
w = 224
hidden_filters = c
output_filters = c
steps = 1
activation = 0
batch_normalize = 0
inputs = 256
outputs = 256
layer_1 = LIB.make_crnn_layer(batch, h, w, c, hidden_filters, output_filters,
steps, activation, batch_normalize)
net.layers[0] = layer_1
net.inputs = inputs
net.outputs = output_filters * h * w
net.w = w
net.h = h
LIB.resize_network(net, net.w, net.h)
test_forward(net)
LIB.free_network(net)
def test_forward_activation_logistic(): def test_forward_activation_logistic():
'''test logistic activation layer''' '''test logistic activation layer'''
net = LIB.make_network(1) net = LIB.make_network(1)
...@@ -369,4 +394,5 @@ if __name__ == '__main__': ...@@ -369,4 +394,5 @@ if __name__ == '__main__':
test_forward_region() test_forward_region()
test_forward_elu() test_forward_elu()
test_forward_rnn() test_forward_rnn()
test_forward_crnn()
test_forward_activation_logistic() test_forward_activation_logistic()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment