Unverified Commit 0fb48360 by Zhi Committed by GitHub

[Relay][Pass] Add inline pass (#4927)

* add inline pass

* IsInline -> IsMarkedInlined

* fix comment
parent 892dc91a
......@@ -222,6 +222,13 @@ class FunctionNode : public BaseFuncNode {
bool IsPrimitive() const;
/*!
* \brief Check whether the function is marked as inline.
*
* \return Whether the function should be inlined or not.
*/
bool IsMarkedInline() const;
/*!
* \brief Check whether the function should use the TVM default compiler to build, or
* use other compilers.
*
......@@ -563,6 +570,8 @@ constexpr const char* kExternalSymbol = "ExternalSymbol";
constexpr const char* kSkipOptimization = "SkipOptimization";
/*! \brief Treat the function as a composite operator. */
constexpr const char* kComposite = "Composite";
/*! \brief Mark the function to be inlined. */
constexpr const char* kInline = "Inline";
} // namespace attr
} // namespace relay
......
......@@ -324,6 +324,14 @@ TVM_DLL Pass PrintIR(bool show_meta_data = true);
*/
TVM_DLL Pass PartitionGraph();
/*!
* \brief Inline the global functions marked as `inline` in a given Relay
* IRModule.
*
* \return The pass.
*/
TVM_DLL Pass Inline();
} // namespace transform
/*!
......
......@@ -552,6 +552,19 @@ def PartitionGraph():
return _transform.PartitionGraph()
def Inline():
"""Perform inlining on the given Relay IR module. The global functions that
are marked as `inline` should be always inlined. A cost model will be
needed in the future to decide if it is profitable to inline the function.
Returns
-------
ret: tvm.relay.Pass
The registered pass that performs inlining for a Relay IR module.
"""
return _transform.Inline()
def gradient(expr, mod=None, mode='higher_order'):
"""
Transform the input function,
......
......@@ -145,6 +145,12 @@ bool FunctionNode::IsPrimitive() const {
return pval && pval->value != 0;
}
bool FunctionNode::IsMarkedInline() const {
ObjectRef res = FunctionGetAttr(GetRef<Function>(this), attr::kInline);
const tir::IntImmNode* pval = res.as<tir::IntImmNode>();
return pval && pval->value != 0;
}
Function FunctionNode::SetParams(const tvm::Map<Var, Constant>& parameters) const {
return FunctionSetAttr(GetRef<Function>(this), attr::kParams, parameters);
}
......
......@@ -84,6 +84,13 @@ CallGraphEntry* CallGraphNode::operator[](const GlobalVar& gv) {
return cit->second.get();
}
BaseFunc CallGraphNode::GetGlobalFunction(const GlobalVar& var) const {
CHECK(module->ContainGlobalVar(var->name_hint))
<< "GlobalVar " << var->name_hint
<< " not found in the current ir module";
return module->Lookup(var);
}
// Query the existence of a GlobalVar in the call graph. It creates an entry if
// there is no such node available.
CallGraphEntry* CallGraphNode::LookupGlobalVar(const GlobalVar& gv) {
......@@ -306,7 +313,7 @@ TVM_REGISTER_GLOBAL("relay._analysis.PrintCallGraph")
TVM_REGISTER_GLOBAL("relay._analysis.GetModule")
.set_body_typed([](CallGraph call_graph) {
return call_graph->GetModule();
return call_graph->module;
});
TVM_REGISTER_GLOBAL("relay._analysis.PrintCallGraphGlobalVar")
......
......@@ -124,10 +124,14 @@ class CallGraphNode : public Object {
return (*this)[module->GetGlobalVar(gvar_name)];
}
/*! \brief Return the IR module. */
IRModule GetModule() const {
return module;
}
/*!
* \brief Get the global function corresponding to the variable.
*
* \param var The global variable.
*
* \return The found global function.
*/
BaseFunc GetGlobalFunction(const GlobalVar& var) const;
/*!
* \brief Get the entries/root nodes of CallGraphNode.
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/relay/pass/inline.cc
* \brief Global function inliner. It contains the following steps:
*
* - Preprocessing: eligibility checking. Only inline the functions that can
* be inlined. We currently only use simple rules to make the decision. No
* profitibility analysis is available for now.
*
* - Inline: replace the call with a function or the function body depending on
* the attribute of the callee function. For example, we return the function
* node when it doesn't use default compiler, i.e. llvm. This is because these
* functions are packed to be offloaded to external codegen.
*
* - Postprocessing: remove the replaced functions that have no reference.
*/
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/support/logging.h>
#include <tvm/relay/transform.h>
#include <string>
#include <unordered_set>
#include "call_graph.h"
using namespace tvm::runtime;
namespace tvm {
namespace relay {
class Inliner : ExprMutator {
public:
explicit Inliner(CallGraphEntry* cur_node, CallGraphNode* call_graph)
: cur_node_(cur_node), call_graph_(call_graph) {}
Expr VisitExpr_(const CallNode* call_node) final {
Expr op = call_node->op;
const auto* gvn = op.as<GlobalVarNode>();
if (gvn) {
GlobalVar gv = GetRef<GlobalVar>(gvn);
auto* cg_node = (*call_graph_)[gv->name_hint];
if (CanInline(cg_node)) {
tvm::Array<Expr> call_args;
for (auto arg : call_node->args) {
auto new_arg = VisitExpr(arg);
call_args.push_back(new_arg);
}
cur_node_->RemoveCallTo(gv);
return MakeNewExpr(gv, call_args, GetRef<Call>(call_node));
}
}
return ExprMutator::VisitExpr_(call_node);
}
Expr VisitExpr_(const GlobalVarNode* gvn) final {
GlobalVar gv = GetRef<GlobalVar>(gvn);
auto* cg_node = (*call_graph_)[gv->name_hint];
if (CanInline(cg_node)) {
cur_node_->RemoveCallTo(gv);
return MakeNewExpr(gv, {}, GetRef<GlobalVar>(gvn));
}
return ExprMutator::VisitExpr_(gvn);
}
Function Inline(const Function& func) {
return FunctionNode::make(func->params,
VisitExpr(func->body),
func->ret_type,
func->type_params,
func->attrs);
}
private:
bool CanInline(const CallGraphEntry* cg_node) {
// The node must be a leaf node and it cannot be recursive.
if (!cg_node->empty() || cg_node->IsRecursive()) return false;
auto base_func = call_graph_->GetGlobalFunction(cg_node->GetGlobalVar());
auto func = Downcast<Function>(base_func);
// The body of a global functions must be defined.
if (!func->body.defined()) return false;
// The function must be annotated with the inline attribute.
if (!func->IsMarkedInline()) return false;
// The function is not abled to be inlined if any callee under the CallGraph
// of this function cannot be inlined.
for (const auto& it : *cg_node) {
if (!CanInline(it.second)) {
return false;
}
}
return true;
}
// Make a new Relay expression to replace the callee.
Expr MakeNewExpr(const GlobalVar& global,
const Array<Expr>& args,
const Expr& callee) {
CHECK(callee->IsInstance<CallNode>() ||
callee->IsInstance<GlobalVarNode>());
auto base_func = call_graph_->GetGlobalFunction(global);
const auto* fn = base_func.as<FunctionNode>();
CHECK(fn) << "Expected to work on a Relay function.";
auto func = FunctionNode::make(fn->params,
fn->body,
fn->ret_type,
fn->type_params,
fn->attrs);
// Inline the function body to the caller if this function uses default
// compiler, i.e. no external codegen is needed.
if (func->UseDefaultCompiler()) {
CHECK_EQ(func->params.size(), args.size())
<< "Mismatch found in the number of parameters and call args";
// Bind the parameters with call args.
Map<Var, Expr> bind_map;
for (size_t i = 0; i < args.size(); i++) {
bind_map.Set(fn->params[i], args[i]);
}
if (const auto* gvn = callee.as<GlobalVarNode>()) {
auto ret_type = gvn->checked_type();
// Cannot replace TensorType/TensorTupleType with FuncType. Therefore,
// we simply inline the function as a closure instead of directly using
// its body when the global var returns FuncType.
return ret_type->IsInstance<FuncTypeNode>() ? std::move(func)
: func->body;
} else {
CHECK(callee->IsInstance<CallNode>());
return Bind(func->body, bind_map);
}
} else if (const auto* call_node = callee.as<CallNode>()) {
return CallNode::make(func, args, call_node->attrs, call_node->type_args);
} else {
return std::move(func);
}
}
/*!
* \brief The current call graph entry that is being handled. Each entry
* contains a global function.
*/
CallGraphEntry* cur_node_;
/*! \brief The call graph that is used for global function lookup. */
const CallGraphNode* call_graph_;
};
IRModule Inline(const IRModule& module) {
CallGraph cg(module);
auto topo = cg->TopologicalOrder();
// Get the reverse topological order of the global functions.
std::reverse(topo.begin(), topo.end());
// Cache the functions that are originally entries. These functions will
// remain in the module after inlining.
std::unordered_set<CallGraphEntry*> original_entry;
for (auto* it : topo) {
if (it->GetRefCount() == 0) original_entry.emplace(it);
// Skip the leaf calls and the recursive calls that don't call other
// functions.
if (it->empty() || (it->IsRecursive() && it->size() == 1)) continue;
auto base_func = module->Lookup(it->GetNameHint());
if (const auto* fn = base_func.as<FunctionNode>()) {
auto func = GetRef<Function>(fn);
auto new_func = Inliner(it, cg.operator->()).Inline(func);
// TODO(zhiics) Maybe move this to CallGraph, but updating function from
// CallGraph arbitarily may lead to incorrect CallGraph.
cg->module->Update(it->GetGlobalVar(), new_func);
}
}
// Clean up the functions that are inlined and have no reference.
for (auto* cgn : topo) {
// Skip recursive functions and entry functions even if they are marked as
// `inline`.
if (cgn->IsRecursive() || original_entry.count(cgn)) continue;
auto base_func = cg->GetGlobalFunction(cgn->GetGlobalVar());
if (const auto* fn = base_func.as<FunctionNode>()) {
auto func = GetRef<Function>(fn);
if (func->IsMarkedInline()) {
CHECK_EQ(cgn->GetRefCount(), 0U)
<< cgn->GetNameHint() << " is marked as inline but not inlined.";
cgn->CleanCallGraphEntries();
cg->RemoveGlobalVarFromModule(cgn, /*update_call_graph*/ true);
}
}
}
return cg->module;
}
namespace transform {
Pass Inline() {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule m, PassContext pc) {
return relay::Inline(m);
};
return CreateModulePass(pass_func, 1, "InlineGlobals", {});
}
TVM_REGISTER_GLOBAL("relay._transform.Inline")
.set_body_typed(Inline);
} // namespace transform
} // namespace relay
} // namespace tvm
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, missing-docstring, too-many-statements
import tvm
from tvm import relay
def get_recursive_count_loop():
mod = tvm.IRModule({})
sum_up = relay.GlobalVar('sum_up')
i = relay.var('i', shape=[], dtype='int32')
sb = relay.ScopeBuilder()
with sb.if_scope(relay.equal(i, relay.const(0, dtype='int32'))):
sb.ret(i)
with sb.else_scope():
one_less = relay.subtract(i, relay.const(1, dtype='int32'))
rec_call = relay.Call(sum_up, [one_less])
sb.ret(relay.add(rec_call, i))
func = relay.Function([i],
sb.get(),
ret_type=relay.TensorType([], 'int32'))
func = func.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
mod[sum_up] = func
iarg = relay.var('i', shape=[], dtype='int32')
mod["main"] = relay.Function([iarg], sum_up(iarg))
return mod, sum_up
def test_call_chain_inline_leaf():
"""Test when only leaf call is inlined.
The call graph is like the following:
main
/ \
g1 g2
/
g11(inline)
"""
def get_mod():
mod = tvm.IRModule({})
x11 = relay.var("x11", shape=(3, 5))
g11 = relay.GlobalVar("g11")
fn11 = relay.Function([x11], x11)
fn11 = fn11.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
mod[g11] = fn11
x1 = relay.var("x1", shape=(3, 5))
y1 = relay.var("y1", shape=(3, 5))
sb = relay.ScopeBuilder()
sb.ret(x1 + y1 + g11(x1))
fn1 = relay.Function([x1, y1], sb.get())
g1 = relay.GlobalVar("g1")
mod[g1] = fn1
x2 = relay.var("x2", shape=(3, 5))
y2 = relay.var("y2", shape=(3, 5))
sb1 = relay.ScopeBuilder()
sb1.ret(x2 - y2)
fn2 = relay.Function([x2, y2], sb1.get())
g2 = relay.GlobalVar("g2")
mod[g2] = fn2
p0 = relay.var("p0", shape=(3, 5))
p1 = relay.var("p1", shape=(3, 5))
p2 = relay.var("p2", shape=(3, 5))
p3 = relay.var("p3", shape=(3, 5))
call_fn1 = g1(p0, p1)
call_fn2 = g2(p2, p3)
mod["main"] = relay.Function([p0, p1, p2, p3], call_fn1 * call_fn2)
return mod
def expected():
mod = tvm.IRModule({})
x1 = relay.var("x1", shape=(3, 5))
y1 = relay.var("y1", shape=(3, 5))
sb = relay.ScopeBuilder()
sb.ret(x1 + y1 + x1)
fn1 = relay.Function([x1, y1], sb.get())
g1 = relay.GlobalVar("g1")
mod[g1] = fn1
x2 = relay.var("x2", shape=(3, 5))
y2 = relay.var("y2", shape=(3, 5))
sb1 = relay.ScopeBuilder()
sb1.ret(x2 - y2)
fn2 = relay.Function([x2, y2], sb1.get())
g2 = relay.GlobalVar("g2")
mod[g2] = fn2
p0 = relay.var("p0", shape=(3, 5))
p1 = relay.var("p1", shape=(3, 5))
p2 = relay.var("p2", shape=(3, 5))
p3 = relay.var("p3", shape=(3, 5))
call_fn1 = g1(p0, p1)
call_fn2 = g2(p2, p3)
mod["main"] = relay.Function([p0, p1, p2, p3], call_fn1 * call_fn2)
return mod
mod = get_mod()
mod = relay.transform.Inline()(mod)
assert relay.analysis.alpha_equal(mod, expected())
def test_call_chain_inline_multiple_levels():
"""Test when only leaf call is inlined.
The call graph is like the following:
main
/ \
g1(inline) g2
/
g11(inline)
"""
def get_mod():
mod = tvm.IRModule({})
x11 = relay.var("x11", shape=(3, 5))
g11 = relay.GlobalVar("g11")
fn11 = relay.Function([x11], x11)
fn11 = fn11.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
mod[g11] = fn11
x1 = relay.var("x1", shape=(3, 5))
y1 = relay.var("y1", shape=(3, 5))
sb = relay.ScopeBuilder()
sb.ret(x1 + y1 + g11(x1))
fn1 = relay.Function([x1, y1], sb.get())
fn1 = fn1.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
g1 = relay.GlobalVar("g1")
mod[g1] = fn1
x2 = relay.var("x2", shape=(3, 5))
y2 = relay.var("y2", shape=(3, 5))
sb1 = relay.ScopeBuilder()
sb1.ret(x2 - y2)
fn2 = relay.Function([x2, y2], sb1.get())
g2 = relay.GlobalVar("g2")
mod[g2] = fn2
p0 = relay.var("p0", shape=(3, 5))
p1 = relay.var("p1", shape=(3, 5))
p2 = relay.var("p2", shape=(3, 5))
p3 = relay.var("p3", shape=(3, 5))
call_fn1 = g1(p0, p1)
call_fn2 = g2(p2, p3)
mod["main"] = relay.Function([p0, p1, p2, p3], call_fn1 * call_fn2)
return mod
def expected():
mod = tvm.IRModule({})
x2 = relay.var("x2", shape=(3, 5))
y2 = relay.var("y2", shape=(3, 5))
sb1 = relay.ScopeBuilder()
sb1.ret(x2 - y2)
fn2 = relay.Function([x2, y2], sb1.get())
g2 = relay.GlobalVar("g2")
mod[g2] = fn2
p0 = relay.var("p0", shape=(3, 5))
p1 = relay.var("p1", shape=(3, 5))
p2 = relay.var("p2", shape=(3, 5))
p3 = relay.var("p3", shape=(3, 5))
call_fn1 = p0 + p1 + p0
call_fn2 = g2(p2, p3)
mod["main"] = relay.Function([p0, p1, p2, p3], call_fn1 * call_fn2)
return mod
mod = get_mod()
mod = relay.transform.Inline()(mod)
assert relay.analysis.alpha_equal(mod, expected())
def test_call_chain_inline_multiple_levels_extern_compiler():
"""Test when only leaf call is inlined.
The call graph is like the following:
main
/ \
g1(inline) g2
/
g11(inline, external compiler)
"""
def get_mod():
mod = tvm.IRModule({})
x11 = relay.var("x11", shape=(3, 5))
g11 = relay.GlobalVar("g11")
fn11 = relay.Function([x11], x11)
fn11 = fn11.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
fn11 = fn11.set_attribute("Compiler", tvm.tir.StringImm("a"))
mod[g11] = fn11
x1 = relay.var("x1", shape=(3, 5))
y1 = relay.var("y1", shape=(3, 5))
sb = relay.ScopeBuilder()
sb.ret(x1 + y1 + g11(x1))
fn1 = relay.Function([x1, y1], sb.get())
fn1 = fn1.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
g1 = relay.GlobalVar("g1")
mod[g1] = fn1
x2 = relay.var("x2", shape=(3, 5))
y2 = relay.var("y2", shape=(3, 5))
sb1 = relay.ScopeBuilder()
sb1.ret(x2 - y2)
fn2 = relay.Function([x2, y2], sb1.get())
g2 = relay.GlobalVar("g2")
mod[g2] = fn2
p0 = relay.var("p0", shape=(3, 5))
p1 = relay.var("p1", shape=(3, 5))
p2 = relay.var("p2", shape=(3, 5))
p3 = relay.var("p3", shape=(3, 5))
call_fn1 = g1(p0, p1)
call_fn2 = g2(p2, p3)
mod["main"] = relay.Function([p0, p1, p2, p3], call_fn1 * call_fn2)
return mod
def expected():
mod = tvm.IRModule({})
x11 = relay.var("x11", shape=(3, 5))
fn11 = relay.Function([x11], x11)
fn11 = fn11.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
fn11 = fn11.set_attribute("Compiler", tvm.tir.StringImm("a"))
x2 = relay.var("x2", shape=(3, 5))
y2 = relay.var("y2", shape=(3, 5))
sb1 = relay.ScopeBuilder()
sb1.ret(x2 - y2)
fn2 = relay.Function([x2, y2], sb1.get())
g2 = relay.GlobalVar("g2")
mod[g2] = fn2
p0 = relay.var("p0", shape=(3, 5))
p1 = relay.var("p1", shape=(3, 5))
p2 = relay.var("p2", shape=(3, 5))
p3 = relay.var("p3", shape=(3, 5))
call_fn1 = p0 + p1 + fn11(p0)
call_fn2 = g2(p2, p3)
mod["main"] = relay.Function([p0, p1, p2, p3], call_fn1 * call_fn2)
return mod
mod = get_mod()
mod = relay.transform.Inline()(mod)
assert relay.analysis.alpha_equal(mod, expected())
def test_recursive_call_with_global():
def get_mod():
mod = tvm.IRModule({})
x = relay.var('x', shape=[], dtype='int32')
fn0 = relay.Function([x], x)
fn0 = fn0.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
gx = relay.GlobalVar("gx")
mod[gx] = fn0
sum_up = relay.GlobalVar('sum_up')
i = relay.var('i', shape=[], dtype='int32')
sb = relay.ScopeBuilder()
with sb.if_scope(relay.equal(i, relay.const(0, dtype="int32"))):
sb.ret(i)
with sb.else_scope():
one_less = relay.subtract(i, relay.const(1, dtype="int32"))
global_call = gx(i)
rec_call = relay.Call(sum_up, [one_less]) + global_call
sb.ret(relay.add(rec_call, i))
func = relay.Function([i],
sb.get(),
ret_type=relay.TensorType([], "int32"))
func = func.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
mod[sum_up] = func
iarg = relay.var("i", shape=[], dtype='int32')
mod["main"] = relay.Function([iarg], sum_up(iarg))
return mod
def expected():
mod = tvm.IRModule({})
sum_up = relay.GlobalVar('sum_up')
i = relay.var('i', shape=[], dtype='int32')
sb = relay.ScopeBuilder()
with sb.if_scope(relay.equal(i, relay.const(0, dtype='int32'))):
sb.ret(i)
with sb.else_scope():
one_less = relay.subtract(i, relay.const(1, dtype='int32'))
rec_call = relay.Call(sum_up, [one_less]) + i
sb.ret(relay.add(rec_call, i))
func = relay.Function([i],
sb.get(),
ret_type=relay.TensorType([], 'int32'))
func = func.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
mod[sum_up] = func
iarg = relay.var('i', shape=[], dtype='int32')
mod["main"] = relay.Function([iarg], sum_up(iarg))
return mod
mod = get_mod()
mod = relay.transform.Inline()(mod)
assert relay.analysis.alpha_equal(mod, expected())
def test_recursive_called():
mod, sum_up = get_recursive_count_loop()
iarg = relay.var('i', shape=[], dtype='int32')
mod["main"] = relay.Function([iarg], sum_up(iarg))
ref_mod = mod
mod = relay.transform.Inline()(mod)
assert relay.analysis.alpha_equal(mod, ref_mod)
def test_recursive_not_called():
def get_mod():
mod, sum_up = get_recursive_count_loop()
x = relay.var("x", shape=(2, 2))
y = relay.var("y", shape=(2, 2))
x1 = relay.var("x1", shape=(2, 2))
fn1 = relay.Function([x1], x1)
fn1 = fn1.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
g1 = relay.GlobalVar("g1")
mod[g1] = fn1
mod["main"] = relay.Function([x, y], x + y + g1(x))
return mod
def expected():
mod, sum_up = get_recursive_count_loop()
x = relay.var("x", shape=(2, 2))
y = relay.var("y", shape=(2, 2))
mod["main"] = relay.Function([x, y], x + y + x)
return mod
mod = get_mod()
mod = relay.transform.Inline()(mod)
ref_mod = expected()
assert relay.analysis.alpha_equal(mod, ref_mod)
def test_recursive_not_called_extern_compiler():
def get_mod():
mod, sum_up = get_recursive_count_loop()
x = relay.var("x", shape=(2, 2))
y = relay.var("y", shape=(2, 2))
x1 = relay.var("x1", shape=(2, 2))
fn1 = relay.Function([x1], x1)
fn1 = fn1.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
fn1 = fn1.set_attribute("Compiler", tvm.tir.StringImm("a"))
g1 = relay.GlobalVar("g1")
mod[g1] = fn1
mod["main"] = relay.Function([x, y], x + y + g1(x))
return mod
def expected():
mod, sum_up = get_recursive_count_loop()
x = relay.var("x", shape=(2, 2))
y = relay.var("y", shape=(2, 2))
x1 = relay.var("x1", shape=(2, 2))
fn1 = relay.Function([x1], x1)
fn1 = fn1.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
fn1 = fn1.set_attribute("Compiler", tvm.tir.StringImm("a"))
mod["main"] = relay.Function([x, y], x + y + fn1(x))
return mod
mod = get_mod()
mod = relay.transform.Inline()(mod)
ref_mod = expected()
assert relay.analysis.alpha_equal(mod, ref_mod)
def test_globalvar_as_call_arg():
def get_mod():
mod = tvm.IRModule({})
x1 = relay.var("x1", shape=(3, 5))
y1 = relay.var("y1", shape=(3, 5))
sb = relay.ScopeBuilder()
sb.ret(x1 + y1)
fn1 = relay.Function([x1, y1], sb.get())
fn1 = fn1.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
g1 = relay.GlobalVar("g1")
mod[g1] = fn1
x2 = relay.var("x2", shape=(3, 5))
y2 = relay.var("y2", shape=(3, 5))
sb1 = relay.ScopeBuilder()
sb1.ret(x2 - y2)
fn2 = relay.Function([x2, y2], sb1.get())
fn2 = fn2.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
g2 = relay.GlobalVar("g2")
mod[g2] = fn2
p0 = relay.var("p0", shape=(3, 5))
p1 = relay.var("p1", shape=(3, 5))
p2 = relay.var("p2", shape=(3, 5))
p3 = relay.var("p3", shape=(3, 5))
call_fn1 = g1(p0, p1)
call_fn2 = g2(p2, p3)
mod["main"] = relay.Function([p0, p1, p2, p3], call_fn1 * call_fn2)
return mod
def expected():
p0 = relay.var("p0", shape=(3, 5))
p1 = relay.var("p1", shape=(3, 5))
p2 = relay.var("p2", shape=(3, 5))
p3 = relay.var("p3", shape=(3, 5))
call_fn1 = p0 + p1
call_fn2 = p2 - p3
mod["main"] = relay.Function([p0, p1, p2, p3], call_fn1 * call_fn2)
return mod
mod = get_mod()
mod = relay.transform.Inline()(mod)
assert relay.analysis.alpha_equal(mod, expected())
def test_globalvar_as_call_arg_extern_compiler():
def get_mod():
mod = tvm.IRModule({})
x1 = relay.var("x1", shape=(3, 5))
y1 = relay.var("y1", shape=(3, 5))
sb = relay.ScopeBuilder()
sb.ret(x1 + y1)
fn1 = relay.Function([x1, y1], sb.get())
fn1 = fn1.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
fn1 = fn1.set_attribute("Compiler", tvm.tir.StringImm("a"))
g1 = relay.GlobalVar("g1")
mod[g1] = fn1
x2 = relay.var("x2", shape=(3, 5))
y2 = relay.var("y2", shape=(3, 5))
sb1 = relay.ScopeBuilder()
sb1.ret(x2 - y2)
fn2 = relay.Function([x2, y2], sb1.get())
fn2 = fn2.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
fn2 = fn2.set_attribute("Compiler", tvm.tir.StringImm("b"))
g2 = relay.GlobalVar("g2")
mod[g2] = fn2
p0 = relay.var("p0", shape=(3, 5))
p1 = relay.var("p1", shape=(3, 5))
p2 = relay.var("p2", shape=(3, 5))
p3 = relay.var("p3", shape=(3, 5))
call_fn1 = g1(p0, p1)
call_fn2 = g2(p2, p3)
mod["main"] = relay.Function([p0, p1, p2, p3], call_fn1 * call_fn2)
return mod
def expected():
mod = tvm.IRModule({})
x1 = relay.var("x1", shape=(3, 5))
y1 = relay.var("y1", shape=(3, 5))
sb = relay.ScopeBuilder()
sb.ret(x1 + y1)
fn1 = relay.Function([x1, y1], sb.get())
fn1 = fn1.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
fn1 = fn1.set_attribute("Compiler", tvm.tir.StringImm("a"))
x2 = relay.var("x2", shape=(3, 5))
y2 = relay.var("y2", shape=(3, 5))
sb1 = relay.ScopeBuilder()
sb1.ret(x2 - y2)
fn2 = relay.Function([x2, y2], sb1.get())
fn2 = fn2.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
fn2 = fn2.set_attribute("Compiler", tvm.tir.StringImm("b"))
p0 = relay.var("p0", shape=(3, 5))
p1 = relay.var("p1", shape=(3, 5))
p2 = relay.var("p2", shape=(3, 5))
p3 = relay.var("p3", shape=(3, 5))
call_fn1 = relay.Call(fn1, [p0, p1])
call_fn2 = relay.Call(fn2, [p2, p3])
mod["main"] = relay.Function([p0, p1, p2, p3], call_fn1 * call_fn2)
return mod
mod = get_mod()
mod = relay.transform.Inline()(mod)
assert relay.analysis.alpha_equal(mod, expected())
def test_inline_globalvar_without_args():
def get_mod():
mod = tvm.IRModule({})
fn1 = relay.Function([], relay.const(1))
fn1 = fn1.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
fn2 = relay.Function([], relay.const(2))
fn2 = fn2.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
g1 = relay.GlobalVar('g1')
g2 = relay.GlobalVar('g2')
mod[g1] = fn1
mod[g2] = fn2
p = relay.var('p', 'bool')
mod['main'] = relay.Function([p], relay.Call(relay.If(p, g1, g2), []))
return mod
def expected():
mod = tvm.IRModule({})
fn1 = relay.Function([], relay.const(1))
fn1 = fn1.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
fn2 = relay.Function([], relay.const(2))
fn2 = fn2.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
p = relay.var('p', 'bool')
mod['main'] = relay.Function([p], relay.Call(
relay.If(p, fn1, fn2), []))
return mod
mod = get_mod()
mod = relay.transform.Inline()(mod)
assert relay.analysis.alpha_equal(mod, expected())
def test_inline_globalvar_without_args_extern_compiler():
def get_mod():
mod = tvm.IRModule({})
fn1 = relay.Function([], relay.const(1))
fn1 = fn1.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
fn1 = fn1.set_attribute("Compiler", tvm.tir.StringImm("a"))
fn2 = relay.Function([], relay.const(2))
fn2 = fn2.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
fn2 = fn2.set_attribute("Compiler", tvm.tir.StringImm("b"))
g1 = relay.GlobalVar('g1')
g2 = relay.GlobalVar('g2')
mod[g1] = fn1
mod[g2] = fn2
p = relay.var('p', 'bool')
mod['main'] = relay.Function([p], relay.Call(relay.If(p, g1, g2), []))
return mod
def expected():
mod = tvm.IRModule({})
fn1 = relay.Function([], relay.const(1))
fn1 = fn1.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
fn1 = fn1.set_attribute("Compiler", tvm.tir.StringImm("a"))
fn2 = relay.Function([], relay.const(2))
fn2 = fn2.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
fn2 = fn2.set_attribute("Compiler", tvm.tir.StringImm("b"))
p = relay.var('p', 'bool')
mod['main'] = relay.Function([p], relay.Call(
relay.If(p, fn1, fn2), []))
return mod
mod = get_mod()
mod = relay.transform.Inline()(mod)
assert relay.analysis.alpha_equal(mod, expected())
def test_globalvar_called_by_multiple_functions():
"""Test when only leaf call is inlined.
The call graph is like the following:
main g0
/ \ /
g1 g2(inline)
"""
def get_mod():
mod = tvm.IRModule({})
x1 = relay.var("x1", shape=(3, 5))
y1 = relay.var("y1", shape=(3, 5))
sb = relay.ScopeBuilder()
sb.ret(x1 + y1)
fn1 = relay.Function([x1, y1], sb.get())
g1 = relay.GlobalVar("g1")
mod[g1] = fn1
x2 = relay.var("x2", shape=(3, 5))
y2 = relay.var("y2", shape=(3, 5))
sb1 = relay.ScopeBuilder()
sb1.ret(x2 - y2)
fn2 = relay.Function([x2, y2], sb1.get())
fn2 = fn2.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
g2 = relay.GlobalVar("g2")
mod[g2] = fn2
x0 = relay.var("x0", shape=(3, 5))
y0 = relay.var("y0", shape=(3, 5))
z0 = relay.var("z0", shape=(3, 5))
fn0 = relay.Function([x0, y0, z0], g2(x0, y0) + z0)
g0 = relay.GlobalVar("g0")
mod[g0] = fn0
p0 = relay.var("p0", shape=(3, 5))
p1 = relay.var("p1", shape=(3, 5))
p2 = relay.var("p2", shape=(3, 5))
p3 = relay.var("p3", shape=(3, 5))
call_fn1 = g1(p0, p1)
call_fn2 = g2(p2, p3)
mod["main"] = relay.Function([p0, p1, p2, p3], call_fn1 * call_fn2)
return mod
def expected():
mod = tvm.IRModule({})
x1 = relay.var("x1", shape=(3, 5))
y1 = relay.var("y1", shape=(3, 5))
sb = relay.ScopeBuilder()
sb.ret(x1 + y1)
fn1 = relay.Function([x1, y1], sb.get())
g1 = relay.GlobalVar("g1")
mod[g1] = fn1
p0 = relay.var("p0", shape=(3, 5))
p1 = relay.var("p1", shape=(3, 5))
p2 = relay.var("p2", shape=(3, 5))
p3 = relay.var("p3", shape=(3, 5))
call_fn2 = p2 - p3
mod["main"] = relay.Function([p0, p1, p2, p3], g1(p0, p1) * call_fn2)
x0 = relay.var("x0", shape=(3, 5))
y0 = relay.var("y0", shape=(3, 5))
z0 = relay.var("z0", shape=(3, 5))
fn0 = relay.Function([x0, y0, z0], x0 - y0 + z0)
g0 = relay.GlobalVar("g0")
mod[g0] = fn0
return mod
mod = get_mod()
mod = relay.transform.Inline()(mod)
assert relay.analysis.alpha_equal(mod, expected())
def test_entry_with_inline():
"""Test entry function with inline
The call graph is like the following:
g1(inline) g2(inline)
"""
def get_mod():
mod = tvm.IRModule({})
x1 = relay.var("x1", shape=(3, 5))
y1 = relay.var("y1", shape=(3, 5))
fn1 = relay.Function([x1, y1], x1 + y1)
fn1 = fn1.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
g1 = relay.GlobalVar("g1")
mod[g1] = fn1
x2 = relay.var("x2", shape=(3, 5))
y2 = relay.var("y2", shape=(3, 5))
fn2 = relay.Function([x2, y2], x2 - y2)
fn2 = fn2.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
g2 = relay.GlobalVar("g2")
mod[g2] = fn2
return mod
mod = get_mod()
mod = relay.transform.Inline()(mod)
assert relay.analysis.alpha_equal(mod, get_mod())
def test_callee_not_inline():
"""Test entry function with inline
The call graph is like the following:
main
|
g2(inline)
|
g1
"""
def get_mod():
mod = tvm.IRModule({})
x1 = relay.var("x1", shape=(3, 5))
y1 = relay.var("y1", shape=(3, 5))
fn1 = relay.Function([x1, y1], x1 + y1)
g1 = relay.GlobalVar("g1")
mod[g1] = fn1
x2 = relay.var("x2", shape=(3, 5))
y2 = relay.var("y2", shape=(3, 5))
fn2 = relay.Function([x2, y2], x2 - g1(x2, y2))
fn2 = fn2.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
g2 = relay.GlobalVar("g2")
mod[g2] = fn2
return mod
mod = get_mod()
mod = relay.transform.Inline()(mod)
assert relay.analysis.alpha_equal(mod, get_mod())
def test_callee_not_inline_leaf_inline():
"""Test entry function with inline
The call graph is like the following:
main
|
g2(inline)
|
g1
|
g0(inline)
"""
def get_mod():
mod = tvm.IRModule({})
x0 = relay.var("x0", shape=(3, 5))
y0 = relay.var("y0", shape=(3, 5))
fn0 = relay.Function([x0, y0], x0 * y0)
fn0 = fn0.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
g0 = relay.GlobalVar("g0")
mod[g0] = fn0
x1 = relay.var("x1", shape=(3, 5))
y1 = relay.var("y1", shape=(3, 5))
fn1 = relay.Function([x1, y1], x1 + g0(x1, y1))
g1 = relay.GlobalVar("g1")
mod[g1] = fn1
x2 = relay.var("x2", shape=(3, 5))
y2 = relay.var("y2", shape=(3, 5))
fn2 = relay.Function([x2, y2], x2 - g1(x2, y2))
fn2 = fn2.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
g2 = relay.GlobalVar("g2")
mod[g2] = fn2
return mod
def expected():
mod = tvm.IRModule({})
x1 = relay.var("x1", shape=(3, 5))
y1 = relay.var("y1", shape=(3, 5))
fn1 = relay.Function([x1, y1], x1 + x1 * y1)
g1 = relay.GlobalVar("g1")
mod[g1] = fn1
x2 = relay.var("x2", shape=(3, 5))
y2 = relay.var("y2", shape=(3, 5))
fn2 = relay.Function([x2, y2], x2 - g1(x2, y2))
fn2 = fn2.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
g2 = relay.GlobalVar("g2")
mod[g2] = fn2
return mod
mod = get_mod()
mod = relay.transform.Inline()(mod)
assert relay.analysis.alpha_equal(mod, expected())
def test_callee_not_inline_leaf_inline_extern_compiler():
"""Test entry function with inline
The call graph is like the following:
main
|
g2(inline)
|
g1
|
g0(inline, external compiler)
"""
def get_mod():
mod = tvm.IRModule({})
x0 = relay.var("x0", shape=(3, 5))
y0 = relay.var("y0", shape=(3, 5))
fn0 = relay.Function([x0, y0], x0 * y0)
fn0 = fn0.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
fn0 = fn0.set_attribute("Compiler", tvm.tir.StringImm("aa"))
g0 = relay.GlobalVar("g0")
mod[g0] = fn0
x1 = relay.var("x1", shape=(3, 5))
y1 = relay.var("y1", shape=(3, 5))
fn1 = relay.Function([x1, y1], x1 + g0(x1, y1))
g1 = relay.GlobalVar("g1")
mod[g1] = fn1
x2 = relay.var("x2", shape=(3, 5))
y2 = relay.var("y2", shape=(3, 5))
fn2 = relay.Function([x2, y2], x2 - g1(x2, y2))
fn2 = fn2.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
g2 = relay.GlobalVar("g2")
mod[g2] = fn2
return mod
def expected():
mod = tvm.IRModule({})
x0 = relay.var("x0", shape=(3, 5))
y0 = relay.var("y0", shape=(3, 5))
fn0 = relay.Function([x0, y0], x0 * y0)
fn0 = fn0.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
fn0 = fn0.set_attribute("Compiler", tvm.tir.StringImm("aa"))
x1 = relay.var("x1", shape=(3, 5))
y1 = relay.var("y1", shape=(3, 5))
fn1 = relay.Function([x1, y1], x1 + fn0(x1, y1))
g1 = relay.GlobalVar("g1")
mod[g1] = fn1
x2 = relay.var("x2", shape=(3, 5))
y2 = relay.var("y2", shape=(3, 5))
fn2 = relay.Function([x2, y2], x2 - g1(x2, y2))
fn2 = fn2.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
g2 = relay.GlobalVar("g2")
mod[g2] = fn2
return mod
mod = get_mod()
mod = relay.transform.Inline()(mod)
assert relay.analysis.alpha_equal(mod, expected())
if __name__ == '__main__':
pytest.main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment