Commit c4f03de3 by Hao Jin Committed by Tianqi Chen

add MXNet converter for where operator for both NNVM and Relay (#2647)

parent e20ef0d4
......@@ -305,7 +305,7 @@ _identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__',
'flatten', 'log', 'log_softmax', 'max', 'min', 'negative',
'ones_like', 'relu', 'sigmoid', 'slice_like', 'softmax',
'sum', 'tanh', 'transpose', 'zeros_like', 'gather_nd',
'reshape_like']
'reshape_like', 'where']
_convert_map = {
'_copy' : _rename('copy'),
......
......@@ -158,7 +158,7 @@ def test_forward_ones():
ones = mx.sym.ones(shape=(2, 3, 4), dtype='float32')
mx_sym = mx.sym.elemwise_add(data, ones)
verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4))
def test_forward_zeros():
data = mx.sym.var('data')
zeros = mx.sym.zeros(shape=(2, 3, 4), dtype='float32')
......@@ -184,7 +184,42 @@ def test_forward_argmin():
data = mx.sym.var('data')
mx_sym = mx.sym.argmin(data, axis=0)
verify_mxnet_frontend_impl(mx_sym, (5, 4), (4,))
def test_forward_where():
cond = mx.sym.var('cond')
x = mx.sym.var('x')
y = mx.sym.var('y')
dshape = (2, 2)
dtype = 'float32'
mx_sym = mx.sym.where(cond, x, y)
np_cond = np.array([[0, 1], [-1, 0]]).astype(dtype)
np_x = np.random.uniform(size=dshape).astype(dtype)
np_y = np.random.uniform(size=dshape).astype(dtype)
mx_cond = mx.nd.array(np_cond)
mx_x = mx.nd.array(np_x)
mx_y = mx.nd.array(np_y)
mod = mx.mod.Module(mx_sym, label_names=None, data_names=['cond', 'x', 'y'])
mod.bind(data_shapes=[('cond', dshape), ('x', dshape), ('y', dshape)], for_training=False)
mod.init_params()
args, auxs = mod.get_params()
mx_out = mx.nd.where(mx_cond, mx_x, mx_y).asnumpy()
out_shape = dshape
new_sym, params = frontend.from_mxnet(mx_sym, args, auxs)
shape_dict = {'cond': dshape, 'x': dshape, 'y': dshape}
for target, ctx in ctx_list():
with nnvm.compiler.build_config(opt_level=3):
graph, lib, params = nnvm.compiler.build(new_sym, target, shape_dict, params=params)
m = graph_runtime.create(graph, lib, ctx)
# set inputs
m.set_input("cond", tvm.nd.array(np_cond))
m.set_input("x", tvm.nd.array(np_x))
m.set_input("y", tvm.nd.array(np_y))
m.set_input(**params)
m.run()
# get outputs
tvm_out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).asnumpy()
tvm.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5, atol=1e-5)
if __name__ == '__main__':
test_forward_mlp()
test_forward_vgg()
......@@ -206,4 +241,5 @@ if __name__ == '__main__':
test_forward_zeros_like()
test_forward_argmax()
test_forward_argmin()
test_forward_where()
......@@ -290,6 +290,7 @@ _identity_list = [
"slice_like",
"zeros_like",
"ones_like",
"where",
]
_convert_map = {
......
......@@ -190,6 +190,44 @@ def test_forward_argmin():
mx_sym = mx.sym.argmin(data, axis=0)
verify_mxnet_frontend_impl(mx_sym, (5, 4), (4,))
def test_forward_where():
cond = mx.sym.var('cond')
x = mx.sym.var('x')
y = mx.sym.var('y')
dshape = (2, 2)
dtype = 'float32'
mx_sym = mx.sym.where(cond, x, y)
np_cond = np.array([[0, 1], [-1, 0]]).astype(dtype)
np_x = np.random.uniform(size=dshape).astype(dtype)
np_y = np.random.uniform(size=dshape).astype(dtype)
mx_cond = mx.nd.array(np_cond)
mx_x = mx.nd.array(np_x)
mx_y = mx.nd.array(np_y)
mod = mx.mod.Module(mx_sym, label_names=None, data_names=['cond', 'x', 'y'])
mod.bind(data_shapes=[('cond', dshape), ('x', dshape), ('y', dshape)], for_training=False)
mod.init_params()
args, auxs = mod.get_params()
mx_out = mx.nd.where(mx_cond, mx_x, mx_y).asnumpy()
out_shape = dshape
shape_dict = {'cond': dshape, 'x': dshape, 'y': dshape}
new_sym, params = relay.frontend.from_mxnet(mx_sym,
shape_dict,
arg_params=args,
aux_params=auxs)
for target, ctx in ctx_list():
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(new_sym, target, params=params)
m = graph_runtime.create(graph, lib, ctx)
# set inputs
m.set_input("cond", tvm.nd.array(np_cond))
m.set_input("x", tvm.nd.array(np_x))
m.set_input("y", tvm.nd.array(np_y))
m.set_input(**params)
m.run()
# get outputs
tvm_out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).asnumpy()
tvm.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5, atol=1e-5)
if __name__ == '__main__':
test_forward_mlp()
......@@ -212,3 +250,4 @@ if __name__ == '__main__':
test_forward_zeros_like()
test_forward_argmax()
test_forward_argmin()
test_forward_where()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment