Commit 2e0dbaa6 by 雾雨魔理沙 Committed by Haichen Shen

[Relay] Fix memory leak in the interpreter (#4155)

* save

lint

* address reviewer comment
parent b08fe810
......@@ -119,6 +119,32 @@ class ClosureNode : public ValueNode {
RELAY_DEFINE_NODE_REF(Closure, ClosureNode, Value);
/*! \brief A Relay Recursive Closure. A closure that has a name. */
class RecClosure;
/*! \brief The container type of RecClosure. */
class RecClosureNode : public ValueNode {
public:
/*! \brief The closure. */
Closure clos;
/*! \brief variable the closure bind to. */
Var bind;
RecClosureNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("clos", &clos);
v->Visit("bind", &bind);
}
TVM_DLL static RecClosure make(Closure clos, Var bind);
static constexpr const char* _type_key = "relay.RecClosure";
TVM_DECLARE_NODE_TYPE_INFO(RecClosureNode, ValueNode);
};
RELAY_DEFINE_NODE_REF(RecClosure, RecClosureNode, Value);
/*! \brief A tuple value. */
class TupleValue;
......
......@@ -73,6 +73,11 @@ class Closure(Value):
@register_relay_node
class RecClosure(Value):
"""A recursive closure produced by the interpreter."""
@register_relay_node
class ConstructorValue(Value):
def __init__(self, tag, fields, constructor):
self.__init_handle_by_constructor__(
......
......@@ -56,9 +56,27 @@ TVM_REGISTER_API("relay._make.Closure")
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ClosureNode>([](const ClosureNode* node, tvm::IRPrinter* p) {
p->stream << "ClosureNode(" << node->func << ")";
p->stream << "ClosureNode(" << node->func << ", " << node->env << ")";
});
// TODO(@jroesch): this doesn't support mutual letrec
/* Value Implementation */
RecClosure RecClosureNode::make(Closure clos, Var bind) {
NodePtr<RecClosureNode> n = make_node<RecClosureNode>();
n->clos = std::move(clos);
n->bind = std::move(bind);
return RecClosure(n);
}
TVM_REGISTER_API("relay._make.RecClosure")
.set_body_typed(RecClosureNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<RecClosureNode>([](const RecClosureNode* node, tvm::IRPrinter* p) {
p->stream << "RecClosureNode(" << node->clos << ")";
});
TupleValue TupleValueNode::make(tvm::Array<Value> value) {
NodePtr<TupleValueNode> n = make_node<TupleValueNode>();
n->fields = value;
......@@ -281,7 +299,6 @@ class Interpreter :
return TupleValueNode::make(values);
}
// TODO(@jroesch): this doesn't support mutual letrec
inline Value MakeClosure(const Function& func, Var letrec_name = Var()) {
tvm::Map<Var, Value> captured_mod;
Array<Var> free_vars = FreeVars(func);
......@@ -298,10 +315,8 @@ class Interpreter :
// We must use mutation here to build a self referential closure.
auto closure = ClosureNode::make(captured_mod, func);
auto mut_closure =
static_cast<ClosureNode*>(const_cast<Node*>(closure.get()));
if (letrec_name.defined()) {
mut_closure->env.Set(letrec_name, closure);
return RecClosureNode::make(closure, letrec_name);
}
return std::move(closure);
}
......@@ -559,7 +574,7 @@ class Interpreter :
}
// Invoke the closure
Value Invoke(const Closure& closure, const tvm::Array<Value>& args) {
Value Invoke(const Closure& closure, const tvm::Array<Value>& args, const Var& bind = Var()) {
// Get a reference to the function inside the closure.
if (closure->func->IsPrimitive()) {
return InvokePrimitiveOp(closure->func, args);
......@@ -575,12 +590,16 @@ class Interpreter :
locals.Set(func->params[i], args[i]);
}
// Add the var to value mappings from the Closure's modironment.
// Add the var to value mappings from the Closure's environment.
for (auto it = closure->env.begin(); it != closure->env.end(); ++it) {
CHECK_EQ(locals.count((*it).first), 0);
locals.Set((*it).first, (*it).second);
}
if (bind.defined()) {
locals.Set(bind, RecClosureNode::make(closure, bind));
}
return WithFrame<Value>(Frame(locals), [&]() { return Eval(func->body); });
}
......@@ -607,6 +626,8 @@ class Interpreter :
if (const ClosureNode* closure_node = fn_val.as<ClosureNode>()) {
auto closure = GetRef<Closure>(closure_node);
return this->Invoke(closure, args);
} else if (const RecClosureNode* closure_node = fn_val.as<RecClosureNode>()) {
return this->Invoke(closure_node->clos, args, closure_node->bind);
} else {
LOG(FATAL) << "internal error: type error, expected function value in the call "
<< "position";
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment