inline_primitives.cc 4.39 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 *  Copyright (c) 2019 by Contributors
 * \file tvm/relay/backend/vm/inline_primitives.cc
 * \brief Ensure that primitives only appear in the call position.
 */

#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/logging.h>
29
#include <tvm/relay/transform.h>
30 31 32 33 34 35 36 37 38 39
#include <tvm/runtime/vm.h>
#include <iostream>
#include <vector>

using namespace tvm::runtime;

namespace tvm {
namespace relay {
namespace vm {

40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
// TODO(@jroesch): write verifier

/* This pass will eliminate primitives which have been lifted by the ANF
 * transform inlining them directly into call sites.
 *
 * This makes VM related code generation easier as the call target is always
 * a primitive function.
 *
 * let prim = fn(...) { ... };
 * prim(...)
 *
 * will become:
 *
 * (fn(...) { ... })(...)
 */
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
struct PrimitiveInliner : ExprMutator {
  Module module_;
  std::unordered_map<Var, Expr, NodeHash, NodeEqual> var_map;

  explicit PrimitiveInliner(const Module& module) : module_(module) {}

  Expr VisitExpr_(const LetNode* let_node) {
    var_map.insert({let_node->var, VisitExpr(let_node->value)});
    return ExprMutator::VisitExpr_(let_node);
  }

  Expr VisitExpr_(const CallNode* call) {
    Expr op = call->op;
    // For now just collapse the chain of variables to see if
    // they point to a primitive function.
    const VarNode* var_node;

    // Collapse a chain of let bindings
    //
    // let x = fn (..) { .. };
    // let y = x
    // let w = y
    // in w(...)
    while ((var_node = op.as<VarNode>())) {
      auto var = GetRef<Var>(var_node);
      DLOG(INFO) << "Var: " << var << std::endl;
      auto it = var_map.find(GetRef<Var>(var_node));
      if (it != var_map.end()) {
        op = it->second;
      } else {
        return ExprMutator::VisitExpr_(call);
      }
    }

    if (auto func = op.as<FunctionNode>()) {
      if (func->IsPrimitive()) {
        return CallNode::make(GetRef<Function>(func), call->args, call->attrs, call->type_args);
      }
    }

    if (auto global = op.as<GlobalVarNode>()) {
      return CallNode::make(GetRef<GlobalVar>(global), call->args, call->attrs, call->type_args);
    }

    return ExprMutator::VisitExpr_(call);
  }

  Expr VisitExpr_(const FunctionNode* func) {
    if (func->IsPrimitive()) {
      return GetRef<Function>(func);
    } else {
      return ExprMutator::VisitExpr_(func);
    }
  }

110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
  Module Inline() {
    auto gvar_funcs = module_->functions;
    for (auto pair : gvar_funcs) {
      auto global = pair.first;
      auto func = pair.second;
      DLOG(INFO) << "Before inlining primitives: " << global
                 << std::endl << AsText(func, false);

      func = FunctionNode::make(func->params,
                                VisitExpr(func->body),
                                func->ret_type,
                                func->type_params,
                                func->attrs);
      module_->Add(global, func, true);

      DLOG(INFO) << "After inlining primitives: " << global
                 << std::endl << AsText(func, false);
    }
    return module_;
129 130 131
  }
};

132
}  // namespace vm
133

134
namespace transform {
135

136 137 138 139 140 141 142 143 144
Pass InlinePrimitives() {
  runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func =
    [=](Module m, PassContext pc) {
      return relay::vm::PrimitiveInliner(m).Inline();
  };
  auto inline_pass = CreateModulePass(pass_func, 1, "Inline", {});
  // Eliminate dead code for each function after inlining.
  return Sequential({inline_pass, DeadCodeElimination()}, "InlinePrimitives");
}
145

146 147
TVM_REGISTER_API("relay._transform.InlinePrimitives")
.set_body_typed(InlinePrimitives);
148

149
}  // namespace transform
150 151 152

}  // namespace relay
}  // namespace tvm