Commit 3f472f94 by Trevor Morris Committed by Tianqi Chen

[Relay][Frontend][Tensorflow] Fix GatherV2, Add StopGradient (#4238)

* Add StopGradient. Add batch_dims attr to ignore list for GatherV2

* Trigger CI
parent 996cf30e
...@@ -872,11 +872,14 @@ def _gather(): ...@@ -872,11 +872,14 @@ def _gather():
axis = _get_num_param(params, inputs.pop(2)) axis = _get_num_param(params, inputs.pop(2))
else: else:
axis = 0 axis = 0
if int(attr.get('batch_dims', 0)) != 0:
raise tvm.error.OpAttributeUnImplemented(
'Attribute batch_dims is not supported')
new_input = inputs[0:2] new_input = inputs[0:2]
return AttrCvt(op_name="take", return AttrCvt(op_name="take",
extras={'axis': tvm.const(axis, 'int32')}, extras={'axis': tvm.const(axis, 'int32')},
ignores=['Tindices', 'Tparams', 'validate_indices', ignores=['Tindices', 'Tparams', 'validate_indices',
'Taxis', '_class'])(new_input, attr) 'Taxis', '_class', 'batch_dims'])(new_input, attr)
return _impl return _impl
def _gather_nd(): def _gather_nd():
...@@ -1472,6 +1475,7 @@ _convert_map = { ...@@ -1472,6 +1475,7 @@ _convert_map = {
'Square' : _square(), 'Square' : _square(),
'SquaredDifference' : _squared_difference(), 'SquaredDifference' : _squared_difference(),
'Squeeze' : _squeeze(), 'Squeeze' : _squeeze(),
'StopGradient' : _identity(),
'StridedSlice' : _stridedSlice(), 'StridedSlice' : _stridedSlice(),
'Sub' : _elemwise('subtract'), 'Sub' : _elemwise('subtract'),
'Sum' : _sum(), 'Sum' : _sum(),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment