Commit 3b328b90 by Hao Jin Committed by Yizhi Liu

add converter for MXNet slice in nnvm and relay (#2662)

parent 3e765edc
......@@ -189,6 +189,19 @@ def _reshape(inputs, attrs):
new_attrs['shape'] = _required_attr(attrs, 'shape')
return _get_nnvm_op(op_name)(*inputs, **new_attrs)
def _slice(inputs, attrs):
begin = attrs.get('begin', None)
end = attrs.get('end', None)
stride = attrs.get('step', None)
if begin is None or end is None:
raise RuntimeError('begin and end are required params')
if 'None' in begin or 'None' in end:
raise RuntimeError('None in begin or end not supported yet...')
new_attrs = {'begin': begin, 'end': end}
if stride is not None:
new_attrs['stride'] = stride
return _get_nnvm_op('strided_slice')(inputs[0], **new_attrs)
def _split(inputs, attrs):
op_name, new_attrs = 'split', {}
axis = attrs.get('axis', 1)
......@@ -337,6 +350,7 @@ _convert_map = {
'Pooling' : _pooling,
'Pooling_v1' : _pooling,
'Reshape' : _reshape,
'slice' : _slice,
'SliceChannel' : _split,
'split' : _split,
'Softmax' : _rename('softmax'),
......
......@@ -220,6 +220,14 @@ def test_forward_where():
tvm_out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).asnumpy()
tvm.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5, atol=1e-5)
def test_forward_slice():
data = mx.sym.var('data')
mx_sym = mx.sym.slice(data, begin=(0, 1), end=(2, 4))
verify_mxnet_frontend_impl(mx_sym, (3, 4), (2, 3))
mx_sym = mx.sym.slice(data, begin=(-1, 1), end=(-3, 4), step=(-1, 2))
verify_mxnet_frontend_impl(mx_sym, (3, 4), (2, 2))
if __name__ == '__main__':
test_forward_mlp()
test_forward_vgg()
......@@ -242,4 +250,5 @@ if __name__ == '__main__':
test_forward_argmax()
test_forward_argmin()
test_forward_where()
test_forward_slice()
......@@ -172,6 +172,21 @@ def _mx_batch_norm(inputs, attrs):
return _op.nn.batch_norm(*inputs, **new_attrs)
def _mx_slice(inputs, attrs):
new_attrs = {}
begin = attrs.get_int_tuple('begin', None)
end = attrs.get_int_tuple('end', None)
stride = attrs.get_int_tuple('step', None)
if begin is None or end is None:
raise RuntimeError("begin and end are required parameters.")
if None in begin or None in end:
raise RuntimeError("None in begin or end is not supported yet.")
new_attrs = {'begin': begin, 'end': end}
if stride is not None:
new_attrs['strides'] = stride
return _op.strided_slice(inputs[0], **new_attrs)
def _mx_split(inputs, attrs):
axis = attrs.get_int("axis", 1)
new_attrs = {}
......@@ -368,6 +383,7 @@ _convert_map = {
"BatchNorm" : _mx_batch_norm,
"BatchNorm_v1" : _mx_batch_norm,
"LRN" : _mx_lrn,
"slice" : _mx_slice,
"SliceChannel" : _mx_split,
"split" : _mx_split,
"expand_dims" : _mx_expand_dims,
......
......@@ -190,6 +190,13 @@ def test_forward_argmin():
mx_sym = mx.sym.argmin(data, axis=0)
verify_mxnet_frontend_impl(mx_sym, (5, 4), (4,))
def test_forward_slice():
data = mx.sym.var('data')
mx_sym = mx.sym.slice(data, begin=(0, 1), end=(2, 4))
verify_mxnet_frontend_impl(mx_sym, (3, 4), (2, 3))
mx_sym = mx.sym.slice(data, begin=(-1, 1), end=(-3, 4), step=(-1, 2))
verify_mxnet_frontend_impl(mx_sym, (3, 4), (2, 2))
def test_forward_where():
cond = mx.sym.var('cond')
x = mx.sym.var('x')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment