inline.cc 8.12 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
21
 * \file src/relay/transforms/inline.cc
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
 * \brief Global function inliner. It contains the following steps:
 *
 *  - Preprocessing: eligibility checking. Only inline the functions that can
 *  be inlined. We currently only use simple rules to make the decision. No
 *  profitibility analysis is available for now.
 *
 *  - Inline: replace the call with a function or the function body depending on
 *  the attribute of the callee function. For example, we return the function
 *  node when it doesn't use default compiler, i.e. llvm. This is because these
 *  functions are packed to be offloaded to external codegen.
 *
 *  - Postprocessing: remove the replaced functions that have no reference.
 */

#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/support/logging.h>
#include <tvm/relay/transform.h>
#include <string>
#include <unordered_set>

43
#include "../analysis/call_graph.h"
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85

using namespace tvm::runtime;

namespace tvm {
namespace relay {

class Inliner : ExprMutator {
 public:
  explicit Inliner(CallGraphEntry* cur_node, CallGraphNode* call_graph)
      : cur_node_(cur_node), call_graph_(call_graph) {}

  Expr VisitExpr_(const CallNode* call_node) final {
    Expr op = call_node->op;
    const auto* gvn = op.as<GlobalVarNode>();

    if (gvn) {
      GlobalVar gv = GetRef<GlobalVar>(gvn);
      auto* cg_node = (*call_graph_)[gv->name_hint];
      if (CanInline(cg_node)) {
        tvm::Array<Expr> call_args;
        for (auto arg : call_node->args) {
          auto new_arg = VisitExpr(arg);
          call_args.push_back(new_arg);
        }
        cur_node_->RemoveCallTo(gv);
        return MakeNewExpr(gv, call_args, GetRef<Call>(call_node));
      }
    }
    return ExprMutator::VisitExpr_(call_node);
  }

  Expr VisitExpr_(const GlobalVarNode* gvn) final {
    GlobalVar gv = GetRef<GlobalVar>(gvn);
    auto* cg_node = (*call_graph_)[gv->name_hint];
    if (CanInline(cg_node)) {
      cur_node_->RemoveCallTo(gv);
      return MakeNewExpr(gv, {}, GetRef<GlobalVar>(gvn));
    }
    return ExprMutator::VisitExpr_(gvn);
  }

  Function Inline(const Function& func) {
86
    return Function(func->params,
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
                              VisitExpr(func->body),
                              func->ret_type,
                              func->type_params,
                              func->attrs);
  }

 private:
  bool CanInline(const CallGraphEntry* cg_node) {
    // The node must be a leaf node and it cannot be recursive.
    if (!cg_node->empty() || cg_node->IsRecursive()) return false;

    auto base_func = call_graph_->GetGlobalFunction(cg_node->GetGlobalVar());
    auto func = Downcast<Function>(base_func);
    // The body of a global functions must be defined.
    if (!func->body.defined()) return false;

    // The function must be annotated with the inline attribute.
104
    if (!func->HasNonzeroAttr(attr::kInline)) return false;
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126

    // The function is not abled to be inlined if any callee under the CallGraph
    // of this function cannot be inlined.
    for (const auto& it : *cg_node) {
      if (!CanInline(it.second)) {
        return false;
      }
    }

    return true;
  }

  // Make a new Relay expression to replace the callee.
  Expr MakeNewExpr(const GlobalVar& global,
                   const Array<Expr>& args,
                   const Expr& callee) {
    CHECK(callee->IsInstance<CallNode>() ||
          callee->IsInstance<GlobalVarNode>());
    auto base_func = call_graph_->GetGlobalFunction(global);
    const auto* fn = base_func.as<FunctionNode>();
    CHECK(fn) << "Expected to work on a Relay function.";

127
    auto func = Function(fn->params,
128 129 130 131
                         fn->body,
                         fn->ret_type,
                         fn->type_params,
                         fn->attrs);
132 133
    // Inline the function body to the caller if this function uses default
    // compiler, i.e. no external codegen is needed.
134
    if (!func->GetAttr<String>(attr::kCompiler).defined()) {
135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
      CHECK_EQ(func->params.size(), args.size())
          << "Mismatch found in the number of parameters and call args";
      // Bind the parameters with call args.
      Map<Var, Expr> bind_map;
      for (size_t i = 0; i < args.size(); i++) {
        bind_map.Set(fn->params[i], args[i]);
      }
      if (const auto* gvn = callee.as<GlobalVarNode>()) {
        auto ret_type = gvn->checked_type();
        // Cannot replace TensorType/TensorTupleType with FuncType. Therefore,
        // we simply inline the function as a closure instead of directly using
        // its body when the global var returns FuncType.
        return ret_type->IsInstance<FuncTypeNode>() ? std::move(func)
                                                    : func->body;
      } else {
        CHECK(callee->IsInstance<CallNode>());
        return Bind(func->body, bind_map);
      }
    } else if (const auto* call_node = callee.as<CallNode>()) {
154
        return Call(func, args, call_node->attrs, call_node->type_args);
155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200
    } else {
      return std::move(func);
    }
  }

  /*!
   * \brief The current call graph entry that is being handled. Each entry
   * contains a global function.
   */
  CallGraphEntry* cur_node_;
  /*! \brief The call graph that is used for global function lookup. */
  const CallGraphNode* call_graph_;
};

IRModule Inline(const IRModule& module) {
  CallGraph cg(module);
  auto topo = cg->TopologicalOrder();
  // Get the reverse topological order of the global functions.
  std::reverse(topo.begin(), topo.end());
  // Cache the functions that are originally entries. These functions will
  // remain in the module after inlining.
  std::unordered_set<CallGraphEntry*> original_entry;

  for (auto* it : topo) {
    if (it->GetRefCount() == 0) original_entry.emplace(it);
    // Skip the leaf calls and the recursive calls that don't call other
    // functions.
    if (it->empty() || (it->IsRecursive() && it->size() == 1)) continue;
    auto base_func = module->Lookup(it->GetNameHint());
    if (const auto* fn = base_func.as<FunctionNode>()) {
      auto func = GetRef<Function>(fn);
      auto new_func = Inliner(it, cg.operator->()).Inline(func);
      // TODO(zhiics) Maybe move this to CallGraph, but updating function from
      // CallGraph arbitarily may lead to incorrect CallGraph.
      cg->module->Update(it->GetGlobalVar(), new_func);
    }
  }

  // Clean up the functions that are inlined and have no reference.
  for (auto* cgn : topo) {
    // Skip recursive functions and entry functions even if they are marked as
    // `inline`.
    if (cgn->IsRecursive() || original_entry.count(cgn)) continue;
    auto base_func = cg->GetGlobalFunction(cgn->GetGlobalVar());
    if (const auto* fn = base_func.as<FunctionNode>()) {
      auto func = GetRef<Function>(fn);
201
      if (func->HasNonzeroAttr(attr::kInline)) {
202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229
        CHECK_EQ(cgn->GetRefCount(), 0U)
            << cgn->GetNameHint() << " is marked as inline but not inlined.";
        cgn->CleanCallGraphEntries();
        cg->RemoveGlobalVarFromModule(cgn, /*update_call_graph*/ true);
      }
    }
  }

  return cg->module;
}

namespace transform {

Pass Inline() {
  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
    [=](IRModule m, PassContext pc) {
      return relay::Inline(m);
  };
  return CreateModulePass(pass_func, 1, "InlineGlobals", {});
}

TVM_REGISTER_GLOBAL("relay._transform.Inline")
.set_body_typed(Inline);

}  // namespace transform

}  // namespace relay
}  // namespace tvm