Commit 35277c2f by tqchen

Add safe destructor

parent 3e693f53
...@@ -105,6 +105,15 @@ class Node { ...@@ -105,6 +105,15 @@ class Node {
protected: protected:
// node ref can see this // node ref can see this
friend class NodeRef; friend class NodeRef;
/*!
* \brief optional: safe destruction function
* Can be called in destructor of composite types.
* This can be used to avoid stack overflow when
* recursive destruction long graph(1M nodes),
*
* It is totally OK to not call this in destructor.
*/
void Destroy();
/*! \brief the node type enum */ /*! \brief the node type enum */
NodeType node_type_{kOtherNodes}; NodeType node_type_{kOtherNodes};
}; };
...@@ -127,6 +136,7 @@ class NodeRef { ...@@ -127,6 +136,7 @@ class NodeRef {
template<typename T, typename> template<typename T, typename>
friend class Array; friend class Array;
friend class APIVariantValue; friend class APIVariantValue;
friend class Node;
NodeRef() = default; NodeRef() = default;
explicit NodeRef(std::shared_ptr<Node>&& node) : node_(std::move(node)) {} explicit NodeRef(std::shared_ptr<Node>&& node) : node_(std::move(node)) {}
/*! \brief the internal node */ /*! \brief the internal node */
......
...@@ -82,6 +82,9 @@ class UnaryOpNode : public ExprNode { ...@@ -82,6 +82,9 @@ class UnaryOpNode : public ExprNode {
node_type_ = kUnaryOpNode; node_type_ = kUnaryOpNode;
dtype_ = this->src.dtype(); dtype_ = this->src.dtype();
} }
~UnaryOpNode() {
this->Destroy();
}
const char* type_key() const override { const char* type_key() const override {
return "UnaryOpNode"; return "UnaryOpNode";
} }
...@@ -114,6 +117,9 @@ struct BinaryOpNode : public ExprNode { ...@@ -114,6 +117,9 @@ struct BinaryOpNode : public ExprNode {
node_type_ = kBinaryOpNode; node_type_ = kBinaryOpNode;
dtype_ = this->lhs.dtype(); dtype_ = this->lhs.dtype();
} }
~BinaryOpNode() {
this->Destroy();
}
const char* type_key() const override { const char* type_key() const override {
return "BinaryOpNode"; return "BinaryOpNode";
} }
......
...@@ -50,7 +50,7 @@ class NodeBase(object): ...@@ -50,7 +50,7 @@ class NodeBase(object):
check_call(_LIB.TVMNodeGetAttr( check_call(_LIB.TVMNodeGetAttr(
self.handle, c_str(name), self.handle, c_str(name),
ctypes.byref(ret_val), ctypes.byref(ret_typeid))) ctypes.byref(ret_val), ctypes.byref(ret_typeid)))
return RET_SWITCH[ret_typeid.value](ret_val) ret = RET_SWITCH[ret_typeid.value](ret_val)
def _type_key(handle): def _type_key(handle):
......
...@@ -11,6 +11,32 @@ DMLC_REGISTRY_ENABLE(::tvm::NodeFactoryReg); ...@@ -11,6 +11,32 @@ DMLC_REGISTRY_ENABLE(::tvm::NodeFactoryReg);
namespace tvm { namespace tvm {
void Node::Destroy() {
bool safe = true;
this->VisitNodeRefFields([&safe](const char* k, NodeRef* r) {
if (r->node_.get() != nullptr) safe = false;
});
if (!safe) {
// explicit deletion via DFS
// this is used to avoid stackoverflow caused by chain of deletions
std::vector<Node*> stack{this};
std::vector<std::shared_ptr<Node> > to_delete;
while (!stack.empty()) {
Node* n = stack.back();
stack.pop_back();
n->VisitNodeRefFields([&safe, &stack, &to_delete](const char* k, NodeRef* r) {
if (r->node_.unique()) {
stack.push_back(r->node_.get());
to_delete.emplace_back(std::move(r->node_));
} else {
r->node_.reset();
}
});
}
}
}
TVM_REGISTER_NODE_TYPE(VarNode); TVM_REGISTER_NODE_TYPE(VarNode);
TVM_REGISTER_NODE_TYPE(IntNode); TVM_REGISTER_NODE_TYPE(IntNode);
TVM_REGISTER_NODE_TYPE(FloatNode); TVM_REGISTER_NODE_TYPE(FloatNode);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment