Commit 0560e156 by Tianqi Chen Committed by GitHub

[CODEGEN] NVPTX backend. (#392)

* [CODEGEN] NVPTX backend.

* Fix pylint

* use fix
parent efafa1a0
...@@ -55,7 +55,10 @@ def context(dev_type, dev_id=0): ...@@ -55,7 +55,10 @@ def context(dev_type, dev_id=0):
assert tvm.context("cuda", 0) == tvm.gpu(0) assert tvm.context("cuda", 0) == tvm.gpu(0)
""" """
if isinstance(dev_type, string_types): if isinstance(dev_type, string_types):
if not dev_type in TVMContext.STR2MASK: if dev_type not in TVMContext.STR2MASK:
if dev_type.find("nvptx") != -1:
dev_type = "cuda"
if dev_type not in TVMContext.STR2MASK:
raise ValueError("Unknown device type %s" % dev_type) raise ValueError("Unknown device type %s" % dev_type)
dev_type = TVMContext.STR2MASK[dev_type] dev_type = TVMContext.STR2MASK[dev_type]
return TVMContext(dev_type, dev_id) return TVMContext(dev_type, dev_id)
......
...@@ -62,6 +62,7 @@ void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) { ...@@ -62,6 +62,7 @@ void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) {
module_->setTargetTriple(tm->getTargetTriple().str()); module_->setTargetTriple(tm->getTargetTriple().str());
module_->setDataLayout(tm->createDataLayout()); module_->setDataLayout(tm->createDataLayout());
data_layout_.reset(new llvm::DataLayout(module_.get())); data_layout_.reset(new llvm::DataLayout(module_.get()));
target_machine_ = tm;
// initialize native vector bits // initialize native vector bits
std::string target = tm->getTarget().getName(); std::string target = tm->getTarget().getName();
if (target == "x86-64") { if (target == "x86-64") {
...@@ -86,6 +87,10 @@ void CodeGenLLVM::InitFuncState() { ...@@ -86,6 +87,10 @@ void CodeGenLLVM::InitFuncState() {
} }
void CodeGenLLVM::AddFunction(const LoweredFunc& f) { void CodeGenLLVM::AddFunction(const LoweredFunc& f) {
AddFunctionInternal(f, false);
}
void CodeGenLLVM::AddFunctionInternal(const LoweredFunc& f, bool ret_void) {
this->InitFuncState(); this->InitFuncState();
is_restricted_ = f->is_restricted; is_restricted_ = f->is_restricted;
CHECK(!module_->getFunction(f->name)) CHECK(!module_->getFunction(f->name))
...@@ -103,7 +108,9 @@ void CodeGenLLVM::AddFunction(const LoweredFunc& f) { ...@@ -103,7 +108,9 @@ void CodeGenLLVM::AddFunction(const LoweredFunc& f) {
arg_type.push_back(LLVMType(t)); arg_type.push_back(LLVMType(t));
} }
} }
llvm::FunctionType* ftype = llvm::FunctionType::get(t_int_, arg_type, false);
llvm::FunctionType* ftype = llvm::FunctionType::get(
ret_void ? t_void_ : t_int_, arg_type, false);
// setup the function. // setup the function.
function_ = llvm::cast<llvm::Function>(module_->getOrInsertFunction(f->name, ftype)); function_ = llvm::cast<llvm::Function>(module_->getOrInsertFunction(f->name, ftype));
function_->setCallingConv(llvm::CallingConv::C); function_->setCallingConv(llvm::CallingConv::C);
...@@ -129,8 +136,13 @@ void CodeGenLLVM::AddFunction(const LoweredFunc& f) { ...@@ -129,8 +136,13 @@ void CodeGenLLVM::AddFunction(const LoweredFunc& f) {
llvm::BasicBlock* block = llvm::BasicBlock::Create(*ctx_, "entry", function_); llvm::BasicBlock* block = llvm::BasicBlock::Create(*ctx_, "entry", function_);
builder_->SetInsertPoint(block); builder_->SetInsertPoint(block);
this->VisitStmt(f->body); this->VisitStmt(f->body);
if (ret_void) {
builder_->CreateRetVoid();
} else {
builder_->CreateRet(ConstInt32(0)); builder_->CreateRet(ConstInt32(0));
}
} }
void CodeGenLLVM::AddMainFunction(const std::string& entry_func_name) { void CodeGenLLVM::AddMainFunction(const std::string& entry_func_name) {
...@@ -155,6 +167,9 @@ class MPassManager : public llvm::legacy::PassManager { ...@@ -155,6 +167,9 @@ class MPassManager : public llvm::legacy::PassManager {
} }
}; };
void CodeGenLLVM::InitPassManagerBuilder(llvm::PassManagerBuilder* builder) {
}
void CodeGenLLVM::Optimize() { void CodeGenLLVM::Optimize() {
// place optimization pass // place optimization pass
llvm::PassManagerBuilder builder; llvm::PassManagerBuilder builder;
...@@ -167,6 +182,12 @@ void CodeGenLLVM::Optimize() { ...@@ -167,6 +182,12 @@ void CodeGenLLVM::Optimize() {
#endif #endif
builder.LoopVectorize = true; builder.LoopVectorize = true;
builder.SLPVectorize = true; builder.SLPVectorize = true;
this->InitPassManagerBuilder(&builder);
#if TVM_LLVM_VERSION >= 50
target_machine_->adjustPassManager(builder);
#endif
// pass manager // pass manager
FPassManager fpass(module_.get()); FPassManager fpass(module_.get());
MPassManager mpass; MPassManager mpass;
...@@ -313,25 +334,31 @@ llvm::Value* CodeGenLLVM::CreateCast(Type from, Type to, llvm::Value* value) { ...@@ -313,25 +334,31 @@ llvm::Value* CodeGenLLVM::CreateCast(Type from, Type to, llvm::Value* value) {
} }
} }
llvm::Value* CodeGenLLVM::CreateCallExtern(const Call* op) { llvm::CallInst* CodeGenLLVM::CreateCallExtern(
std::vector<llvm::Value*> arg_values(op->args.size()); llvm::Type* ret,
for (size_t i = 0; i < op->args.size(); ++i) { const std::string& name,
arg_values[i] = MakeValue(op->args[i]); const std::vector<llvm::Value*>& arg_values) {
}
std::vector<llvm::Type*> arg_types; std::vector<llvm::Type*> arg_types;
for (llvm::Value* v : arg_values) { for (llvm::Value* v : arg_values) {
arg_types.push_back(v->getType()); arg_types.push_back(v->getType());
} }
llvm::FunctionType* ftype = llvm::FunctionType::get( llvm::FunctionType* ftype = llvm::FunctionType::get(ret, arg_types, false);
LLVMType(op->type), arg_types, false); llvm::Function* f = module_->getFunction(name);
llvm::Function* f = module_->getFunction(op->name);
if (f == nullptr) { if (f == nullptr) {
f = llvm::Function::Create( f = llvm::Function::Create(
ftype, llvm::Function::ExternalLinkage, op->name, module_.get()); ftype, llvm::Function::ExternalLinkage, name, module_.get());
} }
return builder_->CreateCall(f, arg_values); return builder_->CreateCall(f, arg_values);
} }
llvm::Value* CodeGenLLVM::CreateCallExtern(const Call* op) {
std::vector<llvm::Value*> arg_values(op->args.size());
for (size_t i = 0; i < op->args.size(); ++i) {
arg_values[i] = MakeValue(op->args[i]);
}
return CreateCallExtern(LLVMType(op->type), op->name, arg_values);
}
llvm::Value* CodeGenLLVM::CreateScalarizedCall( llvm::Value* CodeGenLLVM::CreateScalarizedCall(
const Call* op, llvm::Function* f, const std::vector<llvm::Value*>& args) { const Call* op, llvm::Function* f, const std::vector<llvm::Value*>& args) {
llvm::Value* value = llvm::UndefValue::get(LLVMType(op->type)); llvm::Value* value = llvm::UndefValue::get(LLVMType(op->type));
...@@ -437,6 +464,8 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) { ...@@ -437,6 +464,8 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
auto id = static_cast<llvm::Intrinsic::ID>(op->args[0].as<UIntImm>()->value); auto id = static_cast<llvm::Intrinsic::ID>(op->args[0].as<UIntImm>()->value);
llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), id); llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), id);
return builder_->CreateCall(f, arg_values); return builder_->CreateCall(f, arg_values);
} else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) {
return CreateStorageSync(op);
} else if (op->is_intrinsic(Call::bitwise_and)) { } else if (op->is_intrinsic(Call::bitwise_and)) {
CHECK_EQ(op->args.size(), 2U); CHECK_EQ(op->args.size(), 2U);
return builder_->CreateAnd( return builder_->CreateAnd(
...@@ -510,7 +539,18 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) { ...@@ -510,7 +539,18 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
return nullptr; return nullptr;
} }
int CodeGenLLVM::NativeVectorBits(const std::string& storage_scope) const { // Get the corresponding thread index
llvm::Value* CodeGenLLVM::GetThreadIndex(const IterVar& iv) {
LOG(FATAL) << "Donot support threading " << iv;
return nullptr;
}
llvm::Value* CodeGenLLVM::CreateStorageSync(const Call* op) {
LOG(FATAL) << "Donot support storage sync in CPU mode";
return nullptr;
}
int CodeGenLLVM::NativeVectorBits(const runtime::StorageScope& storage_scope) const {
// By default, we ask the buffer to be aligned to 64 bytes // By default, we ask the buffer to be aligned to 64 bytes
return native_vector_bits_; return native_vector_bits_;
} }
...@@ -855,7 +895,8 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) { ...@@ -855,7 +895,8 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
ramp->base, make_const(ramp->base.type(), offset)); ramp->base, make_const(ramp->base.type(), offset));
llvm::Value* ptr = CreateBufferPtr(t.element_of(), buf, MakeValue(base)); llvm::Value* ptr = CreateBufferPtr(t.element_of(), buf, MakeValue(base));
llvm::Type* vtype = llvm::VectorType::get( llvm::Type* vtype = llvm::VectorType::get(
LLVMType(t.element_of()), lanes)->getPointerTo(); LLVMType(t.element_of()), lanes)->getPointerTo(
ptr->getType()->getPointerAddressSpace());
llvm::LoadInst* inst = builder_->CreateAlignedLoad( llvm::LoadInst* inst = builder_->CreateAlignedLoad(
builder_->CreatePointerCast(ptr, vtype), alignment); builder_->CreatePointerCast(ptr, vtype), alignment);
AddAliasInfo(inst, op->buffer_var.get(), AddAliasInfo(inst, op->buffer_var.get(),
...@@ -971,7 +1012,8 @@ void CodeGenLLVM::VisitStmt_(const Store* op) { ...@@ -971,7 +1012,8 @@ void CodeGenLLVM::VisitStmt_(const Store* op) {
ramp->base, make_const(ramp->base.type(), offset)); ramp->base, make_const(ramp->base.type(), offset));
llvm::Value* ptr = CreateBufferPtr(t.element_of(), buf, MakeValue(base)); llvm::Value* ptr = CreateBufferPtr(t.element_of(), buf, MakeValue(base));
llvm::Type* vtype = llvm::VectorType::get( llvm::Type* vtype = llvm::VectorType::get(
LLVMType(t.element_of()), lanes)->getPointerTo(); LLVMType(t.element_of()), lanes)->getPointerTo(
ptr->getType()->getPointerAddressSpace());
llvm::StoreInst* inst = builder_->CreateAlignedStore( llvm::StoreInst* inst = builder_->CreateAlignedStore(
CreateVecSlice(value, offset, lanes), CreateVecSlice(value, offset, lanes),
builder_->CreatePointerCast(ptr, vtype), alignment); builder_->CreatePointerCast(ptr, vtype), alignment);
...@@ -1069,17 +1111,28 @@ void CodeGenLLVM::VisitStmt_(const Allocate* op) { ...@@ -1069,17 +1111,28 @@ void CodeGenLLVM::VisitStmt_(const Allocate* op) {
} }
info.alignment = alloca->getAlignment(); info.alignment = alloca->getAlignment();
} }
buf = builder_->CreatePointerCast(buf, LLVMType(op->type)->getPointerTo()); buf = builder_->CreatePointerCast(
buf, LLVMType(op->type)->getPointerTo(
buf->getType()->getPointerAddressSpace()));
CHECK(!var_map_.count(op->buffer_var.get())); CHECK(!var_map_.count(op->buffer_var.get()));
var_map_[op->buffer_var.get()] = buf; var_map_[op->buffer_var.get()] = buf;
this->VisitStmt(op->body); this->VisitStmt(op->body);
} }
void CodeGenLLVM::VisitStmt_(const AttrStmt* op) { void CodeGenLLVM::VisitStmt_(const AttrStmt* op) {
if (op->attr_key == ir::attr::storage_scope) { if (op->attr_key == ir::attr::thread_extent) {
IterVar iv(op->node.node_);
if (iv->thread_tag.length() != 0) {
if (!var_map_.count(iv->var.get())) {
var_map_[iv->var.get()] = GetThreadIndex(iv);
}
}
this->VisitStmt(op->body);
} else if (op->attr_key == ir::attr::storage_scope) {
const Variable* v = op->node.as<Variable>(); const Variable* v = op->node.as<Variable>();
CHECK(v); CHECK(v);
alloc_storage_info_[v].scope = op->value.as<StringImm>()->value; alloc_storage_info_[v].scope = runtime::StorageScope::make(
op->value.as<StringImm>()->value);
this->VisitStmt(op->body); this->VisitStmt(op->body);
} else if (op->attr_key == ir::attr::storage_alignment) { } else if (op->attr_key == ir::attr::storage_alignment) {
const Variable* v = op->node.as<Variable>(); const Variable* v = op->node.as<Variable>();
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <vector> #include <vector>
#include <string> #include <string>
#include "./llvm_common.h" #include "./llvm_common.h"
#include "../../runtime/thread_storage_scope.h"
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
...@@ -116,22 +117,29 @@ class CodeGenLLVM : ...@@ -116,22 +117,29 @@ class CodeGenLLVM :
void VisitStmt_(const Block* op) override; void VisitStmt_(const Block* op) override;
void VisitStmt_(const Evaluate* op) override; void VisitStmt_(const Evaluate* op) override;
void VisitStmt_(const ProducerConsumer* op) override; void VisitStmt_(const ProducerConsumer* op) override;
// create intrinstic given call
virtual llvm::Value* CreateIntrinsic(const Call* op);
// create extern function call
virtual llvm::Value* CreateCallExtern(const Call* op);
// Scalarize e by iterating elements of e.
// f is a callback that takes index and v.
virtual void Scalarize(const Expr& e,
std::function<void(int i, llvm::Value* v)> f);
protected: protected:
/*! \brief The storage information */ /*! \brief The storage information */
struct StorageInfo { struct StorageInfo {
/*! \brief The storage scope */ /*! \brief The storage scope */
std::string scope; runtime::StorageScope scope;
/*! \brief The alignment of allocation */ /*! \brief The alignment of allocation */
int alignment{0}; int alignment{0};
}; };
// create intrinstic given call
virtual llvm::Value* CreateIntrinsic(const Call* op);
// create extern function call
virtual llvm::Value* CreateCallExtern(const Call* op);
// Get the corresponding thread index
virtual llvm::Value* GetThreadIndex(const IterVar& iv);
// Get the corresponding thread index
virtual llvm::Value* CreateStorageSync(const Call* op);
// apply optimization on the module.
virtual void InitPassManagerBuilder(llvm::PassManagerBuilder* builder);
// Scalarize by iterating elements of e.
// f is a callback that takes index and v.
virtual void Scalarize(const Expr& e,
std::function<void(int i, llvm::Value* v)> f);
// Initialize target // Initialize target
virtual void InitTarget(llvm::TargetMachine* tm); virtual void InitTarget(llvm::TargetMachine* tm);
// Add module startup function if needed. // Add module startup function if needed.
...@@ -139,7 +147,12 @@ class CodeGenLLVM : ...@@ -139,7 +147,12 @@ class CodeGenLLVM :
// apply optimization on the module. // apply optimization on the module.
virtual void Optimize(); virtual void Optimize();
// Get the maximim storage align bits of buffer pointer given storage scope. // Get the maximim storage align bits of buffer pointer given storage scope.
virtual int NativeVectorBits(const std::string& storage_scope) const; virtual int NativeVectorBits(const runtime::StorageScope& storage_scope) const;
void AddFunctionInternal(const LoweredFunc& f, bool ret_void);
// Create extern call
llvm::CallInst* CreateCallExtern(llvm::Type* ret,
const std::string& name,
const std::vector<llvm::Value*>& value);
/*! /*!
* \param t The original type. * \param t The original type.
* \return LLVM type of t * \return LLVM type of t
...@@ -192,6 +205,8 @@ class CodeGenLLVM : ...@@ -192,6 +205,8 @@ class CodeGenLLVM :
std::unique_ptr<llvm::DataLayout> data_layout_; std::unique_ptr<llvm::DataLayout> data_layout_;
// Internal metabuilder // Internal metabuilder
std::unique_ptr<llvm::MDBuilder> md_builder_; std::unique_ptr<llvm::MDBuilder> md_builder_;
// llvm target machine
llvm::TargetMachine* target_machine_{nullptr};
// llvm context // llvm context
llvm::LLVMContext* ctx_{nullptr}; llvm::LLVMContext* ctx_{nullptr};
// helpful data types // helpful data types
......
/*!
* Copyright (c) 2017 by Contributors
* \file codegen_nvptx.cc
* \brief NVPTX code generator.
*/
#ifdef TVM_LLVM_VERSION
#if TVM_CUDA_RUNTIME
#include <tvm/runtime/device_api.h>
#include "./codegen_llvm.h"
#include "../build_common.h"
#include "../../pass/ir_util.h"
#include "../../runtime/cuda/cuda_module.h"
namespace tvm {
namespace codegen {
// NVPTX code generator.
class CodeGenNVPTX : public CodeGenLLVM {
public:
void AddFunction(const LoweredFunc& f) final {
// add function as void return value
CodeGenLLVM::AddFunctionInternal(f, true);
// annotate as kernel function
module_->getOrInsertNamedMetadata("nvvm.annotations")
->addOperand(llvm::MDNode::get(*ctx_, {
llvm::ValueAsMetadata::get(function_),
llvm::MDString::get(*ctx_, "kernel"),
llvm::ValueAsMetadata::get(ConstInt32(1)) }));
}
void VisitStmt_(const Allocate* op) final {
CHECK(!is_zero(op->condition));
llvm::Value* buf = nullptr;
if (op->new_expr.defined()) {
CHECK_EQ(op->free_function, "nop");
buf = MakeValue(op->new_expr);
} else {
int32_t constant_size = op->constant_allocation_size();
CHECK_GT(constant_size, 0)
<< "Can only handle constant size stack allocation in GPU";
StorageInfo& info = alloc_storage_info_[op->buffer_var.get()];
if (constant_size % 4 == 0 && info.alignment == 0) {
info.alignment = GetTempAllocaAlignment(op->type, constant_size);
}
// maximum necessary alignment in the NV devices
if (info.alignment > 16) {
info.alignment = 16;
}
if (info.scope.rank == 2) {
// const int local_address_space = 5;
// TODO(tqchen): for higher version of LLVM, local address space can be set.
llvm::AllocaInst* alloca = builder_->CreateAlloca(
LLVMType(op->type), ConstInt32(constant_size));
if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
alloca->setAlignment(info.alignment);
}
buf = alloca;
} else {
CHECK_EQ(info.scope.rank, 1)
<< "Can only allocate shared or local memory inside kernel";
// Shared memory: address space == 3
const unsigned shared_address_space = 3;
llvm::Type* type = llvm::ArrayType::get(LLVMType(op->type), constant_size);
// Allocate shared memory in global, address_space = 3
llvm::GlobalVariable *global = new llvm::GlobalVariable(
*module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared",
nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space);
global->setAlignment(info.alignment);
buf = global;
}
}
buf = builder_->CreatePointerCast(
buf, LLVMType(op->type)->getPointerTo(
buf->getType()->getPointerAddressSpace()));
CHECK(!var_map_.count(op->buffer_var.get()));
var_map_[op->buffer_var.get()] = buf;
this->VisitStmt(op->body);
}
// Return the thread index via intrinsics.
llvm::Value* GetThreadIndex(const IterVar& iv) final {
runtime::ThreadScope ts = runtime::ThreadScope::make(iv->thread_tag);
llvm::Intrinsic::ID intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x;
if (ts.rank == 1) {
switch (ts.dim_index) {
case 0: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x; break;
case 1: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_y; break;
case 2: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_z; break;
default: LOG(FATAL) << "unknown thread idx";
}
} else {
CHECK_EQ(ts.rank, 0);
switch (ts.dim_index) {
case 0: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x; break;
case 1: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_y; break;
case 2: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_z; break;
default: LOG(FATAL) << "unknown thread idx";
}
}
llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), intrin_id);
return builder_->CreateCall(f, {});
}
llvm::Value* CreateStorageSync(const Call* op) final {
const std::string& sync = op->args[0].as<StringImm>()->value;
if (sync == "warp") {
// TODO(tqchen) warp sync in CUDA9
return nullptr;
} else if (sync == "shared") {
llvm::Function* f = llvm::Intrinsic::getDeclaration(
module_.get(),
::llvm::Intrinsic::nvvm_barrier0);
return builder_->CreateCall(f, {});
} else {
LOG(FATAL) << "Do not support sync " << sync;
return nullptr;
}
}
void InitPassManagerBuilder(llvm::PassManagerBuilder* builder) final {
// Additional optimization hook to tweak the builder.
}
protected:
void InitTarget(llvm::TargetMachine* tm) final {
// Maximum vector lane = float4
native_vector_bits_ = 4 * 32;
CodeGenLLVM::InitTarget(tm);
}
};
runtime::Module BuildNVPTX(Array<LoweredFunc> funcs, std::string target) {
CHECK(target.length(
) >= 5 &&
target.substr(0, 5) == "nvptx");
llvm::TargetMachine* tm = GetLLVMTargetMachine(
"-mtriple=nvptx64-nvidia-cuda -mcpu=sm_20" +
target.substr(5, target.length() - 5));
std::unique_ptr<CodeGenNVPTX> cg(new CodeGenNVPTX());
std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext());
cg->Init(funcs[0]->name, tm, ctx.get(), false, false);
for (LoweredFunc f : funcs) {
cg->AddFunction(f);
}
std::unique_ptr<llvm::Module> module = cg->Finish();
llvm::SmallString<8> data;
llvm::raw_svector_ostream dest(data);
dest.SetUnbuffered();
llvm::legacy::PassManager pass;
CHECK(tm->addPassesToEmitFile(
pass, dest, llvm::TargetMachine::CGFT_AssemblyFile) == 0)
<< "Cannot emit target CGFT_ObjectFile";
pass.run(*module);
std::string ptx(data.begin(), data.end());
return CUDAModuleCreate(ptx, "ptx", ExtractFuncInfo(funcs), "");
}
TVM_REGISTER_API("codegen.build_nvptx")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = BuildNVPTX(args[0], args[1]);
});
} // namespace codegen
} // namespace tvm
#endif // TVM_CUDA_RUNTIME
#endif // TVM_LLVM_VERSION
...@@ -37,19 +37,22 @@ void InitializeLLVM() { ...@@ -37,19 +37,22 @@ void InitializeLLVM() {
} }
llvm::TargetMachine* llvm::TargetMachine*
GetLLVMTargetMachine(const std::string& target_str, bool allow_null) { GetLLVMTargetMachine(const std::string& target_str,
bool allow_null) {
// setup target triple // setup target triple
CHECK(target_str.length() >= 4 && size_t start = 0;
target_str.substr(0, 4) == "llvm") if (target_str.length() >= 4 &&
<< "llvm target must starts with llvm"; target_str.substr(0, 4) == "llvm") {
start = 4;
}
// simple parser // simple parser
std::string target_triple = ""; std::string target_triple = "";
std::string cpu = "generic"; std::string cpu = "generic";
std::string attr = ""; std::string attr = "";
bool soft_float_abi = false; bool soft_float_abi = false;
std::string key, value; std::string key, value;
if (target_str.length() > 5) { std::istringstream is(target_str.substr(start, target_str.length() - start));
std::istringstream is(target_str.substr(5, target_str.length() - 5));
while (is >> key) { while (is >> key) {
if (key == "--system-lib" || key == "-system-lib") { if (key == "--system-lib" || key == "-system-lib") {
continue; continue;
...@@ -83,7 +86,7 @@ GetLLVMTargetMachine(const std::string& target_str, bool allow_null) { ...@@ -83,7 +86,7 @@ GetLLVMTargetMachine(const std::string& target_str, bool allow_null) {
LOG(FATAL) << "unknown option " << key; LOG(FATAL) << "unknown option " << key;
} }
} }
}
if (target_triple.length() == 0 || if (target_triple.length() == 0 ||
target_triple == "default") { target_triple == "default") {
target_triple = llvm::sys::getDefaultTargetTriple(); target_triple = llvm::sys::getDefaultTargetTriple();
...@@ -109,9 +112,8 @@ GetLLVMTargetMachine(const std::string& target_str, bool allow_null) { ...@@ -109,9 +112,8 @@ GetLLVMTargetMachine(const std::string& target_str, bool allow_null) {
} else { } else {
opt.FloatABIType = llvm::FloatABI::Hard; opt.FloatABIType = llvm::FloatABI::Hard;
} }
auto rmodel = llvm::Reloc::PIC_; llvm::TargetMachine* tm = target->createTargetMachine(
llvm::TargetMachine* tm = target_triple, cpu, attr, opt, llvm::Reloc::PIC_);
target->createTargetMachine(target_triple, cpu, attr, opt, rmodel);
return tm; return tm;
} }
......
...@@ -112,6 +112,8 @@ bool RuntimeEnabled(const std::string& target) { ...@@ -112,6 +112,8 @@ bool RuntimeEnabled(const std::string& target) {
f_name = "device_api.rpc"; f_name = "device_api.rpc";
} else if (target == "vpi" || target == "verilog") { } else if (target == "vpi" || target == "verilog") {
f_name = "device_api.vpi"; f_name = "device_api.vpi";
} else if (target.length() >= 5 && target.substr(0, 5) == "nvptx") {
f_name = "codegen.build_nvptx";
} else if (target.length() >= 4 && target.substr(0, 4) == "llvm") { } else if (target.length() >= 4 && target.substr(0, 4) == "llvm") {
const PackedFunc* pf = runtime::Registry::Get("codegen.llvm_target_enabled"); const PackedFunc* pf = runtime::Registry::Get("codegen.llvm_target_enabled");
if (pf == nullptr) return false; if (pf == nullptr) return false;
......
...@@ -84,6 +84,7 @@ def test_gemm(): ...@@ -84,6 +84,7 @@ def test_gemm():
np.testing.assert_allclose( np.testing.assert_allclose(
c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5) c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5)
check_device("nvptx -mcpu=sm_20")
check_device("metal") check_device("metal")
check_device("opencl") check_device("opencl")
check_device("cuda") check_device("cuda")
......
...@@ -81,7 +81,7 @@ def test_add_pipeline(): ...@@ -81,7 +81,7 @@ def test_add_pipeline():
check_target("cuda", host="stackvm") check_target("cuda", host="stackvm")
check_target("cuda", host="llvm") check_target("cuda", host="llvm")
check_module_save("cuda", host="stackvm") check_module_save("cuda", host="stackvm")
check_target("nvptx", host="llvm")
if __name__ == "__main__": if __name__ == "__main__":
test_add_pipeline() test_add_pipeline()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment