Commit f216b25e by Tatsuya Nishiyama Committed by Tianqi Chen

[FRONTEND][MXNET] Add expand_dims supoort (#1317)

* [FRONTEND][MXNET] Add expand_dims supoort

* fix lint
parent a83e1e1e
...@@ -241,6 +241,12 @@ def _elemwise_sum(inputs, _): ...@@ -241,6 +241,12 @@ def _elemwise_sum(inputs, _):
return _get_nnvm_op('elemwise_sum')(*inputs, **new_attrs) return _get_nnvm_op('elemwise_sum')(*inputs, **new_attrs)
def _expand_dims(inputs, attrs):
op_name, new_attrs = "expand_dims", {}
new_attrs['axis'] = _required_attr(attrs, 'axis')
return _get_nnvm_op(op_name)(*inputs, **new_attrs)
_identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__', _identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__',
'__div_symbol__', '__mul_scalar__', '__mul_symbol__', '__div_symbol__', '__mul_scalar__', '__mul_symbol__',
'__pow_scalar__', '__rdiv_scalar__', '__rpow_scalar__', '__pow_scalar__', '__rdiv_scalar__', '__rpow_scalar__',
...@@ -288,7 +294,8 @@ _convert_map = { ...@@ -288,7 +294,8 @@ _convert_map = {
'reshape' : _reshape, 'reshape' : _reshape,
'sum_axis' : _rename('sum'), 'sum_axis' : _rename('sum'),
'UpSampling' : _upsampling, 'UpSampling' : _upsampling,
'clip' : _clip 'clip' : _clip,
'expand_dims' : _expand_dims
} }
def _convert_symbol(op_name, inputs, attrs, def _convert_symbol(op_name, inputs, attrs,
......
...@@ -136,6 +136,11 @@ def test_forward_split_squeeze(): ...@@ -136,6 +136,11 @@ def test_forward_split_squeeze():
mx_sym = mx.sym.split(data, axis=1, num_outputs=4, squeeze_axis=True) mx_sym = mx.sym.split(data, axis=1, num_outputs=4, squeeze_axis=True)
verify_mxnet_frontend_impl(mx_sym, (1, 4, 2, 1), (1, 2, 1)) verify_mxnet_frontend_impl(mx_sym, (1, 4, 2, 1), (1, 2, 1))
def test_forward_expand_dims():
data = mx.sym.var('data')
mx_sym = mx.sym.expand_dims(data, axis=1)
verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 1, 3, 4))
if __name__ == '__main__': if __name__ == '__main__':
test_forward_mlp() test_forward_mlp()
test_forward_vgg() test_forward_vgg()
...@@ -148,3 +153,4 @@ if __name__ == '__main__': ...@@ -148,3 +153,4 @@ if __name__ == '__main__':
test_forward_clip() test_forward_clip()
test_forward_split() test_forward_split()
test_forward_split_squeeze() test_forward_split_squeeze()
test_forward_expand_dims()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment