From 1c87e009d416efc0b82866658ef9236a8d7f8018 Mon Sep 17 00:00:00 2001
From: Jared Roesch <roeschinc@gmail.com>
Date: Sun, 28 Oct 2018 20:54:47 -0700
Subject: [PATCH] Do not mutate GlobalVar's checked_type field. (#2026)

---
 src/relay/pass/type_infer.cc          |  2 +-
 tests/python/relay/test_type_infer.py | 11 +++++++++++
 2 files changed, 12 insertions(+), 1 deletion(-)

diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc
index c1f6cdc63..c0f1db97b 100644
--- a/src/relay/pass/type_infer.cc
+++ b/src/relay/pass/type_infer.cc
@@ -366,7 +366,7 @@ class TypeInferencer::Resolver : public ExprMutator {
   }
 
   Expr VisitExpr_(const GlobalVarNode* op) final {
-    return AttachCheckedType(op);
+    return GetRef<GlobalVar>(op);
   }
 
   Expr VisitExpr_(const OpNode* op) final {
diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py
index 8f92fc0f5..b18230040 100644
--- a/tests/python/relay/test_type_infer.py
+++ b/tests/python/relay/test_type_infer.py
@@ -123,6 +123,16 @@ def test_self_reference():
     assert relay.ir_pass.infer_type(f).checked_type == relay.FuncType([a], a)
     assert relay.ir_pass.infer_type(fx).checked_type == a
 
+def test_global_var_cow_issue():
+    env = relay.env.Environment({})
+    gv = relay.GlobalVar("foo")
+    x = relay.var('x', shape=[])
+    func = relay.Function([x], relay.Call(gv, [x]), relay.TensorType([], 'float32'))
+    env[gv] = func
+    # They should both point to the same global variable if global variables are
+    # stable across type checking.
+    assert gv == func.body.op
+
 if __name__ == "__main__":
     test_free_expr()
     test_dual_op()
@@ -134,3 +144,4 @@ if __name__ == "__main__":
     test_free_expr()
     test_type_args()
     test_self_reference()
+    test_global_var_cow_issue()
\ No newline at end of file
-- 
GitLab