lambda_lift.cc 7.08 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file tvm/relay/backend/vm/lambda_lift.cc
 * \brief Lift all local functions into global functions.
 */

#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
27
#include <tvm/support/logging.h>
Zhi committed
28
#include <tvm/relay/analysis.h>
29
#include <tvm/relay/transform.h>
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
#include <tvm/runtime/vm.h>
#include <iostream>
#include <vector>

using namespace tvm::runtime;

namespace tvm {
namespace relay {
namespace vm {

inline std::string GenerateName(const Function& func) {
  size_t hash = StructuralHash()(func);
  return std::string("lifted_name") + std::to_string(hash);
}

bool IsClosure(const Function& func) {
46
  ObjectRef res = FunctionGetAttr(func, attr::kClosure);
47
  const tir::IntImmNode* pval = res.as<tir::IntImmNode>();
48 49 50 51
  return pval && pval->value != 0;
}

Function MarkClosure(const Function& func) {
Zhi committed
52
  return FunctionSetAttr(func, attr::kClosure, tvm::Integer(1));
53 54
}

55 56 57 58 59 60
/* The goal of this class is to lift out any nested functions into top-level
 * functions.
 *
 * We will lift a function out into a global which takes the set of the free
 * vars and then return the new created function.
 */
61 62
class LambdaLifter : public ExprMutator {
 public:
63
  explicit LambdaLifter(const IRModule& module) : module_(module) {}
64

65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
  Expr VisitExpr_(const LetNode* let_node) final {
    bool is_lambda = false;
    if (auto func = let_node->value.as<FunctionNode>()) {
      if (!func->IsPrimitive()) {
        is_lambda = true;
        letrec_.push_back(let_node->var);
      }
    }
    auto value = VisitExpr(let_node->value);
    if (is_lambda) {
      letrec_.pop_back();
    }
    auto body = VisitExpr(let_node->body);
    return LetNode::make(let_node->var, value, body);
  }

  Expr VisitExpr_(const CallNode* call_node) final {
    auto call = Downcast<Call>(ExprMutator::VisitExpr_(call_node));
    if (auto var_node = call_node->op.as<VarNode>()) {
      auto var = GetRef<Var>(var_node);
      if (!letrec_.empty() && var == letrec_.back()) {
        auto it = lambda_map_.find(var);
        CHECK(it != lambda_map_.end());
        return CallNode::make(it->second, call->args, call_node->attrs,
                              call_node->type_args);
      }
    }
    return std::move(call);
  }

95 96 97 98 99 100 101 102
  Expr VisitExpr_(const FunctionNode* func_node) final {
    auto func = GetRef<Function>(func_node);

    // We should not transform primitive functions.
    if (func->IsPrimitive()) {
      return std::move(func);
    }

103
    auto name = GenerateName(func);
104
    auto global = GlobalVar(name);
105 106
    auto free_vars = FreeVars(func);
    auto free_type_vars = FreeTypeVars(func, module_);
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127

    Array<Var> captured_vars;
    bool recursive = false;
    for (const auto& var : free_vars) {
      if (!letrec_.empty() && var == letrec_.back()) {
        recursive = true;
        continue;
      }
      captured_vars.push_back(var);
    }
    if (recursive) {
      if (!captured_vars.empty()) {
        Array<Expr> fvs;
        for (auto fv : captured_vars) {
          fvs.push_back(fv);
        }
        lambda_map_.emplace(letrec_.back(), CallNode::make(global, fvs));
      } else {
        lambda_map_.emplace(letrec_.back(), global);
      }
    }
128 129
    auto body = Downcast<Function>(ExprMutator::VisitExpr_(func_node));

130
    // When performing this optimization there are two cases.
131 132 133 134 135 136 137
    //
    // The first case in which we have no free variables
    // we can just lift the function into the global
    // environment without needing to allocate a closure.
    //
    //
    // The second case requires that we generate a special
138
    // function which makes a distinction between allocating
139 140 141 142 143 144 145 146 147 148 149
    // a closure, and then the code for the closure.
    //
    // We represent a closure allocation by lifting the
    // closure to a global function which takes its
    // captured arguments and then directly returns
    // the function representing the closure's code.
    //
    // When we generate code later on a call to the "outer"
    // function marked as a closure is used to emit allocation
    // code for the closure's environment.
    //
150
    // The "inner" function should be used to generate the
151 152
    // code for the closure.
    Function lifted_func;
153
    if (captured_vars.size() == 0 && free_type_vars.size() == 0) {
154
      lifted_func = FunctionNode::make(body->params, body->body, body->ret_type, body->type_params);
155 156
    } else {
      lifted_func =
157
          FunctionNode::make(captured_vars, body, func->func_type_annotation(), free_type_vars);
158 159 160 161 162 163
      lifted_func = MarkClosure(lifted_func);
    }

    CHECK(lifted_func.defined());


164 165 166 167 168 169 170 171 172
    if (module_->ContainGlobalVar(name)) {
      const auto existing_func = module_->Lookup(name);
      CHECK(AlphaEqual(lifted_func, existing_func)) << "lifted function hash collision";
      // If an identical function already exists, use its global var.
      global = module_->GetGlobalVar(name);
    } else {
      // Add the lifted function to the module.
      module_->Add(global, lifted_func);
    }
173

174
    if (captured_vars.size() == 0) {
175 176
      return std::move(global);
    } else {
177 178
      // If we need to allocate a closure,
      // we pass the variables in its environment here.
179
      Array<Expr> fvs;
180
      for (auto fv : captured_vars) {
181 182 183 184 185 186
        fvs.push_back(fv);
      }
      return CallNode::make(global, fvs);
    }
  }

187
  IRModule Lift() {
188 189 190
    // There is an ordering bug here.
    auto glob_funcs = module_->functions;
    for (auto pair : glob_funcs) {
191 192 193 194 195 196 197 198 199
      if (auto* n = pair.second.as<FunctionNode>()) {
        auto func = GetRef<Function>(n);
        func = FunctionNode::make(func->params,
                                  VisitExpr(func->body),
                                  func->ret_type,
                                  func->type_params,
                                  func->attrs);
        module_->Add(pair.first, func, true);
      }
200 201
    }
    return module_;
202
  }
203 204

 private:
205
  std::unordered_map<Var, Expr, ObjectHash, ObjectEqual> lambda_map_;
206
  std::vector<Var> letrec_;
207
  IRModule module_;
208 209
};

210
}  // namespace vm
211

212
namespace transform {
213

214
Pass LambdaLift() {
215 216
  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
    [=](IRModule m, PassContext pc) {
217 218 219 220
    return relay::vm::LambdaLifter(m).Lift();
  };
  return CreateModulePass(pass_func, 1, "LambdaLift", {});
}
221

222
TVM_REGISTER_GLOBAL("relay._transform.LambdaLift")
223
.set_body_typed(LambdaLift);
224

225
}  // namespace transform
226 227 228

}  // namespace relay
}  // namespace tvm