Commit 3bbfc2dc by Wei Chen Committed by Tianqi Chen

[RELAY] Support recursive call syntax (#2352)

parent 6ab05082
...@@ -62,6 +62,15 @@ class ModuleNode : public RelayNode { ...@@ -62,6 +62,15 @@ class ModuleNode : public RelayNode {
void Add(const GlobalVar& var, const Function& func, bool update = false); void Add(const GlobalVar& var, const Function& func, bool update = false);
/*! /*!
* \brief Add a function to the global environment.
* \param var The name of the global function.
* \param func The function.
*
* It does not do type inference as Add does.
*/
void AddUnchecked(const GlobalVar& var, const Function& func);
/*!
* \brief Update a function in the global environment. * \brief Update a function in the global environment.
* \param var The name of the global function to update. * \param var The name of the global function to update.
* \param func The new function. * \param func The new function.
......
...@@ -87,6 +87,7 @@ class ParseTreeToRelayIR(RelayVisitor): ...@@ -87,6 +87,7 @@ class ParseTreeToRelayIR(RelayVisitor):
# Adding an empty scope allows naked lets without pain. # Adding an empty scope allows naked lets without pain.
self.var_scopes = deque([deque()]) # type: Scopes[expr.Var] self.var_scopes = deque([deque()]) # type: Scopes[expr.Var]
self.global_var_scope = deque() # type: Scope[expr.GlobalVar]
self.type_param_scopes = deque([deque()]) # type: Scopes[ty.TypeVar] self.type_param_scopes = deque([deque()]) # type: Scopes[ty.TypeVar]
super(ParseTreeToRelayIR, self).__init__() super(ParseTreeToRelayIR, self).__init__()
...@@ -111,6 +112,14 @@ class ParseTreeToRelayIR(RelayVisitor): ...@@ -111,6 +112,14 @@ class ParseTreeToRelayIR(RelayVisitor):
self.var_scopes[0].appendleft((name, var)) self.var_scopes[0].appendleft((name, var))
return var return var
def mk_global_var(self, name):
# type: (str) -> expr.GlobalVar
"""Create a new GlobalVar and add it to the GlobalVar scope."""
var = expr.GlobalVar(name)
self.global_var_scope.append((name, var))
return var
def enter_type_param_scope(self): def enter_type_param_scope(self):
# type: () -> None # type: () -> None
"""Enter a new TypeVar scope so it can be popped off later.""" """Enter a new TypeVar scope so it can be popped off later."""
...@@ -140,7 +149,7 @@ class ParseTreeToRelayIR(RelayVisitor): ...@@ -140,7 +149,7 @@ class ParseTreeToRelayIR(RelayVisitor):
# variables # variables
if node_type == RelayLexer.GLOBAL_VAR: if node_type == RelayLexer.GLOBAL_VAR:
return expr.GlobalVar(node_text[1:]) return lookup([self.global_var_scope], node_text[1:])
elif node_type == RelayLexer.LOCAL_VAR: elif node_type == RelayLexer.LOCAL_VAR:
name = node_text[1:] name = node_text[1:]
var = lookup(self.var_scopes, name) var = lookup(self.var_scopes, name)
...@@ -313,7 +322,8 @@ class ParseTreeToRelayIR(RelayVisitor): ...@@ -313,7 +322,8 @@ class ParseTreeToRelayIR(RelayVisitor):
ident = ctx.ident().GLOBAL_VAR() ident = ctx.ident().GLOBAL_VAR()
if ident is None: if ident is None:
raise ParseError('Only global ids may be used in `def`s.') raise ParseError('Only global ids may be used in `def`s.')
ident = expr.GlobalVar(ident.getText()[1:]) ident_name = ident.getText()[1:]
ident = self.mk_global_var(ident_name)
self.module[ident] = self.mk_func(ctx) self.module[ident] = self.mk_func(ctx)
......
...@@ -33,10 +33,26 @@ GlobalVar ModuleNode::GetGlobalVar(const std::string& name) { ...@@ -33,10 +33,26 @@ GlobalVar ModuleNode::GetGlobalVar(const std::string& name) {
return (*it).second; return (*it).second;
} }
void ModuleNode::AddUnchecked(const GlobalVar& var,
const Function& func) {
auto mod = GetRef<Module>(this);
this->functions.Set(var, func);
auto it = global_var_map_.find(var->name_hint);
if (it != global_var_map_.end()) {
CHECK_EQ((*it).second, var);
} else {
CHECK(!global_var_map_.count(var->name_hint))
<< "Duplicate global function name " << var->name_hint;
}
global_var_map_.Set(var->name_hint, var);
}
void ModuleNode::Add(const GlobalVar& var, void ModuleNode::Add(const GlobalVar& var,
const Function& func, const Function& func,
bool update) { bool update) {
// Type check the item before we add it to the modironment. // Type check the item before we add it to the module.
auto mod = GetRef<Module>(this); auto mod = GetRef<Module>(this);
Function checked_func = InferType(func, mod, var); Function checked_func = InferType(func, mod, var);
auto type = checked_func->checked_type(); auto type = checked_func->checked_type();
...@@ -48,18 +64,7 @@ void ModuleNode::Add(const GlobalVar& var, ...@@ -48,18 +64,7 @@ void ModuleNode::Add(const GlobalVar& var,
CHECK(AlphaEqual(type, old_type)) CHECK(AlphaEqual(type, old_type))
<< "Module#update changes type, not possible in this mode."; << "Module#update changes type, not possible in this mode.";
} }
this->functions.Set(var, checked_func); AddUnchecked(var, checked_func);
auto it = global_var_map_.find(var->name_hint);
if (it != global_var_map_.end()) {
CHECK_EQ((*it).second, var);
} else {
// set global var map
CHECK(!global_var_map_.count(var->name_hint))
<< "Duplicate global function name " << var->name_hint;
}
global_var_map_.Set(var->name_hint, var);
} }
void ModuleNode::Update(const GlobalVar& var, const Function& func) { void ModuleNode::Update(const GlobalVar& var, const Function& func) {
......
...@@ -543,10 +543,9 @@ Function InferType(const Function& func, ...@@ -543,10 +543,9 @@ Function InferType(const Function& func,
const GlobalVar& var) { const GlobalVar& var) {
Function func_copy = Function(make_node<FunctionNode>(*func.operator->())); Function func_copy = Function(make_node<FunctionNode>(*func.operator->()));
func_copy->checked_type_ = func_copy->func_type_annotation(); func_copy->checked_type_ = func_copy->func_type_annotation();
mod->functions.Set(var, func_copy); mod->AddUnchecked(var, func_copy);
Expr func_ret = TypeInferencer(mod).Infer(func_copy); Expr func_ret = TypeInferencer(mod).Infer(func_copy);
auto map_node = mod->functions.CopyOnWrite(); mod->Remove(var);
map_node->data.erase(var.node_);
CHECK(WellFormed(func_ret)); CHECK(WellFormed(func_ret));
return Downcast<Function>(func_ret); return Downcast<Function>(func_ret);
} }
......
...@@ -282,6 +282,16 @@ def test_defn(): ...@@ -282,6 +282,16 @@ def test_defn():
assert isinstance(id_defn, relay.Module) assert isinstance(id_defn, relay.Module)
@if_parser_enabled @if_parser_enabled
def test_recursive_call():
id_defn = relay.fromtext(
"""
def @id(%x: int32) -> int32 {
@id(%x)
}
""")
assert isinstance(id_defn, relay.Module)
@if_parser_enabled
def test_ifelse(): def test_ifelse():
assert alpha_equal( assert alpha_equal(
relay.fromtext( relay.fromtext(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment