Commit 7ea06e6e by Siju Committed by Tianqi Chen

[ONNX]onnx gather bug fix (#1543)

parent 60da4705
...@@ -489,15 +489,11 @@ class Slice(OnnxOpConverter): ...@@ -489,15 +489,11 @@ class Slice(OnnxOpConverter):
class Gather(OnnxOpConverter): class Gather(OnnxOpConverter):
""" Operator converter for Gather. """ Operator converter for Gather.
""" """
@classmethod @classmethod
def _impl_v1(cls, inputs, attr, params): def _impl_v1(cls, inputs, attr, params):
axis = attr['axis'] axis = attr.get('axis', 0)
indices = np.array(attr['indices'], dtype='int32') return AttrCvt(op_name='take',
name = 'gather_indices' extras={'axis':axis})(inputs, attr)
gather_indices = _sym.Variable(name=name, init=indices)
params[name] = indices
return _sym.take(inputs[0], gather_indices, axis=axis)
class LRN(OnnxOpConverter): class LRN(OnnxOpConverter):
""" Operator converter for Local Response Normalization. """ Operator converter for Local Response Normalization.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment