diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h
index c72612791b5211208d74c8989ea9eea5698fde77..887d28b0fa9f8e7f6b11e9587819cb3a2abd37c3 100644
--- a/include/tvm/relay/expr.h
+++ b/include/tvm/relay/expr.h
@@ -458,12 +458,14 @@ inline const TTypeNode* ExprNode::type_as() const {
 /*!
  * \brief Print node as text format.
  * \param node The node to be printed.
+ * \param show_meta_data Whether to print meta data section.
  * \param annotate An optional callback function for attaching
  *        additional comment block to an expr.
  * \return The text representation.
  */
 std::string RelayPrint(
     const NodeRef& node,
+    bool show_meta_data = true,
     runtime::TypedPackedFunc<std::string(Expr)> annotate = nullptr);
 }  // namespace relay
 }  // namespace tvm
diff --git a/python/tvm/relay/backend/_backend.py b/python/tvm/relay/backend/_backend.py
index b5454031cb4ab95b50629d6202622cb89ce48d6a..a51cc8072aaca36dd885582b2bae94a05ba6faac 100644
--- a/python/tvm/relay/backend/_backend.py
+++ b/python/tvm/relay/backend/_backend.py
@@ -55,6 +55,7 @@ def build(funcs, target, target_host=None):
     funcs : List[tvm.LoweredFunc]
          The list of lowered functions.
 
+
     target : tvm.Target
          The target to run the code on.
 
diff --git a/python/tvm/relay/backend/graph_runtime_codegen.py b/python/tvm/relay/backend/graph_runtime_codegen.py
index 4bbab957ab1d1f1d41082997d64d21bb039e086d..50568b58607b8e80cebfd6549e5b05cb9858ed55 100644
--- a/python/tvm/relay/backend/graph_runtime_codegen.py
+++ b/python/tvm/relay/backend/graph_runtime_codegen.py
@@ -21,6 +21,7 @@ contrib.graph_runtime or any other TVM runtime comptatible system.
 from __future__ import absolute_import
 import json
 import attr
+from . import _backend
 from . import compile_engine
 from ..op import Op
 from ..expr import Function, GlobalVar, ExprFunctor
@@ -103,11 +104,12 @@ class GraphRuntimeCodegen(ExprFunctor):
         self.nodes = []
         self.var_map = {}
         self.params = {}
+        self.storage_map = None
         self.compile_engine = compile_engine.get()
         self.lowered_funcs = set()
         self._name_map = {}
 
-    def add_node(self, node, checked_type):
+    def add_node(self, node, expr):
         """
         Add a node to the graph.
 
@@ -116,14 +118,21 @@ class GraphRuntimeCodegen(ExprFunctor):
         node: Node
             The node to add to the graph.
 
-        checked_type: Type
-            The type of the node.
+        expr: tvm.relay.Expr
+            The corresponding expression.
 
         Returns
         -------
         node_ref: Union[NodeRef, List[NodeRef]]
             A reference to the node.
         """
+        checked_type = expr.checked_type
+        # setup storage ids
+        assert expr in self.storage_map
+        node.attrs["storage_id"] = [
+            x.value for x in self.storage_map[expr]
+        ]
+
         node_id = len(self.nodes)
         self.nodes.append(node)
         # Tuple return value, flatten as tuple
@@ -168,7 +177,7 @@ class GraphRuntimeCodegen(ExprFunctor):
         name = "p%d" % index
         self.params[name] = op.data
         node = InputNode(name, {})
-        return self.add_node(node, op.checked_type)
+        return self.add_node(node, op)
 
     def visit_function(self, _):
         raise RuntimeError("function not supported")
@@ -244,7 +253,7 @@ class GraphRuntimeCodegen(ExprFunctor):
         op_name = cached_func.func_name
         op_node = OpNode(self._get_unique_name(op_name), {},
                          op_name, inputs, {})
-        return self.add_node(op_node, call.checked_type)
+        return self.add_node(op_node, call)
 
     def _get_json(self):
         """
@@ -281,8 +290,7 @@ class GraphRuntimeCodegen(ExprFunctor):
             assert node.num_outputs == len(node.attrs["shape"])
             shapes += node.attrs["shape"]
             dltypes += node.attrs["dtype"]
-            for i in range(node.num_outputs):
-                storage_ids.append(i + num_entry)
+            storage_ids += node.attrs["storage_id"]
             num_entry += node.num_outputs
             node_row_ptr.append(num_entry)
 
@@ -302,6 +310,14 @@ class GraphRuntimeCodegen(ExprFunctor):
 
         return json.dumps(json_dict, indent=2)
 
+    def debug_dump_memory_plan(self, func):
+        """Debug function to dump memory plan."""
+        def _annotate(expr):
+            if expr in self.storage_map:
+                return str(self.storage_map[expr])
+            return ""
+        return func.astext(show_meta_data=False, annotate=_annotate)
+
     def codegen(self, func):
         """Compile a single function into a graph.
 
@@ -321,11 +337,12 @@ class GraphRuntimeCodegen(ExprFunctor):
         params : Dict[str, tvm.nd.NDArray]
             Additional constant parameters.
         """
+        self.storage_map = _backend.GraphPlanMemory(func)
         # First we convert all the parameters into input nodes.
         for param in func.params:
             node = InputNode(param.name_hint, {})
             self.var_map[param] = self.add_node(
-                node, param.type_annotation)
+                node, param)
 
         # Then we compile the body into a graph which can depend
         # on input variables.
diff --git a/python/tvm/relay/base.py b/python/tvm/relay/base.py
index 012315b40f5104b3809e079ba4db42698384c878..0feffeb809c548aa4bb7ee7d5aa308cb3c1a755f 100644
--- a/python/tvm/relay/base.py
+++ b/python/tvm/relay/base.py
@@ -23,7 +23,7 @@ def register_relay_node(type_key=None):
 
 class RelayNode(NodeBase):
     """Base class of all relay node."""
-    def astext(self, annotate=None):
+    def astext(self, show_meta_data=True, annotate=None):
         """Get the text format of the expression.
 
         Returns
@@ -31,11 +31,21 @@ class RelayNode(NodeBase):
         text : str
             The text format of the expression.
 
+        show_meta_data : bool
+            Whether to include meta data section in the text
+            if there is meta data.
+
         annotate: Optional[relay.Expr->str]
             Optional annotate function to provide additional
             information in the comment block.
+
+        Note
+        ----
+        meta data section is necessary to fully parse the text format.
+        However, it can contain dumps that are big(constat weights),
+        so it can be helpful to skip printing the meta data section.
         """
-        return _expr.RelayPrint(self, annotate)
+        return _expr.RelayPrint(self, show_meta_data, annotate)
 
 
 @register_relay_node
diff --git a/src/relay/backend/graph_plan_memory.cc b/src/relay/backend/graph_plan_memory.cc
new file mode 100644
index 0000000000000000000000000000000000000000..f3c3e2935d224b866da51e0263e60d00f3d5d3ed
--- /dev/null
+++ b/src/relay/backend/graph_plan_memory.cc
@@ -0,0 +1,349 @@
+/*!
+ *  Copyright (c) 2018 by Contributors
+ * \file relay/backend/graph_mem_alloca.cc
+ * \brief Memory index assignment pass for executing
+ *   the program in the graph runtime.
+ */
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include "../../common/arena.h"
+
+namespace tvm {
+namespace relay {
+
+struct StorageToken {
+  /*! \brief Reference counter */
+  int ref_counter{0};
+  /*! \brief number of bytes */
+  size_t max_bytes{0};
+  /*! \brief The corresponding tensor type node. */
+  const TensorTypeNode* ttype{nullptr};
+  /*! \brief virtual device index */
+  int device_id{0};
+  /*! \brief The storage id */
+  int64_t storage_id{-1};
+};
+
+class StorageAllocaBaseVisitor : public ExprVisitor {
+ public:
+  // run the visitor on a function.
+  void Run(const Function& func) {
+    for (Var param : func->params) {
+      CreateToken(param.operator->(), false);
+    }
+    this->VisitExpr(func->body);
+  }
+
+  void VisitExpr_(const ConstantNode* op) final {
+    this->CreateToken(op, false);
+  }
+
+  void VisitExpr_(const VarNode* op) final {
+    // Do nothing.
+  }
+
+  void VisitExpr_(const FunctionNode* op) final {
+    // do not recursive into sub function.
+  }
+
+  void VisitExpr_(const GlobalVarNode* op) final {
+    // Do nothing.
+  }
+
+  void VisitExpr_(const OpNode* op) final {
+    // Do nothing.
+  }
+
+  void VisitExpr_(const TupleNode* op) final {
+    std::vector<StorageToken*> fields;
+    for (Expr field : op->fields) {
+      auto tok = GetToken(field);
+      CHECK_EQ(tok.size(), 1U);
+      fields.push_back(tok[0]);
+    }
+    token_map_[op] = fields;
+  }
+
+  void VisitExpr_(const TupleGetItemNode* op) final {
+    const auto& tok = GetToken(op->tuple);
+    CHECK_LT(static_cast<size_t>(op->index), tok.size());
+    token_map_[op] = {tok[op->index]};
+  }
+
+  void VisitExpr_(const IfNode* op) final {
+    LOG(FATAL) << "if is not supported.";
+  }
+
+  void VisitExpr_(const LetNode* op) final {
+    auto token = GetToken(op->value);
+    token_map_[op->var.operator->()] = token;
+    token_map_[op] = GetToken(op->body);
+  }
+
+ protected:
+  /*! \brief internal token map */
+  std::unordered_map<const ExprNode*, std::vector<StorageToken*> > token_map_;
+
+  /*!
+   * \brief Get the necessary token.
+   * \param expr The expression.
+   * \return The corresponding token.
+   */
+  const std::vector<StorageToken*>& GetToken(const Expr& expr) {
+    this->VisitExpr(expr);
+    auto it = token_map_.find(expr.operator->());
+    CHECK(it != token_map_.end());
+    return it->second;
+  }
+  /*!
+   * \brief Populate the token map to set op's tokens
+   * \param op The node to be processed.
+   * \param can_realloc Whether we can re-allocate the memory.
+   */
+  virtual void CreateToken(const ExprNode* op, bool can_realloc) = 0;
+};
+
+
+class StorageAllocaInit : protected StorageAllocaBaseVisitor {
+ public:
+  explicit StorageAllocaInit(common::Arena* arena)
+      : arena_(arena) {}
+
+
+  /*! \return The internal token map */
+  std::unordered_map<const ExprNode*, std::vector<StorageToken*> >
+  GetInitTokenMap(const Function& func) {
+    this->Run(func);
+    return std::move(token_map_);
+  }
+
+
+ protected:
+  using StorageAllocaBaseVisitor::VisitExpr_;
+
+  void CreateToken(const ExprNode* op, bool can_realloc)  final {
+    CHECK(!token_map_.count(op));
+    std::vector<StorageToken*> tokens;
+    if (const auto* tuple_type = op->checked_type().as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        const auto* ttype = t.as<TensorTypeNode>();
+        CHECK(ttype);
+        StorageToken* token = arena_->make<StorageToken>();
+        token->ttype = ttype;
+        tokens.push_back(token);
+      }
+    } else {
+      const auto* ttype = op->checked_type().as<TensorTypeNode>();
+      CHECK(ttype);
+      StorageToken* token = arena_->make<StorageToken>();
+      token->ttype = ttype;
+      tokens.push_back(token);
+    }
+    token_map_[op] = tokens;
+  }
+
+  void VisitExpr_(const CallNode* op) final {
+    // create token for the call node.
+    CreateToken(op, true);
+    // for each input, visit argument token.
+    for (Expr arg : op->args) {
+      for (StorageToken* tok : GetToken(arg)) {
+        tok->ref_counter += 1;
+      }
+    }
+  }
+
+ private:
+  // allocator
+  common::Arena* arena_;
+};
+
+
+class StorageAllocator : public StorageAllocaBaseVisitor {
+ public:
+  /*!
+   * \return totoal number of bytes allocated
+   */
+  size_t TotalAllocBytes() const {
+    size_t total = 0;
+    for (const auto* p : data_) {
+      total += p->max_bytes;
+    }
+    return total;
+  }
+
+  // Run storage allocation for a function.
+  Map<Expr, Array<Integer> > Plan(const Function& func) {
+    prototype_ = StorageAllocaInit(&arena_).GetInitTokenMap(func);
+    this->Run(func);
+
+    Map<Expr, Array<Integer> > smap;
+
+    for (const auto& kv : token_map_) {
+      Array<Integer> vec;
+      for (StorageToken* tok : kv.second) {
+        vec.push_back(tok->storage_id);
+      }
+      smap.Set(GetRef<Expr>(kv.first), vec);
+    }
+    return smap;
+  }
+
+
+ protected:
+  using StorageAllocaBaseVisitor::VisitExpr_;
+  // override create token by getting token as prototype requirements.
+  void CreateToken(const ExprNode* op, bool can_realloc)  final {
+    CHECK(!token_map_.count(op));
+    auto it = prototype_.find(op);
+    CHECK(it != prototype_.end());
+    std::vector<StorageToken*> tokens;
+    for (StorageToken* tok : it->second) {
+      if (can_realloc) {
+        tokens.push_back(Request(tok));
+      } else {
+        // Allocate a new token,
+        StorageToken* allocated_tok = Alloc(tok, GetMemorySize(tok));
+        // ensure it never get de-allocated.
+        allocated_tok->ref_counter += 1;
+        tokens.push_back(allocated_tok);
+      }
+    }
+    token_map_[op] = tokens;
+  }
+  // The call map
+  void VisitExpr_(const CallNode* op) final {
+    std::vector<StorageToken*> args;
+    // for each input, visit argument token.
+    for (Expr arg : op->args) {
+      for (StorageToken* tok : GetToken(arg)) {
+        args.push_back(tok);
+      }
+    }
+    // create token for the call node.
+    CreateToken(op, true);
+    // check if there is orphaned output that can be released immediately.
+    for (StorageToken* tok : token_map_.at(op)) {
+      CheckForRelease(tok);
+    }
+    for (StorageToken* tok : args) {
+      tok->ref_counter -= 1;
+      CheckForRelease(tok);
+    }
+  }
+  /*!
+   * \brief ceil(size/word_size) to get number of words.
+   * \param size The original size.
+   * \param word_size The element size.
+   */
+  static size_t DivRoundUp(size_t size, size_t word_size) {
+    return (size + word_size - 1) / word_size;
+  }
+  /*!
+   * \brief Get the memory requirement.
+   * \param prototype The prototype token.
+   * \return The required memory size.
+   */
+  size_t GetMemorySize(StorageToken* prototype) {
+    const TensorTypeNode* ttype = prototype->ttype;
+    CHECK(ttype != nullptr);
+    size_t size = 1;
+    for (IndexExpr dim : ttype->shape) {
+      const int64_t* pval = as_const_int(dim);
+      CHECK(pval != nullptr)
+          << "Cannot allocate memory symbolic tensor shape "
+          << ttype->shape;
+      size *= static_cast<size_t>(pval[0]);
+    }
+    size *= DivRoundUp(ttype->dtype.bits() * ttype->dtype.lanes(), 8);
+    return size;
+  }
+  /*!
+   * \brief Request a storage token for a given prototype.
+   * \param prototype. The prototype storage token.
+   * \return The result token.
+   */
+  StorageToken* Request(StorageToken* prototype) {
+    // calculate the size;
+    size_t size = GetMemorySize(prototype);
+    // search memory block in [size / match_range_, size * match_range_)
+    if (match_range_ == 0) {
+      return this->Alloc(prototype, size);
+    }
+    auto begin = free_.lower_bound(size / match_range_);
+    auto mid = free_.lower_bound(size);
+    auto end = free_.upper_bound(size * match_range_);
+    // search for memory blocks larger than requested
+    for (auto it = mid; it != end; ++it) {
+      StorageToken *tok = it->second;
+      if (tok->device_id != prototype->device_id) continue;
+      CHECK_EQ(tok->ref_counter, 0);
+      // Use exect matching strategy
+      tok->max_bytes = std::max(size, tok->max_bytes);
+      tok->ref_counter = prototype->ref_counter;
+      // find a exact match, erase from map and return
+      free_.erase(it);
+      return tok;
+    }
+    // then search for memory blocks smaller than requested space
+    for (auto it = mid; it != begin;) {
+      --it;
+      StorageToken *tok = it->second;
+      if (tok->device_id != prototype->device_id) continue;
+      CHECK_EQ(tok->ref_counter, 0);
+      // Use exect matching strategy
+      tok->max_bytes = std::max(size, tok->max_bytes);
+      tok->ref_counter = prototype->ref_counter;
+      // erase from map and return
+      free_.erase(it);
+      return tok;
+    }
+    // cannot find anything return a new one.
+    return this->Alloc(prototype, size);
+  }
+  /*!
+   * \brief Allocate a storage token by consuming prototype
+   * \param prototype The prototype token.
+   * \param size The size of memory being requested.
+   */
+  StorageToken* Alloc(StorageToken* prototype, size_t size) {
+    prototype->max_bytes = size;
+    prototype->storage_id = static_cast<int64_t>(data_.size());
+    data_.push_back(prototype);
+    return prototype;
+  }
+  /*!
+   * \brief Check if we can release token.
+   * \tok The token to be released.
+   */
+  void CheckForRelease(StorageToken* tok) {
+    CHECK_GE(tok->storage_id, 0);
+    CHECK_GE(tok->ref_counter, 0);
+    if (tok->ref_counter == 0) {
+      free_.insert({tok->max_bytes, tok});
+    }
+  }
+
+ private:
+  // allocator
+  common::Arena arena_;
+  // scale used for rough match
+  size_t match_range_{16};
+  // free list of storage entry
+  std::multimap<size_t, StorageToken*> free_;
+  // all the storage resources available
+  std::vector<StorageToken*> data_;
+  /*! \brief internal prototype token map */
+  std::unordered_map<const ExprNode*, std::vector<StorageToken*> > prototype_;
+};
+
+
+Map<Expr, Array<Integer> > GraphPlanMemory(const Function& func) {
+  return StorageAllocator().Plan(func);
+}
+
+TVM_REGISTER_GLOBAL("relay.backend.GraphPlanMemory")
+.set_body_typed<Map<Expr, Array<Integer> >(const Function&)>(GraphPlanMemory);
+
+}  // namespace relay
+}  // namespace tvm
diff --git a/src/relay/ir/text_printer.cc b/src/relay/ir/text_printer.cc
index bfc5f0db52b72d37ac2ae319b25501ee755832c2..5e97ce1010ade363f1d44264eb930a4939eca374 100644
--- a/src/relay/ir/text_printer.cc
+++ b/src/relay/ir/text_printer.cc
@@ -113,6 +113,11 @@ class TextMetaDataContext {
     return SaveJSON(Array<NodeRef>(meta_data_));
   }
 
+  /*! \return whether the meta data context is empty. */
+  bool empty() const {
+    return meta_data_.empty();
+  }
+
  private:
   /*! \brief additional metadata stored in TVM json format */
   std::vector<NodeRef> meta_data_;
@@ -125,8 +130,9 @@ class TextPrinter :
     public TypeFunctor<void (const Type&, std::ostream& os)>,  // NOLINT(*)
     public AttrFunctor<void (const NodeRef&, std::ostream& os)> { // NOLINT(*)
  public:
-  explicit TextPrinter(runtime::TypedPackedFunc<std::string(Expr)> annotate)
-      : annotate_(annotate) {}
+  explicit TextPrinter(bool show_meta_data,
+                       runtime::TypedPackedFunc<std::string(Expr)> annotate)
+      : show_meta_data_(show_meta_data), annotate_(annotate) {}
   /*!
    * \brief Print a node to string.
    * \param node.
@@ -144,13 +150,17 @@ class TextPrinter :
     } else {
       stream_ << node;
     }
-    std::string meta_json = meta_.GetMetaSection();
-    if (meta_json.length() != 0) {
-      // append meta data in the end.
-      stream_ << "# meta data\n"
-              << "r\"\"\"\n"
-              << meta_json << "\n"
-              << "\"\"\"";
+    if (!meta_.empty()) {
+      if (show_meta_data_) {
+        std::string meta_json = meta_.GetMetaSection();
+        // append meta data in the end.
+        stream_ << "# meta data\n"
+                << "r\"\"\"\n"
+                << meta_json << "\n"
+                << "\"\"\"";
+      } else {
+        stream_ << "# meta data omitted. you can use show_meta_data=True to include meta-data\n";
+      }
     }
     return stream_.str();
   }
@@ -227,7 +237,9 @@ class TextPrinter :
     TextValue id = this->AllocTempVar();
     this->PrintIndent();
     stream_ << id << " = " << meta_.GetMetaNode(GetRef<NodeRef>(op));
-    this->PrintEndInst("\n");
+    this->PrintEndInst("");
+    this->PrintOptionalInfo(GetRef<Expr>(op));
+    stream_ << '\n';
     return id;
   }
 
@@ -697,6 +709,8 @@ class TextPrinter :
  private:
   class AttrPrinter;
   friend class AttrPrinter;
+  /*! \brief Whether to print meta data. */
+  bool show_meta_data_;
   /*! \brief additional comment function */
   runtime::TypedPackedFunc<std::string(Expr)> annotate_;
   /*! \brief meta data context */
@@ -790,13 +804,14 @@ void TextPrinter::PrintCallAttrs(const Expr& op,
 }
 
 std::string RelayPrint(const NodeRef& node,
+                       bool show_meta_data,
                        runtime::TypedPackedFunc<std::string(Expr)> annotate) {
-  return TextPrinter(annotate).Print(node);
+  return TextPrinter(show_meta_data, annotate).Print(node);
 }
 
 TVM_REGISTER_API("relay._expr.RelayPrint")
 .set_body_typed<std::string(
-    const NodeRef&,
+    const NodeRef&, bool,
     runtime::TypedPackedFunc<std::string(Expr)>)>(RelayPrint);
 
 }  // namespace relay
diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc
index cb5f86f4b525d08d6a02a06705582e2e776043b1..b9e0823e88fac5443ab8ca2f4fa580a123d8da04 100644
--- a/src/relay/pass/fuse_ops.cc
+++ b/src/relay/pass/fuse_ops.cc
@@ -749,7 +749,7 @@ class FuseMutator : private ExprMutator {
   }
   // Debug function, dump the group assignment in text.
   void DebugDumpGroup(const Expr& body) {
-    std::string text = RelayPrint(body, [this](const Expr& expr) -> std::string {
+    std::string text = RelayPrint(body, false, [this](const Expr& expr) -> std::string {
         auto it = gmap_.find(expr.get());
         if (it == gmap_.end()) return "";
         std::ostringstream os;
diff --git a/tests/python/relay/test_backend_graph_runtime.py b/tests/python/relay/test_backend_graph_runtime.py
index 7b610f82f6a53d799c0121128095e39f8a391a20..7baa906abacc47e3dd5187b58a60074188225679 100644
--- a/tests/python/relay/test_backend_graph_runtime.py
+++ b/tests/python/relay/test_backend_graph_runtime.py
@@ -77,7 +77,9 @@ def test_add_op_broadcast():
 def test_with_params():
     x = relay.var('x', shape=(10, 5))
     y = relay.var('y', shape=(1, 5))
-    func = relay.Function([x, y], add(x, y))
+    z = relay.add(x, y)
+    z = relay.exp(z)
+    func = relay.Function([x, y], z)
     x_data = np.random.rand(10, 5).astype('float32')
     y_data = np.random.rand(1, 5).astype('float32')
     params = {"y": y_data}
@@ -87,11 +89,40 @@ def test_with_params():
     mod.set_input(x=x_data)
     mod.run()
     res = mod.get_output(0).asnumpy()
-    ref_res = y_data + x_data
+    ref_res = np.exp(y_data + x_data)
     tvm.testing.assert_allclose(res, ref_res)
 
 
+def test_plan_memory():
+    # it is sufficient to cycle through two memories.
+
+    x = relay.var("x", shape=(10,))
+    y = relay.var("x", shape=(1,))
+    y2 = relay.exp(y)
+    z = relay.add(x, y2)
+    z = relay.exp(z)
+    z = relay.exp(z)
+    z = relay.exp(z)
+    z = relay.exp(z)
+    z = relay.exp(z)
+    func = relay.Function([x, y], z)
+    func = relay.ir_pass.infer_type(func)
+    func = relay.ir_pass.fuse_ops(func, opt_level=0)
+    func = relay.ir_pass.infer_type(func)
+    smap = relay.backend._backend.GraphPlanMemory(func)
+    storage_ids = set()
+    for k, v in smap.items():
+        for x in v:
+            storage_ids.add(x.value)
+
+    # Current rule requires vars have unique storage id
+    # because we don't do inplace, we will need another
+    # two alternating temporary space.
+    assert len(storage_ids) == 4
+
+
 if __name__ == "__main__":
+    test_plan_memory()
     test_with_params()
     test_add_op_scalar()
     test_add_op_tensor()