llvm_module.cc 11.6 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27
/*!
 *  Copyright (c) 2017 by Contributors
 * \file llvm_module.cc
 * \brief LLVM runtime module for TVM
 */
#ifdef TVM_LLVM_VERSION
#include <tvm/runtime/packed_func.h>
#include <tvm/codegen.h>
28
#include <mutex>
29 30
#include "llvm_common.h"
#include "codegen_llvm.h"
31
#include "../../runtime/file_util.h"
32
#include "../../runtime/module_util.h"
33 34 35 36 37 38 39 40

namespace tvm {
namespace codegen {

using runtime::TVMArgs;
using runtime::TVMRetValue;
using runtime::PackedFunc;

41
class LLVMModuleNode final : public runtime::ModuleNode {
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
 public:
  ~LLVMModuleNode() {
    module_.reset();
    if (ee_ != nullptr) {
      ee_->runStaticConstructorsDestructors(true);
      delete ee_;
    }
  }

  const char* type_key() const {
    return "llvm";
  }

  PackedFunc GetFunction(
      const std::string& name,
      const std::shared_ptr<ModuleNode>& sptr_to_self) final {
Tianqi Chen committed
58 59 60 61 62 63 64
    if (name == "__tvm_is_system_module") {
      bool flag =
          (mptr_->getFunction("__tvm_module_startup") != nullptr);
      return PackedFunc([flag](TVMArgs args, TVMRetValue *rv) {
          * rv = flag;
        });
    }
65 66
    if (ee_ == nullptr) LazyInitJIT();
    std::lock_guard<std::mutex> lock(mutex_);
67 68
    const std::string& fname = (name == runtime::symbol::tvm_module_main ?
                                entry_func_ : name);
69

70
    BackendPackedCFunc faddr =
71
        reinterpret_cast<BackendPackedCFunc>(GetFunctionAddr(fname));
72
    if (faddr == nullptr) return PackedFunc();
nhynes committed
73
    return WrapPackedFunc(faddr, sptr_to_self);
74 75 76 77 78 79 80 81 82 83
  }

  void SaveToFile(const std::string& file_name,
                  const std::string& format) final {
    std::string fmt = runtime::GetFileFormat(file_name, format);
    std::error_code ecode;
    llvm::raw_fd_ostream dest(file_name, ecode, llvm::sys::fs::F_None);
    CHECK_EQ(ecode.value(), 0) << "Cannot open file: " << file_name
                               << " " << ecode.message();
    if (fmt == "o" || fmt == "obj") {
HungMingWu committed
84
#if TVM_LLVM_VERSION <= 60
85
      std::unique_ptr<llvm::Module> m = llvm::CloneModule(mptr_);
HungMingWu committed
86 87 88
#else
      std::unique_ptr<llvm::Module> m = llvm::CloneModule(*mptr_);
#endif
89 90
      llvm::legacy::PassManager pass;
      CHECK(tm_);
91
#if TVM_LLVM_VERSION <= 60
92 93 94
      CHECK(tm_->addPassesToEmitFile(
          pass, dest, llvm::TargetMachine::CGFT_ObjectFile) == 0)
          << "Cannot emit target CGFT_ObjectFile";
95 96 97 98 99
#else
      CHECK(tm_->addPassesToEmitFile(
          pass, dest, nullptr, llvm::TargetMachine::CGFT_ObjectFile) == 0)
          << "Cannot emit target CGFT_ObjectFile";
#endif
100
      pass.run(*m);
101
    } else if (fmt == "s" || fmt == "asm") {
HungMingWu committed
102
#if TVM_LLVM_VERSION <= 60
103
      std::unique_ptr<llvm::Module> m = llvm::CloneModule(mptr_);
HungMingWu committed
104 105 106
#else
      std::unique_ptr<llvm::Module> m = llvm::CloneModule(*mptr_);
#endif
107 108
      llvm::legacy::PassManager pass;
      CHECK(tm_);
109
#if TVM_LLVM_VERSION <= 60
110 111 112
      CHECK(tm_->addPassesToEmitFile(
          pass, dest, llvm::TargetMachine::CGFT_AssemblyFile) == 0)
          << "Cannot emit target CGFT_AssemblyFile";
113 114 115 116 117
#else
      CHECK(tm_->addPassesToEmitFile(
          pass, dest, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == 0)
          << "Cannot emit target CGFT_AssemblyFile";
#endif
118
      pass.run(*m);
119 120 121
    } else if (fmt == "ll") {
      mptr_->print(dest, nullptr);
    } else if (fmt == "bc") {
HungMingWu committed
122
#if TVM_LLVM_VERSION <= 60
123
      llvm::WriteBitcodeToFile(mptr_, dest);
HungMingWu committed
124 125 126
#else
      llvm::WriteBitcodeToFile(*mptr_, dest);
#endif
127 128 129 130 131 132 133
    } else {
      LOG(FATAL) << "Do not know how to save file "
                 << file_name << " with format=\'"<< format << "\'";
    }
    dest.close();
  }

134 135 136 137
  void SaveToBinary(dmlc::Stream* stream) final {
    LOG(FATAL) << "LLVMModule: SaveToBinary not supported";
  }

138
  std::string GetSource(const std::string& format) final {
139
    std::string fmt = runtime::GetFileFormat("", format);
140
    std::string type_str;
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
    llvm::SmallString<256> str;
    llvm::raw_svector_ostream rso(str);

    if (fmt == "s" || fmt == "asm") {
    #if TVM_LLVM_VERSION <= 60
          std::unique_ptr<llvm::Module> m = llvm::CloneModule(mptr_);
    #else
          std::unique_ptr<llvm::Module> m = llvm::CloneModule(*mptr_);
    #endif
          llvm::legacy::PassManager pass;
          CHECK(tm_);
    #if TVM_LLVM_VERSION <= 60
          CHECK(tm_->addPassesToEmitFile(
              pass, rso, llvm::TargetMachine::CGFT_AssemblyFile) == 0)
              << "Cannot emit target CGFT_AssemblyFile";
    #else
          CHECK(tm_->addPassesToEmitFile(
              pass, rso, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == 0)
              << "Cannot emit target CGFT_AssemblyFile";
    #endif
          pass.run(*m);
          return rso.str().str();
    } else if (fmt == "" || fmt == "ll") {
      std::string type_str;
      llvm::raw_string_ostream rso(type_str);
      CHECK(mptr_ != nullptr);
      mptr_->print(rso, nullptr);
      return rso.str();
    } else {
      LOG(FATAL) << "Do not know how to get source code with format: "
                 << format << "\'";
    }
    return "";
174 175 176 177
  }

  void Init(const Array<LoweredFunc>& funcs, std::string target) {
    InitializeLLVM();
178
    tm_ = GetLLVMTargetMachine(target);
179
    bool system_lib = (target.find("-system-lib") != std::string::npos);
180 181
    CHECK_NE(funcs.size(), 0U);
    ctx_ = std::make_shared<llvm::LLVMContext>();
182
    std::unique_ptr<CodeGenLLVM> cg = CodeGenLLVM::Create(tm_.get());
183
    entry_func_ = funcs[0]->name;
184
    cg->Init(funcs[0]->name, tm_.get(), ctx_.get(), system_lib, system_lib);
185
    for (LoweredFunc f :  funcs) {
186
      cg->AddFunction(f);
187
    }
188 189
    cg->AddMainFunction(funcs[0]->name);
    module_ = cg->Finish();
190 191 192 193 194
    std::string verify_errors_storage;
    llvm::raw_string_ostream verify_errors(verify_errors_storage);
    LOG_IF(FATAL, llvm::verifyModule(*module_, &verify_errors))
        << "LLVM module verification failed with the following errors: \n"
        << verify_errors.str();
195 196 197 198
    module_->addModuleFlag(
        llvm::Module::Warning, "tvm_target",
        llvm::MDString::get(*ctx_, target));
    target_ = target;
199 200 201
    mptr_ = module_.get();
  }

202 203 204 205 206
  void LoadIR(const std::string& file_name) {
    InitializeLLVM();
    ctx_ = std::make_shared<llvm::LLVMContext>();
    llvm::SMDiagnostic err;
    module_ = llvm::parseIRFile(file_name, err, *ctx_);
Tianqi Chen committed
207 208 209 210 211
    if (module_.get() == nullptr) {
      std::string msg = err.getMessage();
      LOG(FATAL) << "Fail to load ir file " << file_name << "\n"
                 << "line " << err.getLineNo() << ":" << msg;
    }
212 213 214 215 216 217 218 219 220 221 222
    std::string target_;
    llvm::Metadata* mtarget = module_->getModuleFlag("tvm_target");
    if (mtarget != nullptr) {
      llvm::MDString* pstr = llvm::dyn_cast<llvm::MDString>(mtarget);
      CHECK(pstr != nullptr);
      target_ = pstr->getString();
    } else {
      std::ostringstream os;
      os << "llvm -target " << module_->getTargetTriple();
      target_ = os.str();
    }
223
    mptr_ = module_.get();
224
    tm_ = GetLLVMTargetMachine(target_);
225 226
  }

227 228 229
 private:
  void LazyInitJIT() {
    std::lock_guard<std::mutex> lock(mutex_);
230 231 232
    if (ee_) {
      return;
    }
233
    llvm::EngineBuilder builder(std::move(module_));
234 235 236
    std::string triple, mcpu, mattr;
    llvm::TargetOptions opt;
    ParseLLVMTargetOptions(target_, &triple, &mcpu, &mattr, &opt);
237 238
    builder.setEngineKind(llvm::EngineKind::JIT);
    builder.setOptLevel(llvm::CodeGenOpt::Aggressive);
239 240 241 242 243 244 245 246
    if (mcpu.length() != 0) {
      builder.setMCPU(mcpu);
    }
    if (mattr.length() != 0) {
      std::vector<std::string> mattrs{mattr};
      builder.setMAttrs(mattrs);
    }
    builder.setTargetOptions(opt);
247 248
    auto tm = std::unique_ptr<llvm::TargetMachine>(builder.selectTarget());
    std::unique_ptr<llvm::TargetMachine> tm_sys = GetLLVMTargetMachine("llvm");
249 250 251 252 253
    if (tm_sys->getTargetTriple().getArch() != tm->getTargetTriple().getArch()) {
      LOG(FATAL) << "Cannot run module, architecture mismatch "
                 << " module=" << tm->getTargetTriple().str()
                 << " system=" << tm_sys->getTargetTriple().str();
    }
254 255 256 257 258 259
    llvm::DataLayout layout(tm->createDataLayout());
    CHECK(layout == mptr_->getDataLayout())
        << "Data layout mismatch between module("
        << mptr_->getDataLayout().getStringRepresentation() << ")"
        << " and ExecutionEngine ("
        << layout.getStringRepresentation() << ")";
260
    ee_ = builder.create(tm.release());
261
    CHECK(ee_ != nullptr)
262
        << "Failed to initialize git engine for " << mptr_->getTargetTriple();
263 264
    ee_->runStaticConstructorsDestructors(false);
    // setup context address.
265
    entry_func_ =
266
        reinterpret_cast<const char*>(GetGlobalAddr(runtime::symbol::tvm_module_main));
267
    if (void** ctx_addr = reinterpret_cast<void**>(
268
            GetGlobalAddr(runtime::symbol::tvm_module_ctx))) {
269 270
      *ctx_addr = this;
    }
271
    runtime::InitContextFunctions([this](const char *name) {
272
        return GetGlobalAddr(name);
273
      });
274
  }
275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292
  // Get global address from execution engine.
  uint64_t GetGlobalAddr(const std::string& name) {
    // first verifies if GV exists.
    if (mptr_->getGlobalVariable(name) != nullptr) {
      return ee_->getGlobalValueAddress(name);
    } else {
      return 0;
    }
  }
  uint64_t GetFunctionAddr(const std::string& name) {
    // first verifies if GV exists.
    if (mptr_->getFunction(name) != nullptr) {
      return ee_->getFunctionAddress(name);
    } else {
      return 0;
    }
  }

293
  // The target configuration string
294
  std::string target_;
295 296
  // Name of entry function.
  std::string entry_func_;
297 298 299 300 301 302 303
  // JIT lock
  std::mutex mutex_;
  // execution engine
  llvm::ExecutionEngine *ee_{nullptr};
  // The raw pointer to the module.
  llvm::Module* mptr_{nullptr};
  // The target machine
304
  std::unique_ptr<llvm::TargetMachine> tm_{nullptr};
305 306 307 308 309 310
  // The module, can be moved to ee if JIT is enabled.
  std::unique_ptr<llvm::Module> module_;
  // the context.
  std::shared_ptr<llvm::LLVMContext> ctx_;
};

311 312 313 314 315 316 317 318 319
unsigned LookupLLVMIntrinsic(const std::string& name) {
  return llvm::Function::lookupIntrinsicID(name);
}

TVM_REGISTER_API("codegen.llvm_lookup_intrinsic_id")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    *rv = static_cast<int64_t>(LookupLLVMIntrinsic(args[0]));
  });

320
TVM_REGISTER_API("codegen.build_llvm")
321 322 323 324 325
.set_body([](TVMArgs args, TVMRetValue* rv) {
    std::shared_ptr<LLVMModuleNode> n = std::make_shared<LLVMModuleNode>();
    n->Init(args[0], args[1]);
    *rv = runtime::Module(n);
  });
326

327 328 329 330 331 332 333
TVM_REGISTER_API("codegen.llvm_version_major")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    std::ostringstream os;
    int major = TVM_LLVM_VERSION / 10;
    *rv = major;
  });

334 335 336 337 338 339
TVM_REGISTER_API("module.loadfile_ll")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    std::shared_ptr<LLVMModuleNode> n = std::make_shared<LLVMModuleNode>();
    n->LoadIR(args[0]);
    *rv = runtime::Module(n);
  });
340 341 342 343 344 345

TVM_REGISTER_API("codegen.llvm_target_enabled")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    InitializeLLVM();
    *rv = (GetLLVMTargetMachine(args[0], true) != nullptr);
  });
346 347 348
}  // namespace codegen
}  // namespace tvm
#endif  // TVM_LLVM_VERSION