From 35277c2f95b0404322345d591e0d8798f7f35dc7 Mon Sep 17 00:00:00 2001
From: tqchen <tianqi.tchen@gmail.com>
Date: Sun, 16 Oct 2016 22:53:47 -0700
Subject: [PATCH] Add safe destructor

---
 include/tvm/base.h             | 10 ++++++++++
 include/tvm/expr_node.h        |  6 ++++++
 python/tvm/cpp/_ctypes/_api.py |  2 +-
 src/expr/expr_node.cc          | 26 ++++++++++++++++++++++++++
 4 files changed, 43 insertions(+), 1 deletion(-)

diff --git a/include/tvm/base.h b/include/tvm/base.h
index 4bf7a8f69..d406b29f8 100644
--- a/include/tvm/base.h
+++ b/include/tvm/base.h
@@ -105,6 +105,15 @@ class Node {
  protected:
   // node ref can see this
   friend class NodeRef;
+  /*!
+   * \brief optional: safe destruction function
+   *  Can be called in destructor of composite types.
+   *  This can be used to avoid stack overflow when
+   *  recursive destruction long graph(1M nodes),
+   *
+   *  It is totally OK to not call this in destructor.
+   */
+  void Destroy();
   /*! \brief the node type enum */
   NodeType node_type_{kOtherNodes};
 };
@@ -127,6 +136,7 @@ class NodeRef {
   template<typename T, typename>
   friend class Array;
   friend class APIVariantValue;
+  friend class Node;
   NodeRef() = default;
   explicit NodeRef(std::shared_ptr<Node>&& node) : node_(std::move(node)) {}
   /*! \brief the internal node */
diff --git a/include/tvm/expr_node.h b/include/tvm/expr_node.h
index 899fea4c3..c6ea96f9b 100644
--- a/include/tvm/expr_node.h
+++ b/include/tvm/expr_node.h
@@ -82,6 +82,9 @@ class UnaryOpNode : public ExprNode {
     node_type_ = kUnaryOpNode;
     dtype_ = this->src.dtype();
   }
+  ~UnaryOpNode() {
+    this->Destroy();
+  }
   const char* type_key() const override {
     return "UnaryOpNode";
   }
@@ -114,6 +117,9 @@ struct BinaryOpNode : public ExprNode {
     node_type_ = kBinaryOpNode;
     dtype_ = this->lhs.dtype();
   }
+  ~BinaryOpNode() {
+    this->Destroy();
+  }
   const char* type_key() const override {
     return "BinaryOpNode";
   }
diff --git a/python/tvm/cpp/_ctypes/_api.py b/python/tvm/cpp/_ctypes/_api.py
index 0b35b9580..3794bedee 100644
--- a/python/tvm/cpp/_ctypes/_api.py
+++ b/python/tvm/cpp/_ctypes/_api.py
@@ -50,7 +50,7 @@ class NodeBase(object):
         check_call(_LIB.TVMNodeGetAttr(
             self.handle, c_str(name),
             ctypes.byref(ret_val), ctypes.byref(ret_typeid)))
-        return RET_SWITCH[ret_typeid.value](ret_val)
+        ret = RET_SWITCH[ret_typeid.value](ret_val)
 
 
 def _type_key(handle):
diff --git a/src/expr/expr_node.cc b/src/expr/expr_node.cc
index 76aa8bf6d..c6626672e 100644
--- a/src/expr/expr_node.cc
+++ b/src/expr/expr_node.cc
@@ -11,6 +11,32 @@ DMLC_REGISTRY_ENABLE(::tvm::NodeFactoryReg);
 
 namespace tvm {
 
+void Node::Destroy() {
+  bool safe = true;
+  this->VisitNodeRefFields([&safe](const char* k, NodeRef* r) {
+      if (r->node_.get() != nullptr) safe = false;
+    });
+
+  if (!safe) {
+    // explicit deletion via DFS
+    // this is used to avoid stackoverflow caused by chain of deletions
+    std::vector<Node*> stack{this};
+    std::vector<std::shared_ptr<Node> > to_delete;
+    while (!stack.empty()) {
+      Node* n = stack.back();
+      stack.pop_back();
+      n->VisitNodeRefFields([&safe, &stack, &to_delete](const char* k, NodeRef* r) {
+          if (r->node_.unique()) {
+            stack.push_back(r->node_.get());
+            to_delete.emplace_back(std::move(r->node_));
+          } else {
+            r->node_.reset();
+          }
+        });
+    }
+  }
+}
+
 TVM_REGISTER_NODE_TYPE(VarNode);
 TVM_REGISTER_NODE_TYPE(IntNode);
 TVM_REGISTER_NODE_TYPE(FloatNode);
-- 
GitLab