llvm_module.cc 8.54 KB
Newer Older
1 2 3 4 5 6 7 8
/*!
 *  Copyright (c) 2017 by Contributors
 * \file llvm_module.cc
 * \brief LLVM runtime module for TVM
 */
#ifdef TVM_LLVM_VERSION
#include <tvm/runtime/packed_func.h>
#include <tvm/codegen.h>
9
#include <mutex>
10 11 12
#include "./llvm_common.h"
#include "./codegen_llvm.h"
#include "../../runtime/file_util.h"
13
#include "../../runtime/module_util.h"
14 15 16 17 18 19 20 21

namespace tvm {
namespace codegen {

using runtime::TVMArgs;
using runtime::TVMRetValue;
using runtime::PackedFunc;

22
class LLVMModuleNode final : public runtime::ModuleNode {
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
 public:
  ~LLVMModuleNode() {
    module_.reset();
    if (ee_ != nullptr) {
      ee_->runStaticConstructorsDestructors(true);
      delete ee_;
    }
  }

  const char* type_key() const {
    return "llvm";
  }

  PackedFunc GetFunction(
      const std::string& name,
      const std::shared_ptr<ModuleNode>& sptr_to_self) final {
Tianqi Chen committed
39 40 41 42 43 44 45
    if (name == "__tvm_is_system_module") {
      bool flag =
          (mptr_->getFunction("__tvm_module_startup") != nullptr);
      return PackedFunc([flag](TVMArgs args, TVMRetValue *rv) {
          * rv = flag;
        });
    }
46 47
    if (ee_ == nullptr) LazyInitJIT();
    std::lock_guard<std::mutex> lock(mutex_);
48 49
    const std::string& fname = (name == runtime::symbol::tvm_module_main ?
                                entry_func_ : name);
50

51
    BackendPackedCFunc faddr =
52
        reinterpret_cast<BackendPackedCFunc>(GetFunctionAddr(fname));
53
    if (faddr == nullptr) return PackedFunc();
nhynes committed
54
    return WrapPackedFunc(faddr, sptr_to_self);
55 56 57 58 59 60 61 62 63 64
  }

  void SaveToFile(const std::string& file_name,
                  const std::string& format) final {
    std::string fmt = runtime::GetFileFormat(file_name, format);
    std::error_code ecode;
    llvm::raw_fd_ostream dest(file_name, ecode, llvm::sys::fs::F_None);
    CHECK_EQ(ecode.value(), 0) << "Cannot open file: " << file_name
                               << " " << ecode.message();
    if (fmt == "o" || fmt == "obj") {
HungMingWu committed
65
#if TVM_LLVM_VERSION <= 60
66
      std::unique_ptr<llvm::Module> m = llvm::CloneModule(mptr_);
HungMingWu committed
67 68 69
#else
      std::unique_ptr<llvm::Module> m = llvm::CloneModule(*mptr_);
#endif
70 71 72 73 74
      llvm::legacy::PassManager pass;
      CHECK(tm_);
      CHECK(tm_->addPassesToEmitFile(
          pass, dest, llvm::TargetMachine::CGFT_ObjectFile) == 0)
          << "Cannot emit target CGFT_ObjectFile";
75
      pass.run(*m);
76
    } else if (fmt == "s" || fmt == "asm") {
HungMingWu committed
77
#if TVM_LLVM_VERSION <= 60
78
      std::unique_ptr<llvm::Module> m = llvm::CloneModule(mptr_);
HungMingWu committed
79 80 81
#else
      std::unique_ptr<llvm::Module> m = llvm::CloneModule(*mptr_);
#endif
82 83 84 85 86
      llvm::legacy::PassManager pass;
      CHECK(tm_);
      CHECK(tm_->addPassesToEmitFile(
          pass, dest, llvm::TargetMachine::CGFT_AssemblyFile) == 0)
          << "Cannot emit target CGFT_AssemblyFile";
87
      pass.run(*m);
88 89 90
    } else if (fmt == "ll") {
      mptr_->print(dest, nullptr);
    } else if (fmt == "bc") {
HungMingWu committed
91
#if TVM_LLVM_VERSION <= 60
92
      llvm::WriteBitcodeToFile(mptr_, dest);
HungMingWu committed
93 94 95
#else
      llvm::WriteBitcodeToFile(*mptr_, dest);
#endif
96 97 98 99 100 101 102
    } else {
      LOG(FATAL) << "Do not know how to save file "
                 << file_name << " with format=\'"<< format << "\'";
    }
    dest.close();
  }

103 104 105 106
  void SaveToBinary(dmlc::Stream* stream) final {
    LOG(FATAL) << "LLVMModule: SaveToBinary not supported";
  }

107 108 109 110 111 112 113 114 115 116
  std::string GetSource(const std::string& format) final {
    std::string type_str;
    llvm::raw_string_ostream rso(type_str);
    CHECK(mptr_ != nullptr);
    mptr_->print(rso, nullptr);
    return rso.str();
  }

  void Init(const Array<LoweredFunc>& funcs, std::string target) {
    InitializeLLVM();
117
    tm_ = GetLLVMTargetMachine(target);
118
    bool system_lib = (target.find("-system-lib") != std::string::npos);
119 120
    CHECK_NE(funcs.size(), 0U);
    ctx_ = std::make_shared<llvm::LLVMContext>();
121
    std::unique_ptr<CodeGenLLVM> cg = CodeGenLLVM::Create(tm_);
122
    entry_func_ = funcs[0]->name;
123
    cg->Init(funcs[0]->name, tm_, ctx_.get(), system_lib, system_lib);
124
    for (LoweredFunc f :  funcs) {
125
      cg->AddFunction(f);
126
    }
127 128
    cg->AddMainFunction(funcs[0]->name);
    module_ = cg->Finish();
129 130 131 132
    module_->addModuleFlag(
        llvm::Module::Warning, "tvm_target",
        llvm::MDString::get(*ctx_, target));
    target_ = target;
133 134 135
    mptr_ = module_.get();
  }

136 137 138 139 140
  void LoadIR(const std::string& file_name) {
    InitializeLLVM();
    ctx_ = std::make_shared<llvm::LLVMContext>();
    llvm::SMDiagnostic err;
    module_ = llvm::parseIRFile(file_name, err, *ctx_);
Tianqi Chen committed
141 142 143 144 145
    if (module_.get() == nullptr) {
      std::string msg = err.getMessage();
      LOG(FATAL) << "Fail to load ir file " << file_name << "\n"
                 << "line " << err.getLineNo() << ":" << msg;
    }
146 147 148 149 150 151 152 153 154 155 156
    std::string target_;
    llvm::Metadata* mtarget = module_->getModuleFlag("tvm_target");
    if (mtarget != nullptr) {
      llvm::MDString* pstr = llvm::dyn_cast<llvm::MDString>(mtarget);
      CHECK(pstr != nullptr);
      target_ = pstr->getString();
    } else {
      std::ostringstream os;
      os << "llvm -target " << module_->getTargetTriple();
      target_ = os.str();
    }
157
    mptr_ = module_.get();
158
    tm_ = GetLLVMTargetMachine(target_);
159 160
  }

161 162 163 164 165
 private:
  void LazyInitJIT() {
    CHECK(ee_ == nullptr);
    std::lock_guard<std::mutex> lock(mutex_);
    llvm::EngineBuilder builder(std::move(module_));
166 167 168
    std::string triple, mcpu, mattr;
    llvm::TargetOptions opt;
    ParseLLVMTargetOptions(target_, &triple, &mcpu, &mattr, &opt);
169 170
    builder.setEngineKind(llvm::EngineKind::JIT);
    builder.setOptLevel(llvm::CodeGenOpt::Aggressive);
171 172 173 174 175 176 177 178
    if (mcpu.length() != 0) {
      builder.setMCPU(mcpu);
    }
    if (mattr.length() != 0) {
      std::vector<std::string> mattrs{mattr};
      builder.setMAttrs(mattrs);
    }
    builder.setTargetOptions(opt);
179
    llvm::TargetMachine *tm = builder.selectTarget();
180 181 182 183 184 185
    llvm::TargetMachine *tm_sys = GetLLVMTargetMachine("llvm");
    if (tm_sys->getTargetTriple().getArch() != tm->getTargetTriple().getArch()) {
      LOG(FATAL) << "Cannot run module, architecture mismatch "
                 << " module=" << tm->getTargetTriple().str()
                 << " system=" << tm_sys->getTargetTriple().str();
    }
186 187 188 189 190 191 192 193
    llvm::DataLayout layout(tm->createDataLayout());
    CHECK(layout == mptr_->getDataLayout())
        << "Data layout mismatch between module("
        << mptr_->getDataLayout().getStringRepresentation() << ")"
        << " and ExecutionEngine ("
        << layout.getStringRepresentation() << ")";
    ee_ = builder.create(tm);
    CHECK(ee_ != nullptr)
194
        << "Failed to initialize git engine for " << mptr_->getTargetTriple();
195 196
    ee_->runStaticConstructorsDestructors(false);
    // setup context address.
197
    entry_func_ =
198
        reinterpret_cast<const char*>(GetGlobalAddr(runtime::symbol::tvm_module_main));
199
    if (void** ctx_addr = reinterpret_cast<void**>(
200
            GetGlobalAddr(runtime::symbol::tvm_module_ctx))) {
201 202
      *ctx_addr = this;
    }
203
    runtime::InitContextFunctions([this](const char *name) {
204
        return GetGlobalAddr(name);
205
      });
206
  }
207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224
  // Get global address from execution engine.
  uint64_t GetGlobalAddr(const std::string& name) {
    // first verifies if GV exists.
    if (mptr_->getGlobalVariable(name) != nullptr) {
      return ee_->getGlobalValueAddress(name);
    } else {
      return 0;
    }
  }
  uint64_t GetFunctionAddr(const std::string& name) {
    // first verifies if GV exists.
    if (mptr_->getFunction(name) != nullptr) {
      return ee_->getFunctionAddress(name);
    } else {
      return 0;
    }
  }

225
  // The target configuration string
226
  std::string target_;
227 228
  // Name of entry function.
  std::string entry_func_;
229 230 231 232 233 234 235 236 237 238 239 240 241 242
  // JIT lock
  std::mutex mutex_;
  // execution engine
  llvm::ExecutionEngine *ee_{nullptr};
  // The raw pointer to the module.
  llvm::Module* mptr_{nullptr};
  // The target machine
  llvm::TargetMachine* tm_{nullptr};
  // The module, can be moved to ee if JIT is enabled.
  std::unique_ptr<llvm::Module> module_;
  // the context.
  std::shared_ptr<llvm::LLVMContext> ctx_;
};

243
TVM_REGISTER_API("codegen.build_llvm")
244 245 246 247 248
.set_body([](TVMArgs args, TVMRetValue* rv) {
    std::shared_ptr<LLVMModuleNode> n = std::make_shared<LLVMModuleNode>();
    n->Init(args[0], args[1]);
    *rv = runtime::Module(n);
  });
249 250 251 252 253 254 255

TVM_REGISTER_API("module.loadfile_ll")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    std::shared_ptr<LLVMModuleNode> n = std::make_shared<LLVMModuleNode>();
    n->LoadIR(args[0]);
    *rv = runtime::Module(n);
  });
256 257 258 259 260 261

TVM_REGISTER_API("codegen.llvm_target_enabled")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    InitializeLLVM();
    *rv = (GetLLVMTargetMachine(args[0], true) != nullptr);
  });
262 263 264
}  // namespace codegen
}  // namespace tvm
#endif  // TVM_LLVM_VERSION