Commit 71cff3e8 by Pariksheet Pinjari Committed by Tianqi Chen

[MXNET] LRN support in MXNET frontend (#1520)

parent a2870fef
...@@ -263,6 +263,15 @@ def _expand_dims(inputs, attrs): ...@@ -263,6 +263,15 @@ def _expand_dims(inputs, attrs):
new_attrs['axis'] = _required_attr(attrs, 'axis') new_attrs['axis'] = _required_attr(attrs, 'axis')
return _get_nnvm_op(op_name)(*inputs, **new_attrs) return _get_nnvm_op(op_name)(*inputs, **new_attrs)
def _lrn(inputs, attrs):
op_name, new_attrs = "lrn", {}
new_attrs['alpha'] = attrs.get('alpha', 0.0001)
new_attrs['beta'] = attrs.get('beta', 0.75)
new_attrs['bias'] = attrs.get('knorm', 2)
# NCHW format and normalization along channel axis
new_attrs['axis'] = 1
new_attrs['size'] = _required_attr(attrs, 'nsize')
return _get_nnvm_op(op_name)(*inputs, **new_attrs)
_identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__', _identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__',
'__div_symbol__', '__mul_scalar__', '__mul_symbol__', '__div_symbol__', '__mul_scalar__', '__mul_symbol__',
...@@ -314,7 +323,8 @@ _convert_map = { ...@@ -314,7 +323,8 @@ _convert_map = {
'sum_axis' : _rename('sum'), 'sum_axis' : _rename('sum'),
'UpSampling' : _upsampling, 'UpSampling' : _upsampling,
'clip' : _clip, 'clip' : _clip,
'expand_dims' : _expand_dims 'expand_dims' : _expand_dims,
'LRN' : _lrn
} }
def _convert_symbol(op_name, inputs, attrs, def _convert_symbol(op_name, inputs, attrs,
......
...@@ -149,6 +149,11 @@ def test_forward_pooling(): ...@@ -149,6 +149,11 @@ def test_forward_pooling():
mx_sym = mx.sym.Pooling(data, kernel=(3, 3), pad=(1, 1), pool_type='max') mx_sym = mx.sym.Pooling(data, kernel=(3, 3), pad=(1, 1), pool_type='max')
verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8), (1, 20, 8, 8)) verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8), (1, 20, 8, 8))
def test_forward_lrn():
data = mx.sym.var('data')
mx_sym = mx.sym.LRN(data, alpha=2, beta=2, knorm=1, nsize=5)
verify_mxnet_frontend_impl(mx_sym, (1, 10, 24, 24), (1, 10, 24, 24))
if __name__ == '__main__': if __name__ == '__main__':
test_forward_mlp() test_forward_mlp()
test_forward_vgg() test_forward_vgg()
...@@ -163,3 +168,4 @@ if __name__ == '__main__': ...@@ -163,3 +168,4 @@ if __name__ == '__main__':
test_forward_split_squeeze() test_forward_split_squeeze()
test_forward_expand_dims() test_forward_expand_dims()
test_forward_pooling() test_forward_pooling()
test_forward_lrn()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment