Commit ba4cc7ba by Siju Committed by Tianqi Chen

Onnx Gather operator added (#1513)

parent d6f3bf16
......@@ -446,7 +446,6 @@ class Unsqueeze(OnnxOpConverter):
inputs[0] = _sym.expand_dims(inputs[0], axis=axes, num_newaxis=1)
return inputs[0]
class Slice(OnnxOpConverter):
""" Operator converter for Slice.
"""
......@@ -487,6 +486,19 @@ class Slice(OnnxOpConverter):
'ends': 'end'},
ignores=['axes'])(inputs, attr)
class Gather(OnnxOpConverter):
""" Operator converter for Gather.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
axis = attr['axis']
indices = np.array(attr['indices'], dtype='int32')
name = 'gather_indices'
gather_indices = _sym.Variable(name=name, init=indices)
params[name] = indices
return _sym.take(inputs[0], gather_indices, axis=axis)
# compatible operators that do NOT require any conversion.
_identity_list = []
......@@ -593,7 +605,7 @@ def _get_convert_map(opset):
'Split': AttrCvt('split', {'split': 'indices_or_sections'}),
'Slice': Slice.get_converter(opset),
'Transpose': AttrCvt('transpose', {'perm': 'axes'}),
# 'Gather'
'Gather': Gather.get_converter(opset),
'Squeeze': Renamer('squeeze'),
'Unsqueeze': Unsqueeze.get_converter(opset),
'Pad': Pad.get_converter(opset),
......
......@@ -189,6 +189,34 @@ def test_unsqueeze():
np.testing.assert_allclose(out_shape, tvm_out.shape)
def verify_gather(in_shape, indices, axis=0):
indices_src = np.array(indices, dtype="int32")
x = np.random.uniform(size=in_shape)
out_np = np.take(x, indices_src, axis=axis)
y = helper.make_node("Gather", ['in'], ['out'], indices=indices, axis=axis)
graph = helper.make_graph([y],
'gather_test',
inputs = [helper.make_tensor_value_info("in",
TensorProto.FLOAT, list(in_shape))],
outputs = [helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(out_np.shape))])
model = helper.make_model(graph, producer_name='gather_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, x, target, ctx, out_np.shape, 'float32')
np.testing.assert_allclose(out_np, tvm_out)
def test_gather():
verify_gather((4,), [1])
verify_gather((4,), [0, 1, 2, 3])
verify_gather((4, 2), [1], 1)
verify_gather((4, 3, 5, 6), [2, 1, 0, 0], -2)
def _test_slice_iteration(indata, outdata, starts, ends, axes=None):
if axes:
y = helper.make_node("Slice", ['in'], ['out'], axes=axes, starts=starts, ends=ends)
......@@ -299,3 +327,4 @@ if __name__ == '__main__':
test_ceil()
test_clip()
test_matmul()
test_gather()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment