Commit 7c6a71ba by Tianqi Chen Committed by GitHub

[CODEGEN] Make codegen registerable (#193)

* [CODEGEN] Make codegen registerable

* fix llvm disbaled
parent 1400edac
...@@ -37,8 +37,7 @@ using DLTypeVector = std::vector<DLDataType>; ...@@ -37,8 +37,7 @@ using DLTypeVector = std::vector<DLDataType>;
*/ */
using FTVMCompute = std::function< using FTVMCompute = std::function<
Array<Tensor> Array<Tensor>
(const NodeAttrs& attrs, (const NodeAttrs& attrs, const Array<Tensor>& inputs)>;
const Array<Tensor>& inputs)>;
/*! /*!
* \brief Build the computation schedule for * \brief Build the computation schedule for
......
/*!
* Copyright (c) 2017 by Contributors
* \file codegen_arm.cc
* \brief ARM specific code generator
*/
#ifdef TVM_LLVM_VERSION
#include "./codegen_llvm.h"
namespace tvm {
namespace codegen {
// ARM specific code generator, this is used as an example on
// how to override behavior llvm code generator for specific target
class CodeGenARM final : public CodeGenLLVM {
public:
void InitTarget(llvm::TargetMachine* tm) final {
// set native vector bits.
native_vector_bits_ = 16 * 8;
CodeGenLLVM::InitTarget(tm);
}
};
TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_arm")
.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
CodeGenLLVM* cg = new CodeGenARM();
*rv = static_cast<void*>(cg);
});
} // namespace codegen
} // namespace tvm
#endif // TVM_LLVM_VERSION
...@@ -13,6 +13,18 @@ ...@@ -13,6 +13,18 @@
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
std::unique_ptr<CodeGenLLVM> CodeGenLLVM::Create(llvm::TargetMachine *tm) {
std::string target = tm->getTarget().getName();
std::string factory_name = "tvm.codegen.llvm.target_" + target;
const PackedFunc* f = runtime::Registry::Get(factory_name);
if (f != nullptr) {
void* handle = (*f)();
return std::unique_ptr<CodeGenLLVM>(static_cast<CodeGenLLVM*>(handle));
} else {
return std::unique_ptr<CodeGenLLVM>(new CodeGenLLVM());
}
}
void CodeGenLLVM::Init(const std::string& module_name, void CodeGenLLVM::Init(const std::string& module_name,
llvm::TargetMachine* tm, llvm::TargetMachine* tm,
llvm::LLVMContext* ctx) { llvm::LLVMContext* ctx) {
...@@ -93,17 +105,17 @@ void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) { ...@@ -93,17 +105,17 @@ void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) {
data_layout_.reset(new llvm::DataLayout(module_.get())); data_layout_.reset(new llvm::DataLayout(module_.get()));
// initialize native vector bits // initialize native vector bits
std::string target = tm->getTarget().getName(); std::string target = tm->getTarget().getName();
if (target == "arm") { if (target == "x86-64") {
native_vector_bits_ = 16 * 8;
} else if (target == "x86-64") {
// for avx512 // for avx512
native_vector_bits_ = 64 * 8; native_vector_bits_ = 64 * 8;
} else if (target == "x86") { } else if (target == "x86") {
native_vector_bits_ = 32 * 8; native_vector_bits_ = 32 * 8;
} else { } else {
native_vector_bits_ = 32 * 8; if (native_vector_bits_ == 0) {
LOG(WARNING) << "set native vector to be " << native_vector_bits_ / 8 native_vector_bits_ = 32 * 8;
<< " for target " << target; LOG(WARNING) << "set native vector to be " << native_vector_bits_ / 8
<< " for target " << target;
}
} }
} }
......
...@@ -29,6 +29,12 @@ class CodeGenLLVM : ...@@ -29,6 +29,12 @@ class CodeGenLLVM :
public StmtFunctor<void(const Stmt&)> { public StmtFunctor<void(const Stmt&)> {
public: public:
/*! /*!
* \brief Create new code generator based on target machine.
* \param tm The target machine
* \return The created llvm generator.
*/
static std::unique_ptr<CodeGenLLVM> Create(llvm::TargetMachine* tm);
/*!
* \brief Initialize the code generator with given context * \brief Initialize the code generator with given context
* \param module_name The name of the module. * \param module_name The name of the module.
* \param tm Target machine model * \param tm Target machine model
...@@ -136,6 +142,8 @@ class CodeGenLLVM : ...@@ -136,6 +142,8 @@ class CodeGenLLVM :
// do a scalarize call with f // do a scalarize call with f
llvm::Value* CreateScalarizedCall( llvm::Value* CreateScalarizedCall(
const Call* op, llvm::Function* f, const std::vector<llvm::Value*>& args); const Call* op, llvm::Function* f, const std::vector<llvm::Value*>& args);
// Initialize target
virtual void InitTarget(llvm::TargetMachine* tm);
// apply optimization on the module. // apply optimization on the module.
virtual void Optimize(); virtual void Optimize();
// Get the maximim storage align bits of buffer pointer given storage scope. // Get the maximim storage align bits of buffer pointer given storage scope.
...@@ -216,8 +224,6 @@ class CodeGenLLVM : ...@@ -216,8 +224,6 @@ class CodeGenLLVM :
// if not directly finalize function and pass on return code. // if not directly finalize function and pass on return code.
// return the end block after the check // return the end block after the check
llvm::BasicBlock* CheckCallSuccess(llvm::Value* retcode); llvm::BasicBlock* CheckCallSuccess(llvm::Value* retcode);
// Initialize target
void InitTarget(llvm::TargetMachine* tm);
// Add a function to set global module context // Add a function to set global module context
void InitGlobalContext(); void InitGlobalContext();
// add alias information. // add alias information.
......
...@@ -36,6 +36,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp") ...@@ -36,6 +36,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log") TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log")
.set_body(DispatchLLVMIntrin<::llvm::Intrinsic::log>); .set_body(DispatchLLVMIntrin<::llvm::Intrinsic::log>);
} // namespace llvm } // namespace llvm
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
......
...@@ -101,14 +101,14 @@ class LLVMModuleNode final : public runtime::ModuleNode { ...@@ -101,14 +101,14 @@ class LLVMModuleNode final : public runtime::ModuleNode {
tm_ = GetLLVMTargetMachine(target); tm_ = GetLLVMTargetMachine(target);
CHECK_NE(funcs.size(), 0U); CHECK_NE(funcs.size(), 0U);
ctx_ = std::make_shared<llvm::LLVMContext>(); ctx_ = std::make_shared<llvm::LLVMContext>();
CodeGenLLVM cg; std::unique_ptr<CodeGenLLVM> cg = CodeGenLLVM::Create(tm_);
entry_func_ = funcs[0]->name; entry_func_ = funcs[0]->name;
cg.Init(funcs[0]->name, tm_, ctx_.get()); cg->Init(funcs[0]->name, tm_, ctx_.get());
for (LoweredFunc f : funcs) { for (LoweredFunc f : funcs) {
cg.AddFunction(f); cg->AddFunction(f);
} }
cg.AddMainFunction(funcs[0]->name); cg->AddMainFunction(funcs[0]->name);
module_ = cg.Finish(); module_ = cg->Finish();
mptr_ = module_.get(); mptr_ = module_.get();
} }
......
...@@ -26,27 +26,33 @@ class IntrinInjecter : public IRMutator { ...@@ -26,27 +26,33 @@ class IntrinInjecter : public IRMutator {
Expr Mutate_(const Call* op, const Expr& e) final { Expr Mutate_(const Call* op, const Expr& e) final {
if (op->call_type == Call::Intrinsic || if (op->call_type == Call::Intrinsic ||
op->call_type == Call::PureIntrinsic) { op->call_type == Call::PureIntrinsic) {
for (size_t i = 0; i < patterns_.size(); ++i) { Expr r = ApplyPattern(op->name, e);
std::string& p = patterns_[i]; if (r.defined()) return r;
size_t psize = p.length();
p.resize(psize + op->name.length());
op->name.copy(&p[0] + psize, op->name.length());
const runtime::PackedFunc* f = runtime::Registry::Get(p);
p.resize(psize);
// if pattern exists.
if (f != nullptr) {
Expr r = (*f)(e);
CHECK(r.defined()) << "intrinsic rule must always return valid Expr";
if (!r.same_as(e)) {
return this->Mutate(r);
}
}
}
} }
return IRMutator::Mutate_(op, e); return IRMutator::Mutate_(op, e);
} }
private: private:
Expr ApplyPattern(const std::string& name, const Expr& e) {
for (size_t i = 0; i < patterns_.size(); ++i) {
std::string& p = patterns_[i];
size_t psize = p.length();
p.resize(psize + name.length());
name.copy(&p[0] + psize, name.length());
const runtime::PackedFunc* f = runtime::Registry::Get(p);
p.resize(psize);
// if pattern exists.
if (f != nullptr) {
Expr r = (*f)(e);
CHECK(r.defined()) << "intrinsic rule must always return valid Expr";
if (!r.same_as(e)) {
return this->Mutate(r);
}
}
}
return Expr();
}
// patterns
std::vector<std::string> patterns_; std::vector<std::string> patterns_;
}; };
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment