Commit f216b25e by Tatsuya Nishiyama Committed by Tianqi Chen

[FRONTEND][MXNET] Add expand_dims supoort (#1317)

* [FRONTEND][MXNET] Add expand_dims supoort

* fix lint
parent a83e1e1e
......@@ -241,6 +241,12 @@ def _elemwise_sum(inputs, _):
return _get_nnvm_op('elemwise_sum')(*inputs, **new_attrs)
def _expand_dims(inputs, attrs):
op_name, new_attrs = "expand_dims", {}
new_attrs['axis'] = _required_attr(attrs, 'axis')
return _get_nnvm_op(op_name)(*inputs, **new_attrs)
_identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__',
'__div_symbol__', '__mul_scalar__', '__mul_symbol__',
'__pow_scalar__', '__rdiv_scalar__', '__rpow_scalar__',
......@@ -288,7 +294,8 @@ _convert_map = {
'reshape' : _reshape,
'sum_axis' : _rename('sum'),
'UpSampling' : _upsampling,
'clip' : _clip
'clip' : _clip,
'expand_dims' : _expand_dims
}
def _convert_symbol(op_name, inputs, attrs,
......
......@@ -136,6 +136,11 @@ def test_forward_split_squeeze():
mx_sym = mx.sym.split(data, axis=1, num_outputs=4, squeeze_axis=True)
verify_mxnet_frontend_impl(mx_sym, (1, 4, 2, 1), (1, 2, 1))
def test_forward_expand_dims():
data = mx.sym.var('data')
mx_sym = mx.sym.expand_dims(data, axis=1)
verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 1, 3, 4))
if __name__ == '__main__':
test_forward_mlp()
test_forward_vgg()
......@@ -148,3 +153,4 @@ if __name__ == '__main__':
test_forward_clip()
test_forward_split()
test_forward_split_squeeze()
test_forward_expand_dims()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment