Commit ee6f22ab by Tianqi Chen Committed by GitHub

[LLVM] Enable same target option in JITModule (#778)

* [LLVM] Enable same target option in JITModule

* not set mcpu explicitly
parent be457348
...@@ -36,9 +36,11 @@ void InitializeLLVM() { ...@@ -36,9 +36,11 @@ void InitializeLLVM() {
} }
} }
llvm::TargetMachine* void ParseLLVMTargetOptions(const std::string& target_str,
GetLLVMTargetMachine(const std::string& target_str, std::string* triple,
bool allow_null) { std::string* mcpu,
std::string* mattr,
llvm::TargetOptions* options) {
// setup target triple // setup target triple
size_t start = 0; size_t start = 0;
if (target_str.length() >= 4 && if (target_str.length() >= 4 &&
...@@ -46,9 +48,10 @@ GetLLVMTargetMachine(const std::string& target_str, ...@@ -46,9 +48,10 @@ GetLLVMTargetMachine(const std::string& target_str,
start = 4; start = 4;
} }
// simple parser // simple parser
std::string target_triple = ""; triple->resize(0);
std::string cpu = "generic"; mcpu->resize(0);
std::string attr = ""; mattr->resize(0);
bool soft_float_abi = false; bool soft_float_abi = false;
std::string key, value; std::string key, value;
std::istringstream is(target_str.substr(start, target_str.length() - start)); std::istringstream is(target_str.substr(start, target_str.length() - start));
...@@ -69,11 +72,11 @@ GetLLVMTargetMachine(const std::string& target_str, ...@@ -69,11 +72,11 @@ GetLLVMTargetMachine(const std::string& target_str,
} }
if (key == "-target" || if (key == "-target" ||
key == "-mtriple") { key == "-mtriple") {
target_triple = value; *triple = value;
} else if (key == "-mcpu") { } else if (key == "-mcpu") {
cpu = value; *mcpu = value;
} else if (key == "-mattr") { } else if (key == "-mattr") {
attr = value; *mattr = value;
} else if (key == "-mfloat-abi") { } else if (key == "-mfloat-abi") {
if (value == "hard") { if (value == "hard") {
soft_float_abi = false; soft_float_abi = false;
...@@ -89,19 +92,13 @@ GetLLVMTargetMachine(const std::string& target_str, ...@@ -89,19 +92,13 @@ GetLLVMTargetMachine(const std::string& target_str,
} }
} }
if (target_triple.length() == 0 || if (triple->length() == 0 ||
target_triple == "default") { *triple == "default") {
target_triple = llvm::sys::getDefaultTargetTriple(); *triple = llvm::sys::getDefaultTargetTriple();
}
std::string err;
const llvm::Target* target =
llvm::TargetRegistry::lookupTarget(target_triple, err);
if (target == nullptr) {
CHECK(allow_null) << err << " target_triple=" << target_triple;
return nullptr;
} }
// set target option // set target option
llvm::TargetOptions opt; llvm::TargetOptions& opt = *options;
opt = llvm::TargetOptions();
#if TVM_LLVM_VERSION < 50 #if TVM_LLVM_VERSION < 50
opt.LessPreciseFPMADOption = true; opt.LessPreciseFPMADOption = true;
#endif #endif
...@@ -114,8 +111,38 @@ GetLLVMTargetMachine(const std::string& target_str, ...@@ -114,8 +111,38 @@ GetLLVMTargetMachine(const std::string& target_str,
} else { } else {
opt.FloatABIType = llvm::FloatABI::Hard; opt.FloatABIType = llvm::FloatABI::Hard;
} }
}
llvm::TargetMachine*
GetLLVMTargetMachine(const std::string& target_str,
bool allow_null) {
std::string target_triple, mcpu, mattr;
llvm::TargetOptions opt;
ParseLLVMTargetOptions(target_str,
&target_triple,
&mcpu,
&mattr,
&opt);
if (target_triple.length() == 0 ||
target_triple == "default") {
target_triple = llvm::sys::getDefaultTargetTriple();
}
if (mcpu.length() == 0) {
mcpu = "generic";
}
std::string err;
const llvm::Target* target =
llvm::TargetRegistry::lookupTarget(target_triple, err);
if (target == nullptr) {
CHECK(allow_null) << err << " target_triple=" << target_triple;
return nullptr;
}
llvm::TargetMachine* tm = target->createTargetMachine( llvm::TargetMachine* tm = target->createTargetMachine(
target_triple, cpu, attr, opt, llvm::Reloc::PIC_); target_triple, mcpu, mattr, opt, llvm::Reloc::PIC_);
return tm; return tm;
} }
......
...@@ -58,6 +58,20 @@ namespace codegen { ...@@ -58,6 +58,20 @@ namespace codegen {
void InitializeLLVM(); void InitializeLLVM();
/*! /*!
* \brief Parse target options
* \param target_str Target string, in format "llvm -target=xxx -mcpu=xxx"
* \param triple Target triple
* \param mcpu cpu info
* \param options the options
* \param mattr The attributes
*/
void ParseLLVMTargetOptions(const std::string& target_str,
std::string* triple,
std::string* mcpu,
std::string* mattr,
llvm::TargetOptions* options);
/*!
* \brief Get target machine from target_str string. * \brief Get target machine from target_str string.
* \param target_str Target string, in format "llvm -target=xxx -mcpu=xxx" * \param target_str Target string, in format "llvm -target=xxx -mcpu=xxx"
* \param allow_null Whether allow null to be returned. * \param allow_null Whether allow null to be returned.
......
...@@ -120,6 +120,10 @@ class LLVMModuleNode final : public runtime::ModuleNode { ...@@ -120,6 +120,10 @@ class LLVMModuleNode final : public runtime::ModuleNode {
} }
cg->AddMainFunction(funcs[0]->name); cg->AddMainFunction(funcs[0]->name);
module_ = cg->Finish(); module_ = cg->Finish();
module_->addModuleFlag(
llvm::Module::Warning, "tvm_target",
llvm::MDString::get(*ctx_, target));
target_ = target;
mptr_ = module_.get(); mptr_ = module_.get();
} }
...@@ -133,11 +137,19 @@ class LLVMModuleNode final : public runtime::ModuleNode { ...@@ -133,11 +137,19 @@ class LLVMModuleNode final : public runtime::ModuleNode {
LOG(FATAL) << "Fail to load ir file " << file_name << "\n" LOG(FATAL) << "Fail to load ir file " << file_name << "\n"
<< "line " << err.getLineNo() << ":" << msg; << "line " << err.getLineNo() << ":" << msg;
} }
std::string target = module_->getTargetTriple(); std::string target_;
llvm::Metadata* mtarget = module_->getModuleFlag("tvm_target");
if (mtarget != nullptr) {
llvm::MDString* pstr = llvm::dyn_cast<llvm::MDString>(mtarget);
CHECK(pstr != nullptr);
target_ = pstr->getString();
} else {
std::ostringstream os;
os << "llvm -target " << module_->getTargetTriple();
target_ = os.str();
}
mptr_ = module_.get(); mptr_ = module_.get();
std::ostringstream os; tm_ = GetLLVMTargetMachine(target_);
os << "llvm -target " << target;
tm_ = GetLLVMTargetMachine(os.str());
} }
private: private:
...@@ -145,8 +157,19 @@ class LLVMModuleNode final : public runtime::ModuleNode { ...@@ -145,8 +157,19 @@ class LLVMModuleNode final : public runtime::ModuleNode {
CHECK(ee_ == nullptr); CHECK(ee_ == nullptr);
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
llvm::EngineBuilder builder(std::move(module_)); llvm::EngineBuilder builder(std::move(module_));
std::string triple, mcpu, mattr;
llvm::TargetOptions opt;
ParseLLVMTargetOptions(target_, &triple, &mcpu, &mattr, &opt);
builder.setEngineKind(llvm::EngineKind::JIT); builder.setEngineKind(llvm::EngineKind::JIT);
builder.setOptLevel(llvm::CodeGenOpt::Aggressive); builder.setOptLevel(llvm::CodeGenOpt::Aggressive);
if (mcpu.length() != 0) {
builder.setMCPU(mcpu);
}
if (mattr.length() != 0) {
std::vector<std::string> mattrs{mattr};
builder.setMAttrs(mattrs);
}
builder.setTargetOptions(opt);
llvm::TargetMachine *tm = builder.selectTarget(); llvm::TargetMachine *tm = builder.selectTarget();
llvm::TargetMachine *tm_sys = GetLLVMTargetMachine("llvm"); llvm::TargetMachine *tm_sys = GetLLVMTargetMachine("llvm");
if (tm_sys->getTargetTriple().getArch() != tm->getTargetTriple().getArch()) { if (tm_sys->getTargetTriple().getArch() != tm->getTargetTriple().getArch()) {
......
...@@ -38,9 +38,9 @@ def verify_binary_dense(batch, in_dim, out_dim): ...@@ -38,9 +38,9 @@ def verify_binary_dense(batch, in_dim, out_dim):
bnn_a = tvm.nd.array(np.zeros(get_const_tuple(bnn_A.shape), dtype=bnn_A.dtype), ctx) bnn_a = tvm.nd.array(np.zeros(get_const_tuple(bnn_A.shape), dtype=bnn_A.dtype), ctx)
bnn_b = tvm.nd.array(np.zeros(get_const_tuple(bnn_B.shape), dtype=bnn_B.dtype), ctx) bnn_b = tvm.nd.array(np.zeros(get_const_tuple(bnn_B.shape), dtype=bnn_B.dtype), ctx)
bnn_c = tvm.nd.array(np.zeros(get_const_tuple(bnn_C.shape), dtype=bnn_C.dtype), ctx) bnn_c = tvm.nd.array(np.zeros(get_const_tuple(bnn_C.shape), dtype=bnn_C.dtype), ctx)
f1 = tvm.build(s1, [A, bnn_A], 'llvm -mcpu=core-avx2') f1 = tvm.build(s1, [A, bnn_A], 'llvm')
f2 = tvm.build(s2, [B, bnn_B], 'llvm -mcpu=core-avx2') f2 = tvm.build(s2, [B, bnn_B], 'llvm')
f3 = tvm.build(s3, [bnn_A1, bnn_B1, bnn_C], 'llvm -mcpu=core-avx2') f3 = tvm.build(s3, [bnn_A1, bnn_B1, bnn_C], 'llvm')
f1(a, bnn_a) f1(a, bnn_a)
f2(b, bnn_b) f2(b, bnn_b)
f3(bnn_a, bnn_b, bnn_c) f3(bnn_a, bnn_b, bnn_c)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment