Commit 329378cf by 雾雨魔理沙 Committed by Wuwei Lin

[Relay] Fix ad for conditional expression (#3453)

* save

* fix
parent 8fe715fe
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -311,6 +311,12 @@ struct ReverseAD : ExprMutator {
return Pair(e, RefCreateNode::make(ZerosLike(e)));
}
Expr VisitExpr_(const IfNode* op) final {
return IfNode::make(TupleGetItemNode::make(VisitExpr(op->cond), 0),
VisitExpr(op->true_branch),
VisitExpr(op->false_branch));
}
Type VisitType(const Type& t) final {
return t.defined() ? ReverseADType()(t) : t;
}
......
......@@ -231,6 +231,16 @@ def test_square_second_order():
tvm.testing.assert_allclose(grad_x.asnumpy(), 2 * np.ones_like(grad_x.asnumpy()))
def test_if():
x = relay.var("x", shape=(1, 16, 64, 64))
y = relay.var("y", shape=(1, 16, 64, 64))
cond = relay.var("cond", shape=(), dtype='uint1')
net = relay.If(cond, x, y)
net = relay.log(net)
net = relay.ir_pass.infer_type(relay.Function(relay.ir_pass.free_vars(net), net))
back_func = relay.ir_pass.infer_type(relay.ir_pass.gradient(net, mode='higher_order'))
if __name__ == "__main__":
test_id()
test_add()
......@@ -242,3 +252,4 @@ if __name__ == "__main__":
test_pow()
test_ref()
test_square_second_order()
test_if()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment