llvm_module.cc 11.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27
/*!
 *  Copyright (c) 2017 by Contributors
 * \file llvm_module.cc
 * \brief LLVM runtime module for TVM
 */
#ifdef TVM_LLVM_VERSION
#include <tvm/runtime/packed_func.h>
#include <tvm/codegen.h>
28
#include <mutex>
29 30
#include "llvm_common.h"
#include "codegen_llvm.h"
31
#include "../../runtime/file_util.h"
32
#include "../../runtime/module_util.h"
33 34 35 36 37 38 39 40

namespace tvm {
namespace codegen {

using runtime::TVMArgs;
using runtime::TVMRetValue;
using runtime::PackedFunc;

41
class LLVMModuleNode final : public runtime::ModuleNode {
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
 public:
  ~LLVMModuleNode() {
    module_.reset();
    if (ee_ != nullptr) {
      ee_->runStaticConstructorsDestructors(true);
      delete ee_;
    }
  }

  const char* type_key() const {
    return "llvm";
  }

  PackedFunc GetFunction(
      const std::string& name,
      const std::shared_ptr<ModuleNode>& sptr_to_self) final {
Tianqi Chen committed
58 59 60 61 62 63 64
    if (name == "__tvm_is_system_module") {
      bool flag =
          (mptr_->getFunction("__tvm_module_startup") != nullptr);
      return PackedFunc([flag](TVMArgs args, TVMRetValue *rv) {
          * rv = flag;
        });
    }
65 66
    if (ee_ == nullptr) LazyInitJIT();
    std::lock_guard<std::mutex> lock(mutex_);
67 68
    const std::string& fname = (name == runtime::symbol::tvm_module_main ?
                                entry_func_ : name);
69

70
    BackendPackedCFunc faddr =
71
        reinterpret_cast<BackendPackedCFunc>(GetFunctionAddr(fname));
72
    if (faddr == nullptr) return PackedFunc();
nhynes committed
73
    return WrapPackedFunc(faddr, sptr_to_self);
74 75 76 77 78 79 80 81 82 83
  }

  void SaveToFile(const std::string& file_name,
                  const std::string& format) final {
    std::string fmt = runtime::GetFileFormat(file_name, format);
    std::error_code ecode;
    llvm::raw_fd_ostream dest(file_name, ecode, llvm::sys::fs::F_None);
    CHECK_EQ(ecode.value(), 0) << "Cannot open file: " << file_name
                               << " " << ecode.message();
    if (fmt == "o" || fmt == "obj") {
HungMingWu committed
84
#if TVM_LLVM_VERSION <= 60
85
      std::unique_ptr<llvm::Module> m = llvm::CloneModule(mptr_);
HungMingWu committed
86 87 88
#else
      std::unique_ptr<llvm::Module> m = llvm::CloneModule(*mptr_);
#endif
89 90
      llvm::legacy::PassManager pass;
      CHECK(tm_);
91
#if TVM_LLVM_VERSION <= 60
92 93 94
      CHECK(tm_->addPassesToEmitFile(
          pass, dest, llvm::TargetMachine::CGFT_ObjectFile) == 0)
          << "Cannot emit target CGFT_ObjectFile";
95 96 97 98 99
#else
      CHECK(tm_->addPassesToEmitFile(
          pass, dest, nullptr, llvm::TargetMachine::CGFT_ObjectFile) == 0)
          << "Cannot emit target CGFT_ObjectFile";
#endif
100
      pass.run(*m);
101
    } else if (fmt == "s" || fmt == "asm") {
HungMingWu committed
102
#if TVM_LLVM_VERSION <= 60
103
      std::unique_ptr<llvm::Module> m = llvm::CloneModule(mptr_);
HungMingWu committed
104 105 106
#else
      std::unique_ptr<llvm::Module> m = llvm::CloneModule(*mptr_);
#endif
107 108
      llvm::legacy::PassManager pass;
      CHECK(tm_);
109
#if TVM_LLVM_VERSION <= 60
110 111 112
      CHECK(tm_->addPassesToEmitFile(
          pass, dest, llvm::TargetMachine::CGFT_AssemblyFile) == 0)
          << "Cannot emit target CGFT_AssemblyFile";
113 114 115 116 117
#else
      CHECK(tm_->addPassesToEmitFile(
          pass, dest, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == 0)
          << "Cannot emit target CGFT_AssemblyFile";
#endif
118
      pass.run(*m);
119 120 121
    } else if (fmt == "ll") {
      mptr_->print(dest, nullptr);
    } else if (fmt == "bc") {
HungMingWu committed
122
#if TVM_LLVM_VERSION <= 60
123
      llvm::WriteBitcodeToFile(mptr_, dest);
HungMingWu committed
124 125 126
#else
      llvm::WriteBitcodeToFile(*mptr_, dest);
#endif
127 128 129 130 131 132 133
    } else {
      LOG(FATAL) << "Do not know how to save file "
                 << file_name << " with format=\'"<< format << "\'";
    }
    dest.close();
  }

134 135 136 137
  void SaveToBinary(dmlc::Stream* stream) final {
    LOG(FATAL) << "LLVMModule: SaveToBinary not supported";
  }

138
  std::string GetSource(const std::string& format) final {
139
    std::string fmt = runtime::GetFileFormat("", format);
140
    std::string type_str;
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
    llvm::SmallString<256> str;
    llvm::raw_svector_ostream rso(str);

    if (fmt == "s" || fmt == "asm") {
    #if TVM_LLVM_VERSION <= 60
          std::unique_ptr<llvm::Module> m = llvm::CloneModule(mptr_);
    #else
          std::unique_ptr<llvm::Module> m = llvm::CloneModule(*mptr_);
    #endif
          llvm::legacy::PassManager pass;
          CHECK(tm_);
    #if TVM_LLVM_VERSION <= 60
          CHECK(tm_->addPassesToEmitFile(
              pass, rso, llvm::TargetMachine::CGFT_AssemblyFile) == 0)
              << "Cannot emit target CGFT_AssemblyFile";
    #else
          CHECK(tm_->addPassesToEmitFile(
              pass, rso, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == 0)
              << "Cannot emit target CGFT_AssemblyFile";
    #endif
          pass.run(*m);
          return rso.str().str();
    } else if (fmt == "" || fmt == "ll") {
      std::string type_str;
      llvm::raw_string_ostream rso(type_str);
      CHECK(mptr_ != nullptr);
      mptr_->print(rso, nullptr);
      return rso.str();
    } else {
      LOG(FATAL) << "Do not know how to get source code with format: "
                 << format << "\'";
    }
    return "";
174 175 176 177
  }

  void Init(const Array<LoweredFunc>& funcs, std::string target) {
    InitializeLLVM();
178
    tm_ = GetLLVMTargetMachine(target);
179
    bool system_lib = (target.find("-system-lib") != std::string::npos);
180 181
    CHECK_NE(funcs.size(), 0U);
    ctx_ = std::make_shared<llvm::LLVMContext>();
182
    std::unique_ptr<CodeGenLLVM> cg = CodeGenLLVM::Create(tm_.get());
183
    entry_func_ = funcs[0]->name;
184
    cg->Init(funcs[0]->name, tm_.get(), ctx_.get(), system_lib, system_lib);
185
    for (LoweredFunc f :  funcs) {
186
      cg->AddFunction(f);
187
    }
188 189
    cg->AddMainFunction(funcs[0]->name);
    module_ = cg->Finish();
190 191 192 193 194
    std::string verify_errors_storage;
    llvm::raw_string_ostream verify_errors(verify_errors_storage);
    LOG_IF(FATAL, llvm::verifyModule(*module_, &verify_errors))
        << "LLVM module verification failed with the following errors: \n"
        << verify_errors.str();
195 196 197 198
    module_->addModuleFlag(
        llvm::Module::Warning, "tvm_target",
        llvm::MDString::get(*ctx_, target));
    target_ = target;
199 200 201
    mptr_ = module_.get();
  }

202 203 204 205 206
  void LoadIR(const std::string& file_name) {
    InitializeLLVM();
    ctx_ = std::make_shared<llvm::LLVMContext>();
    llvm::SMDiagnostic err;
    module_ = llvm::parseIRFile(file_name, err, *ctx_);
Tianqi Chen committed
207 208 209 210 211
    if (module_.get() == nullptr) {
      std::string msg = err.getMessage();
      LOG(FATAL) << "Fail to load ir file " << file_name << "\n"
                 << "line " << err.getLineNo() << ":" << msg;
    }
212 213 214 215 216 217 218 219 220 221 222
    std::string target_;
    llvm::Metadata* mtarget = module_->getModuleFlag("tvm_target");
    if (mtarget != nullptr) {
      llvm::MDString* pstr = llvm::dyn_cast<llvm::MDString>(mtarget);
      CHECK(pstr != nullptr);
      target_ = pstr->getString();
    } else {
      std::ostringstream os;
      os << "llvm -target " << module_->getTargetTriple();
      target_ = os.str();
    }
223
    mptr_ = module_.get();
224
    tm_ = GetLLVMTargetMachine(target_);
225 226
  }

227 228 229 230 231
 private:
  void LazyInitJIT() {
    CHECK(ee_ == nullptr);
    std::lock_guard<std::mutex> lock(mutex_);
    llvm::EngineBuilder builder(std::move(module_));
232 233 234
    std::string triple, mcpu, mattr;
    llvm::TargetOptions opt;
    ParseLLVMTargetOptions(target_, &triple, &mcpu, &mattr, &opt);
235 236
    builder.setEngineKind(llvm::EngineKind::JIT);
    builder.setOptLevel(llvm::CodeGenOpt::Aggressive);
237 238 239 240 241 242 243 244
    if (mcpu.length() != 0) {
      builder.setMCPU(mcpu);
    }
    if (mattr.length() != 0) {
      std::vector<std::string> mattrs{mattr};
      builder.setMAttrs(mattrs);
    }
    builder.setTargetOptions(opt);
245 246
    auto tm = std::unique_ptr<llvm::TargetMachine>(builder.selectTarget());
    std::unique_ptr<llvm::TargetMachine> tm_sys = GetLLVMTargetMachine("llvm");
247 248 249 250 251
    if (tm_sys->getTargetTriple().getArch() != tm->getTargetTriple().getArch()) {
      LOG(FATAL) << "Cannot run module, architecture mismatch "
                 << " module=" << tm->getTargetTriple().str()
                 << " system=" << tm_sys->getTargetTriple().str();
    }
252 253 254 255 256 257
    llvm::DataLayout layout(tm->createDataLayout());
    CHECK(layout == mptr_->getDataLayout())
        << "Data layout mismatch between module("
        << mptr_->getDataLayout().getStringRepresentation() << ")"
        << " and ExecutionEngine ("
        << layout.getStringRepresentation() << ")";
258
    ee_ = builder.create(tm.release());
259
    CHECK(ee_ != nullptr)
260
        << "Failed to initialize git engine for " << mptr_->getTargetTriple();
261 262
    ee_->runStaticConstructorsDestructors(false);
    // setup context address.
263
    entry_func_ =
264
        reinterpret_cast<const char*>(GetGlobalAddr(runtime::symbol::tvm_module_main));
265
    if (void** ctx_addr = reinterpret_cast<void**>(
266
            GetGlobalAddr(runtime::symbol::tvm_module_ctx))) {
267 268
      *ctx_addr = this;
    }
269
    runtime::InitContextFunctions([this](const char *name) {
270
        return GetGlobalAddr(name);
271
      });
272
  }
273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290
  // Get global address from execution engine.
  uint64_t GetGlobalAddr(const std::string& name) {
    // first verifies if GV exists.
    if (mptr_->getGlobalVariable(name) != nullptr) {
      return ee_->getGlobalValueAddress(name);
    } else {
      return 0;
    }
  }
  uint64_t GetFunctionAddr(const std::string& name) {
    // first verifies if GV exists.
    if (mptr_->getFunction(name) != nullptr) {
      return ee_->getFunctionAddress(name);
    } else {
      return 0;
    }
  }

291
  // The target configuration string
292
  std::string target_;
293 294
  // Name of entry function.
  std::string entry_func_;
295 296 297 298 299 300 301
  // JIT lock
  std::mutex mutex_;
  // execution engine
  llvm::ExecutionEngine *ee_{nullptr};
  // The raw pointer to the module.
  llvm::Module* mptr_{nullptr};
  // The target machine
302
  std::unique_ptr<llvm::TargetMachine> tm_{nullptr};
303 304 305 306 307 308
  // The module, can be moved to ee if JIT is enabled.
  std::unique_ptr<llvm::Module> module_;
  // the context.
  std::shared_ptr<llvm::LLVMContext> ctx_;
};

309 310 311 312 313 314 315 316 317
unsigned LookupLLVMIntrinsic(const std::string& name) {
  return llvm::Function::lookupIntrinsicID(name);
}

TVM_REGISTER_API("codegen.llvm_lookup_intrinsic_id")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    *rv = static_cast<int64_t>(LookupLLVMIntrinsic(args[0]));
  });

318
TVM_REGISTER_API("codegen.build_llvm")
319 320 321 322 323
.set_body([](TVMArgs args, TVMRetValue* rv) {
    std::shared_ptr<LLVMModuleNode> n = std::make_shared<LLVMModuleNode>();
    n->Init(args[0], args[1]);
    *rv = runtime::Module(n);
  });
324

325 326 327 328 329 330 331
TVM_REGISTER_API("codegen.llvm_version_major")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    std::ostringstream os;
    int major = TVM_LLVM_VERSION / 10;
    *rv = major;
  });

332 333 334 335 336 337
TVM_REGISTER_API("module.loadfile_ll")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    std::shared_ptr<LLVMModuleNode> n = std::make_shared<LLVMModuleNode>();
    n->LoadIR(args[0]);
    *rv = runtime::Module(n);
  });
338 339 340 341 342 343

TVM_REGISTER_API("codegen.llvm_target_enabled")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    InitializeLLVM();
    *rv = (GetLLVMTargetMachine(args[0], true) != nullptr);
  });
344 345 346
}  // namespace codegen
}  // namespace tvm
#endif  // TVM_LLVM_VERSION