Commit 8037fc82 by Zhi Chen Committed by Wei Chen

fix RemoveUnusedFunctions pass

parent 56996378
......@@ -53,33 +53,20 @@ struct CallTracer : ExprVisitor {
called_funcs_{},
visiting_{} {}
void CheckExpr(const Expr& expr) {
if (auto func_node = expr.as<FunctionNode>()) {
auto func = GetRef<Function>(func_node);
auto it = visiting_.find(func);
if (it != visiting_.end()) {
return;
}
visiting_.insert(func);
VisitExpr(func);
} else if (auto global = expr.as<GlobalVarNode>()) {
called_funcs_.insert(global->name_hint);
auto func = module_->Lookup(global->name_hint);
auto it = visiting_.find(func);
if (it != visiting_.end()) {
return;
}
visiting_.insert(func);
VisitExpr(func);
} else {
VisitExpr(expr);
}
void VisitExpr_(const GlobalVarNode* op) final {
called_funcs_.insert(op->name_hint);
auto func = module_->Lookup(op->name_hint);
VisitExpr(func);
}
void VisitExpr_(const CallNode* call_node) final {
CheckExpr(call_node->op);
for (auto param : call_node->args) {
CheckExpr(param);
void VisitExpr_(const FunctionNode* func_node) final {
auto func = GetRef<Function>(func_node);
if (visiting_.find(func) == visiting_.end()) {
visiting_.insert(func);
for (auto param : func_node->params) {
ExprVisitor::VisitExpr(param);
}
ExprVisitor::VisitExpr(func_node->body);
}
}
......
......@@ -20,6 +20,7 @@ from tvm import relay
from tvm.relay import transform
from tvm.relay.prelude import Prelude
def test_remove_all_prelude_functions():
mod = relay.Module()
p = Prelude(mod)
......@@ -29,6 +30,7 @@ def test_remove_all_prelude_functions():
l = set([x[0].name_hint for x in mod.functions.items()])
assert l == set(['main'])
def test_remove_all_prelude_functions_but_referenced_functions():
mod = relay.Module()
p = Prelude(mod)
......@@ -42,6 +44,7 @@ def test_remove_all_prelude_functions_but_referenced_functions():
l = set([x[0].name_hint for x in mod.functions.items()])
assert l == set(['id_func', 'main'])
def test_keep_only_referenced_prelude_functions():
mod = relay.Module()
p = Prelude(mod)
......@@ -54,6 +57,7 @@ def test_keep_only_referenced_prelude_functions():
l = set([x[0].name_hint for x in mod.functions.items()])
assert l == set(['tl', 'hd', 'main'])
def test_multiple_entry_functions():
mod = relay.Module()
p = Prelude(mod)
......@@ -72,6 +76,7 @@ def test_multiple_entry_functions():
l = set([x[0].name_hint for x in mod.functions.items()])
assert l == set(['tl', 'hd', 'main2', 'id_func', 'main1'])
def test_globalvar_as_call_arg():
mod = relay.Module()
p = Prelude(mod)
......@@ -88,5 +93,24 @@ def test_globalvar_as_call_arg():
l = set([x[0].name_hint for x in mod.functions.items()])
assert 'tensor_array_int32' in l
def test_call_globalvar_without_args():
def get_mod():
mod = relay.Module({})
fn1 = relay.Function([], relay.const(1))
fn2 = relay.Function([], relay.const(2))
g1 = relay.GlobalVar('g1')
g2 = relay.GlobalVar('g2')
mod[g1] = fn1
mod[g2] = fn2
p = relay.var('p', 'bool')
mod['main'] = relay.Function([p], relay.Call(relay.If(p, g1, g2), []))
return mod
mod = get_mod()
ref_mod = get_mod()
mod = relay.transform.RemoveUnusedFunctions()(mod)
assert relay.alpha_equal(mod, ref_mod)
if __name__ == '__main__':
pytest.main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment