From 35277c2f95b0404322345d591e0d8798f7f35dc7 Mon Sep 17 00:00:00 2001 From: tqchen <tianqi.tchen@gmail.com> Date: Sun, 16 Oct 2016 22:53:47 -0700 Subject: [PATCH] Add safe destructor --- include/tvm/base.h | 10 ++++++++++ include/tvm/expr_node.h | 6 ++++++ python/tvm/cpp/_ctypes/_api.py | 2 +- src/expr/expr_node.cc | 26 ++++++++++++++++++++++++++ 4 files changed, 43 insertions(+), 1 deletion(-) diff --git a/include/tvm/base.h b/include/tvm/base.h index 4bf7a8f69..d406b29f8 100644 --- a/include/tvm/base.h +++ b/include/tvm/base.h @@ -105,6 +105,15 @@ class Node { protected: // node ref can see this friend class NodeRef; + /*! + * \brief optional: safe destruction function + * Can be called in destructor of composite types. + * This can be used to avoid stack overflow when + * recursive destruction long graph(1M nodes), + * + * It is totally OK to not call this in destructor. + */ + void Destroy(); /*! \brief the node type enum */ NodeType node_type_{kOtherNodes}; }; @@ -127,6 +136,7 @@ class NodeRef { template<typename T, typename> friend class Array; friend class APIVariantValue; + friend class Node; NodeRef() = default; explicit NodeRef(std::shared_ptr<Node>&& node) : node_(std::move(node)) {} /*! \brief the internal node */ diff --git a/include/tvm/expr_node.h b/include/tvm/expr_node.h index 899fea4c3..c6ea96f9b 100644 --- a/include/tvm/expr_node.h +++ b/include/tvm/expr_node.h @@ -82,6 +82,9 @@ class UnaryOpNode : public ExprNode { node_type_ = kUnaryOpNode; dtype_ = this->src.dtype(); } + ~UnaryOpNode() { + this->Destroy(); + } const char* type_key() const override { return "UnaryOpNode"; } @@ -114,6 +117,9 @@ struct BinaryOpNode : public ExprNode { node_type_ = kBinaryOpNode; dtype_ = this->lhs.dtype(); } + ~BinaryOpNode() { + this->Destroy(); + } const char* type_key() const override { return "BinaryOpNode"; } diff --git a/python/tvm/cpp/_ctypes/_api.py b/python/tvm/cpp/_ctypes/_api.py index 0b35b9580..3794bedee 100644 --- a/python/tvm/cpp/_ctypes/_api.py +++ b/python/tvm/cpp/_ctypes/_api.py @@ -50,7 +50,7 @@ class NodeBase(object): check_call(_LIB.TVMNodeGetAttr( self.handle, c_str(name), ctypes.byref(ret_val), ctypes.byref(ret_typeid))) - return RET_SWITCH[ret_typeid.value](ret_val) + ret = RET_SWITCH[ret_typeid.value](ret_val) def _type_key(handle): diff --git a/src/expr/expr_node.cc b/src/expr/expr_node.cc index 76aa8bf6d..c6626672e 100644 --- a/src/expr/expr_node.cc +++ b/src/expr/expr_node.cc @@ -11,6 +11,32 @@ DMLC_REGISTRY_ENABLE(::tvm::NodeFactoryReg); namespace tvm { +void Node::Destroy() { + bool safe = true; + this->VisitNodeRefFields([&safe](const char* k, NodeRef* r) { + if (r->node_.get() != nullptr) safe = false; + }); + + if (!safe) { + // explicit deletion via DFS + // this is used to avoid stackoverflow caused by chain of deletions + std::vector<Node*> stack{this}; + std::vector<std::shared_ptr<Node> > to_delete; + while (!stack.empty()) { + Node* n = stack.back(); + stack.pop_back(); + n->VisitNodeRefFields([&safe, &stack, &to_delete](const char* k, NodeRef* r) { + if (r->node_.unique()) { + stack.push_back(r->node_.get()); + to_delete.emplace_back(std::move(r->node_)); + } else { + r->node_.reset(); + } + }); + } + } +} + TVM_REGISTER_NODE_TYPE(VarNode); TVM_REGISTER_NODE_TYPE(IntNode); TVM_REGISTER_NODE_TYPE(FloatNode); -- GitLab