Commit 3bbfc2dc by Wei Chen Committed by Tianqi Chen

[RELAY] Support recursive call syntax (#2352)

parent 6ab05082
......@@ -62,6 +62,15 @@ class ModuleNode : public RelayNode {
void Add(const GlobalVar& var, const Function& func, bool update = false);
/*!
* \brief Add a function to the global environment.
* \param var The name of the global function.
* \param func The function.
*
* It does not do type inference as Add does.
*/
void AddUnchecked(const GlobalVar& var, const Function& func);
/*!
* \brief Update a function in the global environment.
* \param var The name of the global function to update.
* \param func The new function.
......
......@@ -87,6 +87,7 @@ class ParseTreeToRelayIR(RelayVisitor):
# Adding an empty scope allows naked lets without pain.
self.var_scopes = deque([deque()]) # type: Scopes[expr.Var]
self.global_var_scope = deque() # type: Scope[expr.GlobalVar]
self.type_param_scopes = deque([deque()]) # type: Scopes[ty.TypeVar]
super(ParseTreeToRelayIR, self).__init__()
......@@ -111,6 +112,14 @@ class ParseTreeToRelayIR(RelayVisitor):
self.var_scopes[0].appendleft((name, var))
return var
def mk_global_var(self, name):
# type: (str) -> expr.GlobalVar
"""Create a new GlobalVar and add it to the GlobalVar scope."""
var = expr.GlobalVar(name)
self.global_var_scope.append((name, var))
return var
def enter_type_param_scope(self):
# type: () -> None
"""Enter a new TypeVar scope so it can be popped off later."""
......@@ -140,7 +149,7 @@ class ParseTreeToRelayIR(RelayVisitor):
# variables
if node_type == RelayLexer.GLOBAL_VAR:
return expr.GlobalVar(node_text[1:])
return lookup([self.global_var_scope], node_text[1:])
elif node_type == RelayLexer.LOCAL_VAR:
name = node_text[1:]
var = lookup(self.var_scopes, name)
......@@ -313,7 +322,8 @@ class ParseTreeToRelayIR(RelayVisitor):
ident = ctx.ident().GLOBAL_VAR()
if ident is None:
raise ParseError('Only global ids may be used in `def`s.')
ident = expr.GlobalVar(ident.getText()[1:])
ident_name = ident.getText()[1:]
ident = self.mk_global_var(ident_name)
self.module[ident] = self.mk_func(ctx)
......
......@@ -33,10 +33,26 @@ GlobalVar ModuleNode::GetGlobalVar(const std::string& name) {
return (*it).second;
}
void ModuleNode::AddUnchecked(const GlobalVar& var,
const Function& func) {
auto mod = GetRef<Module>(this);
this->functions.Set(var, func);
auto it = global_var_map_.find(var->name_hint);
if (it != global_var_map_.end()) {
CHECK_EQ((*it).second, var);
} else {
CHECK(!global_var_map_.count(var->name_hint))
<< "Duplicate global function name " << var->name_hint;
}
global_var_map_.Set(var->name_hint, var);
}
void ModuleNode::Add(const GlobalVar& var,
const Function& func,
bool update) {
// Type check the item before we add it to the modironment.
const Function& func,
bool update) {
// Type check the item before we add it to the module.
auto mod = GetRef<Module>(this);
Function checked_func = InferType(func, mod, var);
auto type = checked_func->checked_type();
......@@ -48,18 +64,7 @@ void ModuleNode::Add(const GlobalVar& var,
CHECK(AlphaEqual(type, old_type))
<< "Module#update changes type, not possible in this mode.";
}
this->functions.Set(var, checked_func);
auto it = global_var_map_.find(var->name_hint);
if (it != global_var_map_.end()) {
CHECK_EQ((*it).second, var);
} else {
// set global var map
CHECK(!global_var_map_.count(var->name_hint))
<< "Duplicate global function name " << var->name_hint;
}
global_var_map_.Set(var->name_hint, var);
AddUnchecked(var, checked_func);
}
void ModuleNode::Update(const GlobalVar& var, const Function& func) {
......
......@@ -543,10 +543,9 @@ Function InferType(const Function& func,
const GlobalVar& var) {
Function func_copy = Function(make_node<FunctionNode>(*func.operator->()));
func_copy->checked_type_ = func_copy->func_type_annotation();
mod->functions.Set(var, func_copy);
mod->AddUnchecked(var, func_copy);
Expr func_ret = TypeInferencer(mod).Infer(func_copy);
auto map_node = mod->functions.CopyOnWrite();
map_node->data.erase(var.node_);
mod->Remove(var);
CHECK(WellFormed(func_ret));
return Downcast<Function>(func_ret);
}
......
......@@ -282,6 +282,16 @@ def test_defn():
assert isinstance(id_defn, relay.Module)
@if_parser_enabled
def test_recursive_call():
id_defn = relay.fromtext(
"""
def @id(%x: int32) -> int32 {
@id(%x)
}
""")
assert isinstance(id_defn, relay.Module)
@if_parser_enabled
def test_ifelse():
assert alpha_equal(
relay.fromtext(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment