Commit 4e7b548e by Siju Committed by Tianqi Chen

[DARKNET FRONTEND]Batchnorm added as part of Dense op for running rnn model for next wo… (#1385)

parent 6cdc18e2
...@@ -226,13 +226,18 @@ def _darknet_dense(inputs, attrs): ...@@ -226,13 +226,18 @@ def _darknet_dense(inputs, attrs):
"""Process the dense operation.""" """Process the dense operation."""
op_name, new_attrs = 'dense', {} op_name, new_attrs = 'dense', {}
new_attrs['units'] = _darknet_required_attr(attrs, 'num_hidden') new_attrs['units'] = _darknet_required_attr(attrs, 'num_hidden')
out_name = {}
if attrs.get('use_bias', False) is True: if attrs.get('use_bias', False) is True:
new_attrs['use_bias'] = True new_attrs['use_bias'] = True
if attrs.get('use_flatten', False) is True: if attrs.get('use_flatten', False) is True:
inputs[0] = _sym.flatten(inputs[0]) inputs[0] = _sym.flatten(inputs[0])
sym = _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs) sym = _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs)
out_name = sym.list_output_names()[0].replace('_output', '') out_name[0] = sym.list_output_names()[0].replace('_output', '')
if 'use_batchNorm' in attrs:
op_name, new_attrs = 'batch_norm', {}
new_attrs['epsilon'] = 0.000001
sym = _darknet_get_nnvm_op(op_name)(*sym, **new_attrs)
out_name[1] = sym.list_output_names()[0].replace('_output', '')
if 'activation' in attrs: if 'activation' in attrs:
new_attrs = {} new_attrs = {}
new_attrs['activation'] = attrs['activation'] new_attrs['activation'] = attrs['activation']
...@@ -430,13 +435,16 @@ def _get_connected_weights(layer, opname, params, dtype): ...@@ -430,13 +435,16 @@ def _get_connected_weights(layer, opname, params, dtype):
weights = _read_memory_buffer((layer.outputs, layer.inputs), layer.weights, dtype) weights = _read_memory_buffer((layer.outputs, layer.inputs), layer.weights, dtype)
biases = _read_memory_buffer((layer.outputs, ), layer.biases, dtype) biases = _read_memory_buffer((layer.outputs, ), layer.biases, dtype)
k = _get_tvm_params_name(opname, 'weight') k = _get_tvm_params_name(opname[0], 'weight')
params[k] = tvm.nd.array(weights) params[k] = tvm.nd.array(weights)
k = _get_tvm_params_name(opname, 'bias')
params[k] = tvm.nd.array(biases)
if layer.batch_normalize == 1 and layer.dontloadscales != 1: if layer.batch_normalize == 1 and layer.dontloadscales != 1:
_get_batchnorm_weights(layer, opname, params, layer.outputs, dtype) _get_batchnorm_weights(layer, opname[1], params, layer.outputs, dtype)
k = _get_tvm_params_name(opname[1], 'beta')
params[k] = tvm.nd.array(biases)
else:
k = _get_tvm_params_name(opname[0], 'bias')
params[k] = tvm.nd.array(biases)
def _get_batchnorm_weights(layer, opname, params, size, dtype): def _get_batchnorm_weights(layer, opname, params, size, dtype):
"""Parse the weights for batchnorm, which includes, scales, moving mean """Parse the weights for batchnorm, which includes, scales, moving mean
......
...@@ -169,6 +169,20 @@ def test_forward_dense(): ...@@ -169,6 +169,20 @@ def test_forward_dense():
test_forward(net) test_forward(net)
LIB.free_network(net) LIB.free_network(net)
def test_forward_dense_batchnorm():
'''test fully connected layer with batchnorm'''
net = LIB.make_network(1)
layer = LIB.make_connected_layer(1, 12, 2, 1, 1, 0)
for i in range(5):
layer.rolling_mean[i] = np.random.rand(1)
layer.rolling_variance[i] = np.random.rand(1)
layer.scales[i] = np.random.rand(1)
net.layers[0] = layer
net.w = net.h = 2
LIB.resize_network(net, 2, 2)
test_forward(net)
LIB.free_network(net)
def test_forward_maxpooling(): def test_forward_maxpooling():
'''test maxpooling layer''' '''test maxpooling layer'''
net = LIB.make_network(1) net = LIB.make_network(1)
...@@ -264,6 +278,7 @@ if __name__ == '__main__': ...@@ -264,6 +278,7 @@ if __name__ == '__main__':
test_forward_batch_norm() test_forward_batch_norm()
test_forward_shortcut() test_forward_shortcut()
test_forward_dense() test_forward_dense()
test_forward_dense_batchnorm()
test_forward_reorg() test_forward_reorg()
test_forward_region() test_forward_region()
test_forward_elu() test_forward_elu()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment