Commit 80e4bc02 by Tatsuya Nishiyama Committed by Tianqi Chen

[FRONTEND][MXNET] Add squeeze_axis support to split operator (#1288)

parent d3441616
...@@ -188,12 +188,15 @@ def _reshape(inputs, attrs): ...@@ -188,12 +188,15 @@ def _reshape(inputs, attrs):
return _get_nnvm_op(op_name)(*inputs, **new_attrs) return _get_nnvm_op(op_name)(*inputs, **new_attrs)
def _split(inputs, attrs): def _split(inputs, attrs):
if _parse_bool_str(attrs, 'squeeze_axis'):
_raise_not_supported('squeeze_axis', 'split')
op_name, new_attrs = 'split', {} op_name, new_attrs = 'split', {}
axis = attrs.get('axis', 1)
new_attrs['indices_or_sections'] = _required_attr(attrs, 'num_outputs') new_attrs['indices_or_sections'] = _required_attr(attrs, 'num_outputs')
new_attrs['axis'] = attrs.get('axis', 1) new_attrs['axis'] = axis
return _get_nnvm_op(op_name)(*inputs, **new_attrs) outputs = _get_nnvm_op(op_name)(*inputs, **new_attrs)
if _parse_bool_str(attrs, 'squeeze_axis'):
squeeze_attrs = {'axis': axis}
outputs = _sym.Group([_get_nnvm_op('squeeze')(o, **squeeze_attrs) for o in outputs])
return outputs
def _softmax_activation(inputs, attrs): def _softmax_activation(inputs, attrs):
op_name, new_attrs = 'softmax', {} op_name, new_attrs = 'softmax', {}
......
...@@ -126,6 +126,16 @@ def test_forward_clip(): ...@@ -126,6 +126,16 @@ def test_forward_clip():
mx_sym = mx.sym.clip(data, a_min=0, a_max=1) mx_sym = mx.sym.clip(data, a_min=0, a_max=1)
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100)) verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
def test_forward_split():
data = mx.sym.var('data')
mx_sym = mx.sym.split(data, axis=1, num_outputs=4, squeeze_axis=False)
verify_mxnet_frontend_impl(mx_sym, (1, 4, 2, 1), (1, 1, 2, 1))
def test_forward_split_squeeze():
data = mx.sym.var('data')
mx_sym = mx.sym.split(data, axis=1, num_outputs=4, squeeze_axis=True)
verify_mxnet_frontend_impl(mx_sym, (1, 4, 2, 1), (1, 2, 1))
if __name__ == '__main__': if __name__ == '__main__':
test_forward_mlp() test_forward_mlp()
test_forward_vgg() test_forward_vgg()
...@@ -136,3 +146,5 @@ if __name__ == '__main__': ...@@ -136,3 +146,5 @@ if __name__ == '__main__':
test_forward_softrelu() test_forward_softrelu()
test_forward_fc_flatten() test_forward_fc_flatten()
test_forward_clip() test_forward_clip()
test_forward_split()
test_forward_split_squeeze()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment