Commit 2672aad4 by Zhi Committed by Haichen Shen

[fix][pass] Save the function when it is used as a call arg (#4389)

parent 3338af7c
......@@ -54,12 +54,8 @@ struct CallTracer : ExprVisitor {
called_funcs_{},
visiting_{} {}
void VisitExpr_(const CallNode* call_node) final {
Expr op = call_node->op;
for (auto param : call_node->args) {
VisitExpr(param);
}
if (auto func_node = op.as<FunctionNode>()) {
void CheckExpr(const Expr& expr) {
if (auto func_node = expr.as<FunctionNode>()) {
auto func = GetRef<Function>(func_node);
auto it = visiting_.find(func);
if (it != visiting_.end()) {
......@@ -67,7 +63,7 @@ struct CallTracer : ExprVisitor {
}
visiting_.insert(func);
VisitExpr(func);
} else if (auto global = op.as<GlobalVarNode>()) {
} else if (auto global = expr.as<GlobalVarNode>()) {
called_funcs_.insert(global->name_hint);
auto func = module_->Lookup(global->name_hint);
auto it = visiting_.find(func);
......@@ -76,6 +72,15 @@ struct CallTracer : ExprVisitor {
}
visiting_.insert(func);
VisitExpr(func);
} else {
VisitExpr(expr);
}
}
void VisitExpr_(const CallNode* call_node) final {
CheckExpr(call_node->op);
for (auto param : call_node->args) {
CheckExpr(param);
}
}
......
......@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import tvm
from tvm import relay
from tvm.relay import transform
......@@ -71,5 +72,21 @@ def test_multiple_entry_functions():
l = set([x[0].name_hint for x in mod.functions.items()])
assert l == set(['tl', 'hd', 'main2', 'id_func', 'main1'])
def test_globalvar_as_call_arg():
mod = relay.Module()
p = Prelude(mod)
tensor_array = p.get_var('tensor_array', 'int32')
tensor1 = p.get_var('tensor1', 'int32')
write = p.get_var('tensor_array_write', 'int32')
stack = p.get_var('tensor_array_stack', 'int32')
v = relay.var('v')
init_tensor_array = tensor_array(relay.const(3))
tensor_array1 = write(init_tensor_array, relay.const(0), tensor1(v))
tensor_array2 = stack(tensor_array1)
mod["main"] = relay.Function([v], tensor_array2)
mod = relay.transform.RemoveUnusedFunctions()(mod)
l = set([x[0].name_hint for x in mod.functions.items()])
assert 'tensor_array_int32' in l
if __name__ == '__main__':
pytest.main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment