Commit 8037fc82 by Zhi Chen Committed by Wei Chen

fix RemoveUnusedFunctions pass

parent 56996378
...@@ -53,33 +53,20 @@ struct CallTracer : ExprVisitor { ...@@ -53,33 +53,20 @@ struct CallTracer : ExprVisitor {
called_funcs_{}, called_funcs_{},
visiting_{} {} visiting_{} {}
void CheckExpr(const Expr& expr) { void VisitExpr_(const GlobalVarNode* op) final {
if (auto func_node = expr.as<FunctionNode>()) { called_funcs_.insert(op->name_hint);
auto func = GetRef<Function>(func_node); auto func = module_->Lookup(op->name_hint);
auto it = visiting_.find(func);
if (it != visiting_.end()) {
return;
}
visiting_.insert(func);
VisitExpr(func); VisitExpr(func);
} else if (auto global = expr.as<GlobalVarNode>()) {
called_funcs_.insert(global->name_hint);
auto func = module_->Lookup(global->name_hint);
auto it = visiting_.find(func);
if (it != visiting_.end()) {
return;
} }
void VisitExpr_(const FunctionNode* func_node) final {
auto func = GetRef<Function>(func_node);
if (visiting_.find(func) == visiting_.end()) {
visiting_.insert(func); visiting_.insert(func);
VisitExpr(func); for (auto param : func_node->params) {
} else { ExprVisitor::VisitExpr(param);
VisitExpr(expr);
} }
} ExprVisitor::VisitExpr(func_node->body);
void VisitExpr_(const CallNode* call_node) final {
CheckExpr(call_node->op);
for (auto param : call_node->args) {
CheckExpr(param);
} }
} }
......
...@@ -20,6 +20,7 @@ from tvm import relay ...@@ -20,6 +20,7 @@ from tvm import relay
from tvm.relay import transform from tvm.relay import transform
from tvm.relay.prelude import Prelude from tvm.relay.prelude import Prelude
def test_remove_all_prelude_functions(): def test_remove_all_prelude_functions():
mod = relay.Module() mod = relay.Module()
p = Prelude(mod) p = Prelude(mod)
...@@ -29,6 +30,7 @@ def test_remove_all_prelude_functions(): ...@@ -29,6 +30,7 @@ def test_remove_all_prelude_functions():
l = set([x[0].name_hint for x in mod.functions.items()]) l = set([x[0].name_hint for x in mod.functions.items()])
assert l == set(['main']) assert l == set(['main'])
def test_remove_all_prelude_functions_but_referenced_functions(): def test_remove_all_prelude_functions_but_referenced_functions():
mod = relay.Module() mod = relay.Module()
p = Prelude(mod) p = Prelude(mod)
...@@ -42,6 +44,7 @@ def test_remove_all_prelude_functions_but_referenced_functions(): ...@@ -42,6 +44,7 @@ def test_remove_all_prelude_functions_but_referenced_functions():
l = set([x[0].name_hint for x in mod.functions.items()]) l = set([x[0].name_hint for x in mod.functions.items()])
assert l == set(['id_func', 'main']) assert l == set(['id_func', 'main'])
def test_keep_only_referenced_prelude_functions(): def test_keep_only_referenced_prelude_functions():
mod = relay.Module() mod = relay.Module()
p = Prelude(mod) p = Prelude(mod)
...@@ -54,6 +57,7 @@ def test_keep_only_referenced_prelude_functions(): ...@@ -54,6 +57,7 @@ def test_keep_only_referenced_prelude_functions():
l = set([x[0].name_hint for x in mod.functions.items()]) l = set([x[0].name_hint for x in mod.functions.items()])
assert l == set(['tl', 'hd', 'main']) assert l == set(['tl', 'hd', 'main'])
def test_multiple_entry_functions(): def test_multiple_entry_functions():
mod = relay.Module() mod = relay.Module()
p = Prelude(mod) p = Prelude(mod)
...@@ -72,6 +76,7 @@ def test_multiple_entry_functions(): ...@@ -72,6 +76,7 @@ def test_multiple_entry_functions():
l = set([x[0].name_hint for x in mod.functions.items()]) l = set([x[0].name_hint for x in mod.functions.items()])
assert l == set(['tl', 'hd', 'main2', 'id_func', 'main1']) assert l == set(['tl', 'hd', 'main2', 'id_func', 'main1'])
def test_globalvar_as_call_arg(): def test_globalvar_as_call_arg():
mod = relay.Module() mod = relay.Module()
p = Prelude(mod) p = Prelude(mod)
...@@ -88,5 +93,24 @@ def test_globalvar_as_call_arg(): ...@@ -88,5 +93,24 @@ def test_globalvar_as_call_arg():
l = set([x[0].name_hint for x in mod.functions.items()]) l = set([x[0].name_hint for x in mod.functions.items()])
assert 'tensor_array_int32' in l assert 'tensor_array_int32' in l
def test_call_globalvar_without_args():
def get_mod():
mod = relay.Module({})
fn1 = relay.Function([], relay.const(1))
fn2 = relay.Function([], relay.const(2))
g1 = relay.GlobalVar('g1')
g2 = relay.GlobalVar('g2')
mod[g1] = fn1
mod[g2] = fn2
p = relay.var('p', 'bool')
mod['main'] = relay.Function([p], relay.Call(relay.If(p, g1, g2), []))
return mod
mod = get_mod()
ref_mod = get_mod()
mod = relay.transform.RemoveUnusedFunctions()(mod)
assert relay.alpha_equal(mod, ref_mod)
if __name__ == '__main__': if __name__ == '__main__':
pytest.main() pytest.main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment