Commit 2e0dbaa6 by 雾雨魔理沙 Committed by Haichen Shen

[Relay] Fix memory leak in the interpreter (#4155)

* save

lint

* address reviewer comment
parent b08fe810
...@@ -119,6 +119,32 @@ class ClosureNode : public ValueNode { ...@@ -119,6 +119,32 @@ class ClosureNode : public ValueNode {
RELAY_DEFINE_NODE_REF(Closure, ClosureNode, Value); RELAY_DEFINE_NODE_REF(Closure, ClosureNode, Value);
/*! \brief A Relay Recursive Closure. A closure that has a name. */
class RecClosure;
/*! \brief The container type of RecClosure. */
class RecClosureNode : public ValueNode {
public:
/*! \brief The closure. */
Closure clos;
/*! \brief variable the closure bind to. */
Var bind;
RecClosureNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("clos", &clos);
v->Visit("bind", &bind);
}
TVM_DLL static RecClosure make(Closure clos, Var bind);
static constexpr const char* _type_key = "relay.RecClosure";
TVM_DECLARE_NODE_TYPE_INFO(RecClosureNode, ValueNode);
};
RELAY_DEFINE_NODE_REF(RecClosure, RecClosureNode, Value);
/*! \brief A tuple value. */ /*! \brief A tuple value. */
class TupleValue; class TupleValue;
......
...@@ -73,6 +73,11 @@ class Closure(Value): ...@@ -73,6 +73,11 @@ class Closure(Value):
@register_relay_node @register_relay_node
class RecClosure(Value):
"""A recursive closure produced by the interpreter."""
@register_relay_node
class ConstructorValue(Value): class ConstructorValue(Value):
def __init__(self, tag, fields, constructor): def __init__(self, tag, fields, constructor):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
......
...@@ -56,9 +56,27 @@ TVM_REGISTER_API("relay._make.Closure") ...@@ -56,9 +56,27 @@ TVM_REGISTER_API("relay._make.Closure")
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ClosureNode>([](const ClosureNode* node, tvm::IRPrinter* p) { .set_dispatch<ClosureNode>([](const ClosureNode* node, tvm::IRPrinter* p) {
p->stream << "ClosureNode(" << node->func << ")"; p->stream << "ClosureNode(" << node->func << ", " << node->env << ")";
}); });
// TODO(@jroesch): this doesn't support mutual letrec
/* Value Implementation */
RecClosure RecClosureNode::make(Closure clos, Var bind) {
NodePtr<RecClosureNode> n = make_node<RecClosureNode>();
n->clos = std::move(clos);
n->bind = std::move(bind);
return RecClosure(n);
}
TVM_REGISTER_API("relay._make.RecClosure")
.set_body_typed(RecClosureNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<RecClosureNode>([](const RecClosureNode* node, tvm::IRPrinter* p) {
p->stream << "RecClosureNode(" << node->clos << ")";
});
TupleValue TupleValueNode::make(tvm::Array<Value> value) { TupleValue TupleValueNode::make(tvm::Array<Value> value) {
NodePtr<TupleValueNode> n = make_node<TupleValueNode>(); NodePtr<TupleValueNode> n = make_node<TupleValueNode>();
n->fields = value; n->fields = value;
...@@ -281,7 +299,6 @@ class Interpreter : ...@@ -281,7 +299,6 @@ class Interpreter :
return TupleValueNode::make(values); return TupleValueNode::make(values);
} }
// TODO(@jroesch): this doesn't support mutual letrec
inline Value MakeClosure(const Function& func, Var letrec_name = Var()) { inline Value MakeClosure(const Function& func, Var letrec_name = Var()) {
tvm::Map<Var, Value> captured_mod; tvm::Map<Var, Value> captured_mod;
Array<Var> free_vars = FreeVars(func); Array<Var> free_vars = FreeVars(func);
...@@ -298,10 +315,8 @@ class Interpreter : ...@@ -298,10 +315,8 @@ class Interpreter :
// We must use mutation here to build a self referential closure. // We must use mutation here to build a self referential closure.
auto closure = ClosureNode::make(captured_mod, func); auto closure = ClosureNode::make(captured_mod, func);
auto mut_closure =
static_cast<ClosureNode*>(const_cast<Node*>(closure.get()));
if (letrec_name.defined()) { if (letrec_name.defined()) {
mut_closure->env.Set(letrec_name, closure); return RecClosureNode::make(closure, letrec_name);
} }
return std::move(closure); return std::move(closure);
} }
...@@ -559,7 +574,7 @@ class Interpreter : ...@@ -559,7 +574,7 @@ class Interpreter :
} }
// Invoke the closure // Invoke the closure
Value Invoke(const Closure& closure, const tvm::Array<Value>& args) { Value Invoke(const Closure& closure, const tvm::Array<Value>& args, const Var& bind = Var()) {
// Get a reference to the function inside the closure. // Get a reference to the function inside the closure.
if (closure->func->IsPrimitive()) { if (closure->func->IsPrimitive()) {
return InvokePrimitiveOp(closure->func, args); return InvokePrimitiveOp(closure->func, args);
...@@ -575,12 +590,16 @@ class Interpreter : ...@@ -575,12 +590,16 @@ class Interpreter :
locals.Set(func->params[i], args[i]); locals.Set(func->params[i], args[i]);
} }
// Add the var to value mappings from the Closure's modironment. // Add the var to value mappings from the Closure's environment.
for (auto it = closure->env.begin(); it != closure->env.end(); ++it) { for (auto it = closure->env.begin(); it != closure->env.end(); ++it) {
CHECK_EQ(locals.count((*it).first), 0); CHECK_EQ(locals.count((*it).first), 0);
locals.Set((*it).first, (*it).second); locals.Set((*it).first, (*it).second);
} }
if (bind.defined()) {
locals.Set(bind, RecClosureNode::make(closure, bind));
}
return WithFrame<Value>(Frame(locals), [&]() { return Eval(func->body); }); return WithFrame<Value>(Frame(locals), [&]() { return Eval(func->body); });
} }
...@@ -607,6 +626,8 @@ class Interpreter : ...@@ -607,6 +626,8 @@ class Interpreter :
if (const ClosureNode* closure_node = fn_val.as<ClosureNode>()) { if (const ClosureNode* closure_node = fn_val.as<ClosureNode>()) {
auto closure = GetRef<Closure>(closure_node); auto closure = GetRef<Closure>(closure_node);
return this->Invoke(closure, args); return this->Invoke(closure, args);
} else if (const RecClosureNode* closure_node = fn_val.as<RecClosureNode>()) {
return this->Invoke(closure_node->clos, args, closure_node->bind);
} else { } else {
LOG(FATAL) << "internal error: type error, expected function value in the call " LOG(FATAL) << "internal error: type error, expected function value in the call "
<< "position"; << "position";
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment