Commit 71cff3e8 by Pariksheet Pinjari Committed by Tianqi Chen

[MXNET] LRN support in MXNET frontend (#1520)

parent a2870fef
......@@ -263,6 +263,15 @@ def _expand_dims(inputs, attrs):
new_attrs['axis'] = _required_attr(attrs, 'axis')
return _get_nnvm_op(op_name)(*inputs, **new_attrs)
def _lrn(inputs, attrs):
op_name, new_attrs = "lrn", {}
new_attrs['alpha'] = attrs.get('alpha', 0.0001)
new_attrs['beta'] = attrs.get('beta', 0.75)
new_attrs['bias'] = attrs.get('knorm', 2)
# NCHW format and normalization along channel axis
new_attrs['axis'] = 1
new_attrs['size'] = _required_attr(attrs, 'nsize')
return _get_nnvm_op(op_name)(*inputs, **new_attrs)
_identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__',
'__div_symbol__', '__mul_scalar__', '__mul_symbol__',
......@@ -314,7 +323,8 @@ _convert_map = {
'sum_axis' : _rename('sum'),
'UpSampling' : _upsampling,
'clip' : _clip,
'expand_dims' : _expand_dims
'expand_dims' : _expand_dims,
'LRN' : _lrn
}
def _convert_symbol(op_name, inputs, attrs,
......
......@@ -149,6 +149,11 @@ def test_forward_pooling():
mx_sym = mx.sym.Pooling(data, kernel=(3, 3), pad=(1, 1), pool_type='max')
verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8), (1, 20, 8, 8))
def test_forward_lrn():
data = mx.sym.var('data')
mx_sym = mx.sym.LRN(data, alpha=2, beta=2, knorm=1, nsize=5)
verify_mxnet_frontend_impl(mx_sym, (1, 10, 24, 24), (1, 10, 24, 24))
if __name__ == '__main__':
test_forward_mlp()
test_forward_vgg()
......@@ -163,3 +168,4 @@ if __name__ == '__main__':
test_forward_split_squeeze()
test_forward_expand_dims()
test_forward_pooling()
test_forward_lrn()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment