Commit 1c87e009 by Jared Roesch Committed by Tianqi Chen

Do not mutate GlobalVar's checked_type field. (#2026)

parent 32f158e6
...@@ -366,7 +366,7 @@ class TypeInferencer::Resolver : public ExprMutator { ...@@ -366,7 +366,7 @@ class TypeInferencer::Resolver : public ExprMutator {
} }
Expr VisitExpr_(const GlobalVarNode* op) final { Expr VisitExpr_(const GlobalVarNode* op) final {
return AttachCheckedType(op); return GetRef<GlobalVar>(op);
} }
Expr VisitExpr_(const OpNode* op) final { Expr VisitExpr_(const OpNode* op) final {
......
...@@ -123,6 +123,16 @@ def test_self_reference(): ...@@ -123,6 +123,16 @@ def test_self_reference():
assert relay.ir_pass.infer_type(f).checked_type == relay.FuncType([a], a) assert relay.ir_pass.infer_type(f).checked_type == relay.FuncType([a], a)
assert relay.ir_pass.infer_type(fx).checked_type == a assert relay.ir_pass.infer_type(fx).checked_type == a
def test_global_var_cow_issue():
env = relay.env.Environment({})
gv = relay.GlobalVar("foo")
x = relay.var('x', shape=[])
func = relay.Function([x], relay.Call(gv, [x]), relay.TensorType([], 'float32'))
env[gv] = func
# They should both point to the same global variable if global variables are
# stable across type checking.
assert gv == func.body.op
if __name__ == "__main__": if __name__ == "__main__":
test_free_expr() test_free_expr()
test_dual_op() test_dual_op()
...@@ -134,3 +144,4 @@ if __name__ == "__main__": ...@@ -134,3 +144,4 @@ if __name__ == "__main__":
test_free_expr() test_free_expr()
test_type_args() test_type_args()
test_self_reference() test_self_reference()
test_global_var_cow_issue()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment