From 1c87e009d416efc0b82866658ef9236a8d7f8018 Mon Sep 17 00:00:00 2001 From: Jared Roesch <roeschinc@gmail.com> Date: Sun, 28 Oct 2018 20:54:47 -0700 Subject: [PATCH] Do not mutate GlobalVar's checked_type field. (#2026) --- src/relay/pass/type_infer.cc | 2 +- tests/python/relay/test_type_infer.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index c1f6cdc63..c0f1db97b 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -366,7 +366,7 @@ class TypeInferencer::Resolver : public ExprMutator { } Expr VisitExpr_(const GlobalVarNode* op) final { - return AttachCheckedType(op); + return GetRef<GlobalVar>(op); } Expr VisitExpr_(const OpNode* op) final { diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index 8f92fc0f5..b18230040 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -123,6 +123,16 @@ def test_self_reference(): assert relay.ir_pass.infer_type(f).checked_type == relay.FuncType([a], a) assert relay.ir_pass.infer_type(fx).checked_type == a +def test_global_var_cow_issue(): + env = relay.env.Environment({}) + gv = relay.GlobalVar("foo") + x = relay.var('x', shape=[]) + func = relay.Function([x], relay.Call(gv, [x]), relay.TensorType([], 'float32')) + env[gv] = func + # They should both point to the same global variable if global variables are + # stable across type checking. + assert gv == func.body.op + if __name__ == "__main__": test_free_expr() test_dual_op() @@ -134,3 +144,4 @@ if __name__ == "__main__": test_free_expr() test_type_args() test_self_reference() + test_global_var_cow_issue() \ No newline at end of file -- GitLab