Commit aa49e851 by Wuwei Lin Committed by Zhi

[Relay][Pass] Avoid FoldConstant folding some ops (#4245)

* [Relay][Pass] Avoid FoldConstant folding some ops

* rename
parent cd717dea
......@@ -102,6 +102,9 @@ class ConstantFolder : public ExprMutator {
Expr VisitExpr_(const CallNode* call) final {
static auto op_stateful = Op::GetAttr<TOpIsStateful>("TOpIsStateful");
std::unordered_set<std::string> skip_list{"zeros_like", "ones_like", "full_like", "full"};
auto origin_args = call->args;
Expr res = ExprMutator::VisitExpr_(call);
call = res.as<CallNode>();
......@@ -111,6 +114,9 @@ class ConstantFolder : public ExprMutator {
if (call->args.size() == 0) return res;
const OpNode* op = call->op.as<OpNode>();
if (op == nullptr) return res;
if (skip_list.count(op->name)) {
return res;
}
// skip stateful ops.
if (op_stateful.get(GetRef<Op>(op), false)) return res;
// Try to evaluate shape_of op
......
......@@ -146,9 +146,25 @@ def test_fold_shape_of():
assert relay.analysis.graph_equal(zz, zexpected)
def test_fold_full():
c_shape = (8, 9, 10)
def before():
dtype = 'float32'
return relay.full(relay.const(1.0, dtype), c_shape, dtype=dtype)
def expected():
# expect no changes
return before()
zz = run_opt_pass(before(), transform.FoldConstant())
zexpected = run_opt_pass(expected(), transform.InferType())
assert relay.analysis.graph_equal(zz, zexpected)
if __name__ == "__main__":
test_fold_const()
test_fold_let()
test_fold_tuple()
test_fold_concat()
test_fold_shape_of()
test_fold_full()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment