Commit 7c6a71ba by Tianqi Chen Committed by GitHub

[CODEGEN] Make codegen registerable (#193)

* [CODEGEN] Make codegen registerable

* fix llvm disbaled
parent 1400edac
......@@ -37,8 +37,7 @@ using DLTypeVector = std::vector<DLDataType>;
*/
using FTVMCompute = std::function<
Array<Tensor>
(const NodeAttrs& attrs,
const Array<Tensor>& inputs)>;
(const NodeAttrs& attrs, const Array<Tensor>& inputs)>;
/*!
* \brief Build the computation schedule for
......
/*!
* Copyright (c) 2017 by Contributors
* \file codegen_arm.cc
* \brief ARM specific code generator
*/
#ifdef TVM_LLVM_VERSION
#include "./codegen_llvm.h"
namespace tvm {
namespace codegen {
// ARM specific code generator, this is used as an example on
// how to override behavior llvm code generator for specific target
class CodeGenARM final : public CodeGenLLVM {
public:
void InitTarget(llvm::TargetMachine* tm) final {
// set native vector bits.
native_vector_bits_ = 16 * 8;
CodeGenLLVM::InitTarget(tm);
}
};
TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_arm")
.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
CodeGenLLVM* cg = new CodeGenARM();
*rv = static_cast<void*>(cg);
});
} // namespace codegen
} // namespace tvm
#endif // TVM_LLVM_VERSION
......@@ -13,6 +13,18 @@
namespace tvm {
namespace codegen {
std::unique_ptr<CodeGenLLVM> CodeGenLLVM::Create(llvm::TargetMachine *tm) {
std::string target = tm->getTarget().getName();
std::string factory_name = "tvm.codegen.llvm.target_" + target;
const PackedFunc* f = runtime::Registry::Get(factory_name);
if (f != nullptr) {
void* handle = (*f)();
return std::unique_ptr<CodeGenLLVM>(static_cast<CodeGenLLVM*>(handle));
} else {
return std::unique_ptr<CodeGenLLVM>(new CodeGenLLVM());
}
}
void CodeGenLLVM::Init(const std::string& module_name,
llvm::TargetMachine* tm,
llvm::LLVMContext* ctx) {
......@@ -93,18 +105,18 @@ void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) {
data_layout_.reset(new llvm::DataLayout(module_.get()));
// initialize native vector bits
std::string target = tm->getTarget().getName();
if (target == "arm") {
native_vector_bits_ = 16 * 8;
} else if (target == "x86-64") {
if (target == "x86-64") {
// for avx512
native_vector_bits_ = 64 * 8;
} else if (target == "x86") {
native_vector_bits_ = 32 * 8;
} else {
if (native_vector_bits_ == 0) {
native_vector_bits_ = 32 * 8;
LOG(WARNING) << "set native vector to be " << native_vector_bits_ / 8
<< " for target " << target;
}
}
}
void CodeGenLLVM::InitGlobalContext() {
......
......@@ -29,6 +29,12 @@ class CodeGenLLVM :
public StmtFunctor<void(const Stmt&)> {
public:
/*!
* \brief Create new code generator based on target machine.
* \param tm The target machine
* \return The created llvm generator.
*/
static std::unique_ptr<CodeGenLLVM> Create(llvm::TargetMachine* tm);
/*!
* \brief Initialize the code generator with given context
* \param module_name The name of the module.
* \param tm Target machine model
......@@ -136,6 +142,8 @@ class CodeGenLLVM :
// do a scalarize call with f
llvm::Value* CreateScalarizedCall(
const Call* op, llvm::Function* f, const std::vector<llvm::Value*>& args);
// Initialize target
virtual void InitTarget(llvm::TargetMachine* tm);
// apply optimization on the module.
virtual void Optimize();
// Get the maximim storage align bits of buffer pointer given storage scope.
......@@ -216,8 +224,6 @@ class CodeGenLLVM :
// if not directly finalize function and pass on return code.
// return the end block after the check
llvm::BasicBlock* CheckCallSuccess(llvm::Value* retcode);
// Initialize target
void InitTarget(llvm::TargetMachine* tm);
// Add a function to set global module context
void InitGlobalContext();
// add alias information.
......
......@@ -36,6 +36,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log")
.set_body(DispatchLLVMIntrin<::llvm::Intrinsic::log>);
} // namespace llvm
} // namespace codegen
} // namespace tvm
......
......@@ -101,14 +101,14 @@ class LLVMModuleNode final : public runtime::ModuleNode {
tm_ = GetLLVMTargetMachine(target);
CHECK_NE(funcs.size(), 0U);
ctx_ = std::make_shared<llvm::LLVMContext>();
CodeGenLLVM cg;
std::unique_ptr<CodeGenLLVM> cg = CodeGenLLVM::Create(tm_);
entry_func_ = funcs[0]->name;
cg.Init(funcs[0]->name, tm_, ctx_.get());
cg->Init(funcs[0]->name, tm_, ctx_.get());
for (LoweredFunc f : funcs) {
cg.AddFunction(f);
cg->AddFunction(f);
}
cg.AddMainFunction(funcs[0]->name);
module_ = cg.Finish();
cg->AddMainFunction(funcs[0]->name);
module_ = cg->Finish();
mptr_ = module_.get();
}
......
......@@ -26,11 +26,19 @@ class IntrinInjecter : public IRMutator {
Expr Mutate_(const Call* op, const Expr& e) final {
if (op->call_type == Call::Intrinsic ||
op->call_type == Call::PureIntrinsic) {
Expr r = ApplyPattern(op->name, e);
if (r.defined()) return r;
}
return IRMutator::Mutate_(op, e);
}
private:
Expr ApplyPattern(const std::string& name, const Expr& e) {
for (size_t i = 0; i < patterns_.size(); ++i) {
std::string& p = patterns_[i];
size_t psize = p.length();
p.resize(psize + op->name.length());
op->name.copy(&p[0] + psize, op->name.length());
p.resize(psize + name.length());
name.copy(&p[0] + psize, name.length());
const runtime::PackedFunc* f = runtime::Registry::Get(p);
p.resize(psize);
// if pattern exists.
......@@ -42,11 +50,9 @@ class IntrinInjecter : public IRMutator {
}
}
}
return Expr();
}
return IRMutator::Mutate_(op, e);
}
private:
// patterns
std::vector<std::string> patterns_;
};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment