Commit 3b328b90 by Hao Jin Committed by Yizhi Liu

add converter for MXNet slice in nnvm and relay (#2662)

parent 3e765edc
...@@ -189,6 +189,19 @@ def _reshape(inputs, attrs): ...@@ -189,6 +189,19 @@ def _reshape(inputs, attrs):
new_attrs['shape'] = _required_attr(attrs, 'shape') new_attrs['shape'] = _required_attr(attrs, 'shape')
return _get_nnvm_op(op_name)(*inputs, **new_attrs) return _get_nnvm_op(op_name)(*inputs, **new_attrs)
def _slice(inputs, attrs):
begin = attrs.get('begin', None)
end = attrs.get('end', None)
stride = attrs.get('step', None)
if begin is None or end is None:
raise RuntimeError('begin and end are required params')
if 'None' in begin or 'None' in end:
raise RuntimeError('None in begin or end not supported yet...')
new_attrs = {'begin': begin, 'end': end}
if stride is not None:
new_attrs['stride'] = stride
return _get_nnvm_op('strided_slice')(inputs[0], **new_attrs)
def _split(inputs, attrs): def _split(inputs, attrs):
op_name, new_attrs = 'split', {} op_name, new_attrs = 'split', {}
axis = attrs.get('axis', 1) axis = attrs.get('axis', 1)
...@@ -337,6 +350,7 @@ _convert_map = { ...@@ -337,6 +350,7 @@ _convert_map = {
'Pooling' : _pooling, 'Pooling' : _pooling,
'Pooling_v1' : _pooling, 'Pooling_v1' : _pooling,
'Reshape' : _reshape, 'Reshape' : _reshape,
'slice' : _slice,
'SliceChannel' : _split, 'SliceChannel' : _split,
'split' : _split, 'split' : _split,
'Softmax' : _rename('softmax'), 'Softmax' : _rename('softmax'),
......
...@@ -220,6 +220,14 @@ def test_forward_where(): ...@@ -220,6 +220,14 @@ def test_forward_where():
tvm_out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).asnumpy() tvm_out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).asnumpy()
tvm.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5, atol=1e-5) tvm.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5, atol=1e-5)
def test_forward_slice():
data = mx.sym.var('data')
mx_sym = mx.sym.slice(data, begin=(0, 1), end=(2, 4))
verify_mxnet_frontend_impl(mx_sym, (3, 4), (2, 3))
mx_sym = mx.sym.slice(data, begin=(-1, 1), end=(-3, 4), step=(-1, 2))
verify_mxnet_frontend_impl(mx_sym, (3, 4), (2, 2))
if __name__ == '__main__': if __name__ == '__main__':
test_forward_mlp() test_forward_mlp()
test_forward_vgg() test_forward_vgg()
...@@ -242,4 +250,5 @@ if __name__ == '__main__': ...@@ -242,4 +250,5 @@ if __name__ == '__main__':
test_forward_argmax() test_forward_argmax()
test_forward_argmin() test_forward_argmin()
test_forward_where() test_forward_where()
test_forward_slice()
...@@ -172,6 +172,21 @@ def _mx_batch_norm(inputs, attrs): ...@@ -172,6 +172,21 @@ def _mx_batch_norm(inputs, attrs):
return _op.nn.batch_norm(*inputs, **new_attrs) return _op.nn.batch_norm(*inputs, **new_attrs)
def _mx_slice(inputs, attrs):
new_attrs = {}
begin = attrs.get_int_tuple('begin', None)
end = attrs.get_int_tuple('end', None)
stride = attrs.get_int_tuple('step', None)
if begin is None or end is None:
raise RuntimeError("begin and end are required parameters.")
if None in begin or None in end:
raise RuntimeError("None in begin or end is not supported yet.")
new_attrs = {'begin': begin, 'end': end}
if stride is not None:
new_attrs['strides'] = stride
return _op.strided_slice(inputs[0], **new_attrs)
def _mx_split(inputs, attrs): def _mx_split(inputs, attrs):
axis = attrs.get_int("axis", 1) axis = attrs.get_int("axis", 1)
new_attrs = {} new_attrs = {}
...@@ -368,6 +383,7 @@ _convert_map = { ...@@ -368,6 +383,7 @@ _convert_map = {
"BatchNorm" : _mx_batch_norm, "BatchNorm" : _mx_batch_norm,
"BatchNorm_v1" : _mx_batch_norm, "BatchNorm_v1" : _mx_batch_norm,
"LRN" : _mx_lrn, "LRN" : _mx_lrn,
"slice" : _mx_slice,
"SliceChannel" : _mx_split, "SliceChannel" : _mx_split,
"split" : _mx_split, "split" : _mx_split,
"expand_dims" : _mx_expand_dims, "expand_dims" : _mx_expand_dims,
......
...@@ -190,6 +190,13 @@ def test_forward_argmin(): ...@@ -190,6 +190,13 @@ def test_forward_argmin():
mx_sym = mx.sym.argmin(data, axis=0) mx_sym = mx.sym.argmin(data, axis=0)
verify_mxnet_frontend_impl(mx_sym, (5, 4), (4,)) verify_mxnet_frontend_impl(mx_sym, (5, 4), (4,))
def test_forward_slice():
data = mx.sym.var('data')
mx_sym = mx.sym.slice(data, begin=(0, 1), end=(2, 4))
verify_mxnet_frontend_impl(mx_sym, (3, 4), (2, 3))
mx_sym = mx.sym.slice(data, begin=(-1, 1), end=(-3, 4), step=(-1, 2))
verify_mxnet_frontend_impl(mx_sym, (3, 4), (2, 2))
def test_forward_where(): def test_forward_where():
cond = mx.sym.var('cond') cond = mx.sym.var('cond')
x = mx.sym.var('x') x = mx.sym.var('x')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment