Unverified Commit fccf2268 by hcyang Committed by GitHub

[RELAY][FRONTEND][TF] Fix FuseBatchNorm output cast error if need_cast is True (#4894)

parent 406b5f76
...@@ -897,6 +897,7 @@ def _fused_batch_norm(): ...@@ -897,6 +897,7 @@ def _fused_batch_norm():
disables=['momentum'])(inputs, attr) disables=['momentum'])(inputs, attr)
if need_cast: if need_cast:
out = _expr.TupleGetItem(out.astuple(), 0)
out = _op.cast(out, dtype=attr['T'].name) out = _op.cast(out, dtype=attr['T'].name)
return out return out
return _impl return _impl
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment