Commit 3f472f94 by Trevor Morris Committed by Tianqi Chen

[Relay][Frontend][Tensorflow] Fix GatherV2, Add StopGradient (#4238)

* Add StopGradient. Add batch_dims attr to ignore list for GatherV2

* Trigger CI
parent 996cf30e
......@@ -872,11 +872,14 @@ def _gather():
axis = _get_num_param(params, inputs.pop(2))
else:
axis = 0
if int(attr.get('batch_dims', 0)) != 0:
raise tvm.error.OpAttributeUnImplemented(
'Attribute batch_dims is not supported')
new_input = inputs[0:2]
return AttrCvt(op_name="take",
extras={'axis': tvm.const(axis, 'int32')},
ignores=['Tindices', 'Tparams', 'validate_indices',
'Taxis', '_class'])(new_input, attr)
'Taxis', '_class', 'batch_dims'])(new_input, attr)
return _impl
def _gather_nd():
......@@ -1472,6 +1475,7 @@ _convert_map = {
'Square' : _square(),
'SquaredDifference' : _squared_difference(),
'Squeeze' : _squeeze(),
'StopGradient' : _identity(),
'StridedSlice' : _stridedSlice(),
'Sub' : _elemwise('subtract'),
'Sum' : _sum(),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment