Commit 7ea06e6e by Siju Committed by Tianqi Chen

[ONNX]onnx gather bug fix (#1543)

parent 60da4705
......@@ -489,15 +489,11 @@ class Slice(OnnxOpConverter):
class Gather(OnnxOpConverter):
""" Operator converter for Gather.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
axis = attr['axis']
indices = np.array(attr['indices'], dtype='int32')
name = 'gather_indices'
gather_indices = _sym.Variable(name=name, init=indices)
params[name] = indices
return _sym.take(inputs[0], gather_indices, axis=axis)
axis = attr.get('axis', 0)
return AttrCvt(op_name='take',
extras={'axis':axis})(inputs, attr)
class LRN(OnnxOpConverter):
""" Operator converter for Local Response Normalization.
......
......@@ -8,21 +8,50 @@ import onnx
from model_zoo import super_resolution, squeezenet1_1, lenet, resnet18_1_0
from onnx import helper, TensorProto
def get_tvm_output(model, x, target, ctx, out_shape, dtype='float32'):
new_sym, params = nnvm.frontend.from_onnx(model)
input_name = model.graph.input[0].name
shape_dict = {input_name: x.shape}
dtype_dict = {input_name: dtype}
graph, lib, params = nnvm.compiler.build(new_sym, target, shape_dict, dtype_dict, params=params)
def get_tvm_output(graph_def, input_data, target, ctx, output_shape, output_dtype='float32'):
""" Generic function to execute and get tvm output"""
sym, params = nnvm.frontend.from_onnx(graph_def)
target = 'llvm'
if isinstance(input_data, list):
input_names = {}
shape_dict = {}
dtype_dict = {}
for i, _ in enumerate(input_data):
input_names[i] = graph_def.graph.input[i].name
shape_dict[input_names[i]] = input_data[i].shape
dtype_dict[input_names[i]] = input_data[i].dtype
else:
input_names = graph_def.graph.input[0].name
shape_dict = {input_names: input_data.shape}
dtype_dict = {input_names: input_data.dtype}
graph, lib, params = nnvm.compiler.build(sym, target, shape_dict,
dtype=dtype_dict, params=params)
ctx = tvm.cpu(0)
from tvm.contrib import graph_runtime
m = graph_runtime.create(graph, lib, ctx)
# set inputs
m.set_input(input_name, tvm.nd.array(x.astype(dtype)))
if isinstance(input_data, list):
for i, e in enumerate(input_names):
m.set_input(input_names[i], tvm.nd.array(input_data[i].astype(input_data[i].dtype)))
else:
m.set_input(input_names, tvm.nd.array(input_data.astype(input_data.dtype)))
m.set_input(**params)
# execute
m.run()
# get outputs
out = m.get_output(0, tvm.nd.empty(out_shape, dtype))
return out.asnumpy()
if isinstance(output_shape, list) and isinstance(output_dtype, list):
tvm_output_list = []
for i, s in enumerate(output_shape):
tvm_output = m.get_output(i, tvm.nd.empty((s), output_dtype[i]))
tvm_output_list.append(tvm_output.asnumpy())
return tvm_output_list
else:
tvm_output = m.get_output(0, tvm.nd.empty((output_shape), output_dtype))
return tvm_output.asnumpy()
def get_caffe2_output(model, x, dtype='float32'):
import caffe2.python.onnx.backend
......@@ -70,13 +99,15 @@ def test_reshape():
graph = helper.make_graph([ref_node, reshape_node],
"reshape_test",
inputs = [helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))],
outputs = [helper.make_tensor_value_info("out", TensorProto.FLOAT, list(ref_shape))])
inputs = [helper.make_tensor_value_info("in",
TensorProto.FLOAT, list(in_shape))],
outputs = [helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(ref_shape))])
model = helper.make_model(graph, producer_name='reshape_test')
for target, ctx in ctx_list():
x = np.random.uniform(size=in_shape)
x = np.random.uniform(size=in_shape).astype('int32')
tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, 'float32')
np.testing.assert_allclose(ref_shape, tvm_out.shape)
......@@ -98,13 +129,15 @@ def test_reshape_like():
graph = helper.make_graph([ref_node, copy_node, reshape_node],
"reshape_like_test",
inputs = [helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))],
outputs = [helper.make_tensor_value_info("out", TensorProto.FLOAT, list(ref_shape))])
inputs = [helper.make_tensor_value_info("in",
TensorProto.FLOAT, list(in_shape))],
outputs = [helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(ref_shape))])
model = helper.make_model(graph, producer_name='reshape_like_test')
for target, ctx in ctx_list():
x = np.random.uniform(size=in_shape)
x = np.random.uniform(size=in_shape).astype('float32')
tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, 'float32')
np.testing.assert_allclose(ref_shape, tvm_out.shape)
......@@ -122,31 +155,18 @@ def _test_power_iteration(x_shape, y_shape):
graph = helper.make_graph([res],
'power_test',
inputs = [helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape)),
helper.make_tensor_value_info("y", TensorProto.FLOAT, list(y_shape))],
outputs = [helper.make_tensor_value_info("out", TensorProto.FLOAT, list(np_res.shape))])
inputs = [helper.make_tensor_value_info("x",
TensorProto.FLOAT, list(x_shape)),
helper.make_tensor_value_info("y",
TensorProto.FLOAT, list(y_shape))],
outputs = [helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(np_res.shape))])
model = helper.make_model(graph, producer_name='power_test')
for target, ctx in ctx_list():
new_sym, params = nnvm.frontend.from_onnx(model)
input_name = model.graph.input[0].name
input_name1 = model.graph.input[1].name
shape_dict = {input_name: x.shape, input_name1: y.shape}
dtype_dict = {input_name: x.dtype, input_name1: y.dtype}
graph, lib, params = nnvm.compiler.build(new_sym, target, shape_dict, dtype_dict, params=params)
m = graph_runtime.create(graph, lib, ctx)
# set inputs
m.set_input(input_name, tvm.nd.array(x))
m.set_input(input_name1, tvm.nd.array(y))
m.set_input(**params)
m.run()
# get outputs
tvm_out = m.get_output(0, tvm.nd.empty(np_res.shape, np_res.dtype))
np.testing.assert_allclose(np_res, tvm_out.asnumpy(), rtol=1e-5, atol=1e-5)
tvm_out = get_tvm_output(model, [x, y], target, ctx, np_res.shape)
np.testing.assert_allclose(np_res, tvm_out, rtol=1e-5, atol=1e-5)
def test_power():
_test_power_iteration((1, 3), (1))
......@@ -160,13 +180,15 @@ def test_squeeze():
graph = helper.make_graph([y],
'squeeze_test',
inputs = [helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))],
outputs = [helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))])
inputs = [helper.make_tensor_value_info("in",
TensorProto.FLOAT, list(in_shape))],
outputs = [helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(out_shape))])
model = helper.make_model(graph, producer_name='squeeze_test')
for target, ctx in ctx_list():
x = np.random.uniform(size=in_shape)
x = np.random.uniform(size=in_shape).astype('float32')
tvm_out = get_tvm_output(model, x, target, ctx, out_shape, 'float32')
np.testing.assert_allclose(out_shape, tvm_out.shape)
......@@ -179,44 +201,47 @@ def test_unsqueeze():
graph = helper.make_graph([y],
'squeeze_test',
inputs = [helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))],
outputs = [helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))])
inputs = [helper.make_tensor_value_info("in",
TensorProto.FLOAT, list(in_shape))],
outputs = [helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(out_shape))])
model = helper.make_model(graph, producer_name='squeeze_test')
for target, ctx in ctx_list():
x = np.random.uniform(size=in_shape)
x = np.random.uniform(size=in_shape).astype('float32')
tvm_out = get_tvm_output(model, x, target, ctx, out_shape, 'float32')
np.testing.assert_allclose(out_shape, tvm_out.shape)
def verify_gather(in_shape, indices, axis=0):
indices_src = np.array(indices, dtype="int32")
x = np.random.uniform(size=in_shape)
out_np = np.take(x, indices_src, axis=axis)
def verify_gather(in_shape, indices, axis, dtype):
x = np.random.uniform(size=in_shape).astype(dtype)
indices = np.array(indices, dtype="int32")
out_np = np.take(x, indices, axis=axis)
y = helper.make_node("Gather", ['in'], ['out'], indices=indices, axis=axis)
y = helper.make_node("Gather", ['in', 'indices'], ['out'], axis=axis)
graph = helper.make_graph([y],
'gather_test',
inputs = [helper.make_tensor_value_info("in",
TensorProto.FLOAT, list(in_shape))],
TensorProto.FLOAT, list(in_shape)),
helper.make_tensor_value_info("indices",
TensorProto.INT32, list(indices.shape))],
outputs = [helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(out_np.shape))])
model = helper.make_model(graph, producer_name='gather_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, x, target, ctx, out_np.shape, 'float32')
tvm_out = get_tvm_output(model, [x, indices], target, ctx, out_np.shape)
np.testing.assert_allclose(out_np, tvm_out)
def test_gather():
verify_gather((4,), [1])
verify_gather((4,), [0, 1, 2, 3])
verify_gather((4, 2), [1], 1)
verify_gather((4, 3, 5, 6), [2, 1, 0, 0], -2)
verify_gather((4,), [1], 0, 'int32')
verify_gather((1,4), [0], 0, 'int32')
verify_gather((4,), [[[1,0],[0,1]]], 0, 'float32')
verify_gather((2,2), [[[1,0],[0,1]]], 1, 'int32')
verify_gather((3,3,3), [[[1,0]]], -1, 'int32')
verify_gather((4,3,5,6), [[2,1,0,0]], 0, 'float32')
def _test_slice_iteration(indata, outdata, starts, ends, axes=None):
if axes:
......@@ -226,8 +251,10 @@ def _test_slice_iteration(indata, outdata, starts, ends, axes=None):
graph = helper.make_graph([y],
'slice_test',
inputs = [helper.make_tensor_value_info("in", TensorProto.FLOAT, list(indata.shape))],
outputs = [helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outdata.shape))])
inputs = [helper.make_tensor_value_info("in",
TensorProto.FLOAT, list(indata.shape))],
outputs = [helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(outdata.shape))])
model = helper.make_model(graph, producer_name='slice_test')
......@@ -251,8 +278,10 @@ def _test_onnx_op_elementwise(inshape, outfunc, npargs, dtype, opname, kwargs):
graph = helper.make_graph([y],
opname+'_test',
inputs = [helper.make_tensor_value_info("in", TensorProto.FLOAT, list(indata.shape))],
outputs = [helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outdata.shape))])
inputs = [helper.make_tensor_value_info("in",
TensorProto.FLOAT, list(indata.shape))],
outputs = [helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(outdata.shape))])
model = helper.make_model(graph, producer_name=opname+'_test')
......@@ -278,40 +307,27 @@ def test_clip():
def test_matmul():
a_shape = (4, 3)
b_shape = (3, 4)
out_shape = (4, 4)
a_array = np.random.uniform(size=a_shape).astype('float32')
b_array = np.random.uniform(size=b_shape).astype('float32')
out_np = np.matmul(a_array, b_array)
mul_node = helper.make_node("MatMul", ["a", "b"], ["out"])
graph = helper.make_graph([mul_node],
"matmul_test",
inputs = [helper.make_tensor_value_info("a", TensorProto.FLOAT, list(a_shape)),
helper.make_tensor_value_info("b", TensorProto.FLOAT, list(b_shape))],
outputs = [helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))])
inputs = [helper.make_tensor_value_info("a",
TensorProto.FLOAT, list(a_shape)),
helper.make_tensor_value_info("b",
TensorProto.FLOAT, list(b_shape))],
outputs = [helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(out_np.shape))])
model = helper.make_model(graph, producer_name='matmul_test')
for target, ctx in ctx_list():
new_sym, params = nnvm.frontend.from_onnx(model)
input_name = model.graph.input[0].name
input_name1 = model.graph.input[1].name
shape_dict = {input_name: a_array.shape, input_name1: b_array.shape}
dtype_dict = {input_name: 'float32', input_name1: 'float32'}
graph, lib, params = nnvm.compiler.build(new_sym, target, shape_dict, dtype_dict, params=params)
m = graph_runtime.create(graph, lib, ctx)
# set inputs
m.set_input(input_name, tvm.nd.array(a_array.astype('float32')))
m.set_input(input_name1, tvm.nd.array(b_array.astype('float32')))
m.set_input(**params)
m.run()
# get outputs
tvm_out = m.get_output(0, tvm.nd.empty(out_shape, 'float32'))
np.testing.assert_allclose(np.matmul(a_array, b_array), tvm_out.asnumpy(), rtol=1e-5, atol=1e-5)
tvm_out = get_tvm_output(model, [a_array, b_array], target, ctx, out_np.shape)
np.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5)
def verify_lrn(shape, nsize, dtype, alpha=None, beta=None, bias=None):
in_array = np.random.uniform(size=shape).astype(dtype)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment