Commit 7ea06e6e by Siju Committed by Tianqi Chen

[ONNX]onnx gather bug fix (#1543)

parent 60da4705
......@@ -489,15 +489,11 @@ class Slice(OnnxOpConverter):
class Gather(OnnxOpConverter):
""" Operator converter for Gather.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
axis = attr['axis']
indices = np.array(attr['indices'], dtype='int32')
name = 'gather_indices'
gather_indices = _sym.Variable(name=name, init=indices)
params[name] = indices
return _sym.take(inputs[0], gather_indices, axis=axis)
axis = attr.get('axis', 0)
return AttrCvt(op_name='take',
extras={'axis':axis})(inputs, attr)
class LRN(OnnxOpConverter):
""" Operator converter for Local Response Normalization.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment