Commit 891e226b by Aditya Atluri Committed by Tianqi Chen

[BACKEND] initial llvm codegen for amdgpu (#402)

* added initial llvm codegen for amdgpu

* fixed whitespace

* fixed hsaco gen from ir

* fixed targetmachine for rocm and added GetSource for rocm

* fixed whitespace issues

* changed statement to use less than 100 lines

* added intrinsics for workgroup - rocm

* whitespace - newline error fix

* fixed error msg for workitem-workgroup intrinsics

* added llvm ir dump for rocm codegen

* [ROCM] changed codegen to emit proper amdgpu kernel header

* fixed whitespace error

* fixed whitespace error- 2

* fixed AddFunction to not to use extra arg

1. Changed AddFunctionInternal to not to take extra arg for target type
2. Use Target from CodeGenLLVM to check for AMDGPU target

* fixed whitespaces

* fixed whitespaces 2

* fixed codegen for AMDGPU - now generating valid IR

* fixed codegen depending on code review

* reviewed alignment for amd devices

* added code to dump code object to file

* fixed cpplint errors

* print out IR after pass manager

* added code to dump asm, obj to file and std string

* fixed whitespaces

* Update codegen_amdgpu.cc

* used registry for amdgpu llvm

* Fixed whitespaces

* added code for calling linker

* fixed formatting errors

* added rocm link python interface

* fixed pylint issues and added more body to the function

* added doc string

* added doc string for module

* fixed python code after review, fixed llvm object codegen

* fixed linker to generate code object

* removed dumping to output file and debugging log out

* fixed lint for python code

* added fault check after running linker

* removed print statement in rocm.py

* changed rocm lld linker to raise runtimeerror than emitting error log to stderr

* changed the way linker command line is pass to subprocess.popen

* removed redundant code and reuse tvm utils

* removed commented out code

* removed cloning of unused modules, and put IR into string
parent 5061a6da
Subproject commit 46886a6b47f660cda581e497378204ccc029a01e Subproject commit a527100d7d5001efc4954848a2fc6027e48c05f4
...@@ -29,3 +29,4 @@ from .ndarray import register_extension ...@@ -29,3 +29,4 @@ from .ndarray import register_extension
from .schedule import create_schedule from .schedule import create_schedule
from .build_module import build, lower, build_config from .build_module import build, lower, build_config
from .tag import tag_scope from .tag import tag_scope
from .contrib import rocm as _rocm
...@@ -59,6 +59,8 @@ def context(dev_type, dev_id=0): ...@@ -59,6 +59,8 @@ def context(dev_type, dev_id=0):
if dev_type not in TVMContext.STR2MASK: if dev_type not in TVMContext.STR2MASK:
if dev_type.find("nvptx") != -1: if dev_type.find("nvptx") != -1:
dev_type = "cuda" dev_type = "cuda"
if dev_type.find("rocm") != -1:
dev_type = "rocm"
if dev_type not in TVMContext.STR2MASK: if dev_type not in TVMContext.STR2MASK:
raise ValueError("Unknown device type %s" % dev_type) raise ValueError("Unknown device type %s" % dev_type)
dev_type = TVMContext.STR2MASK[dev_type] dev_type = TVMContext.STR2MASK[dev_type]
......
"""Utility for ROCm backend"""
import subprocess
from . import util
from ..api import register_func
def rocm_link(in_file, out_file):
"""Link relocatable ELF object to shared ELF object using lld
Parameters
----------
in_file : str
Input file name (relocatable ELF object file)
out_file : str
Output file name (shared ELF object file)
"""
args = ["ld.lld", "-shared", in_file, "-o", out_file]
proc = subprocess.Popen(
args,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
if proc.returncode != 0:
msg = "Linking error using ld.lld:\n"
msg += str(out)
raise RuntimeError(msg)
@register_func("tvm_callback_rocm_link")
def callback_rocm_link(obj_bin):
"""Links object file generated from LLVM to HSA Code Object
Parameters
----------
obj_bin : bytearray
The object file
Return
------
cobj_bin : bytearray
The HSA Code Object
"""
tmp_dir = util.tempdir()
tmp_obj = tmp_dir.relpath("rocm_kernel.o")
tmp_cobj = tmp_dir.relpath("rocm_kernel.co")
with open(tmp_obj, "wb") as out_file:
out_file.write(bytes(obj_bin))
rocm_link(tmp_obj, tmp_cobj)
cobj_bin = bytearray(open(tmp_cobj, "rb").read())
return cobj_bin
/*!
* Copyright (c) 2017 by Contributors
* \file codegen_amdgpu.cc
* \brief AMDGPU code generator.
*/
#ifdef TVM_LLVM_VERSION
#if TVM_ROCM_RUNTIME
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/registry.h>
#include "./codegen_llvm.h"
#include "../build_common.h"
#include "../../pass/ir_util.h"
#include "../../runtime/rocm/rocm_module.h"
namespace tvm {
namespace codegen {
// AMDGPU code generator.
class CodeGenAMDGPU : public CodeGenLLVM {
public:
void AddFunction(const LoweredFunc& f) final {
// add function as void return value
CodeGenLLVM::AddFunctionInternal(f, true);
function_->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
}
void VisitStmt_(const Allocate* op) final {
CHECK(!is_zero(op->condition));
llvm::Value* buf = nullptr;
if (op->new_expr.defined()) {
CHECK_EQ(op->free_function, "nop");
buf = MakeValue(op->new_expr);
} else {
int32_t constant_size = op->constant_allocation_size();
CHECK_GT(constant_size, 0)
<< "Can only handle constant size stack allocation in GPU";
StorageInfo& info = alloc_storage_info_[op->buffer_var.get()];
if (constant_size % 4 == 0 && info.alignment == 0) {
info.alignment = GetTempAllocaAlignment(op->type, constant_size);
}
// maximum necessary alignment in the AMD devices
if (info.alignment > 16) {
info.alignment = 16;
}
if (info.scope.rank == 2) {
// const int local_address_space = 5;
// TODO(tqchen): for higher version of LLVM, local address space can be set.
llvm::AllocaInst* alloca = builder_->CreateAlloca(
LLVMType(op->type), ConstInt32(constant_size));
if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
alloca->setAlignment(info.alignment);
}
buf = alloca;
} else {
CHECK_EQ(info.scope.rank, 1)
<< "Can only allocate shared or local memory inside kernel";
// Shared memory: address space == 3
const unsigned shared_address_space = 3;
llvm::Type* type = llvm::ArrayType::get(LLVMType(op->type), constant_size);
// Allocate shared memory in global, address_space = 3
llvm::GlobalVariable *global = new llvm::GlobalVariable(
*module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared",
nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space);
global->setAlignment(info.alignment);
buf = global;
}
}
buf = builder_->CreatePointerCast(
buf, LLVMType(op->type)->getPointerTo(
buf->getType()->getPointerAddressSpace()));
CHECK(!var_map_.count(op->buffer_var.get()));
var_map_[op->buffer_var.get()] = buf;
this->VisitStmt(op->body);
}
// Return the thread index via intrinsics.
llvm::Value* GetThreadIndex(const IterVar& iv) final {
runtime::ThreadScope ts = runtime::ThreadScope::make(iv->thread_tag);
llvm::Intrinsic::ID intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_x;
if (ts.rank == 1) {
switch (ts.dim_index) {
case 0: intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_x; break;
case 1: intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_y; break;
case 2: intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_z; break;
default: LOG(FATAL) << "unknown workitem idx";
}
} else {
CHECK_EQ(ts.rank, 0);
switch (ts.dim_index) {
case 0: intrin_id = ::llvm::Intrinsic::amdgcn_workgroup_id_x; break;
case 1: intrin_id = ::llvm::Intrinsic::amdgcn_workgroup_id_y; break;
case 2: intrin_id = ::llvm::Intrinsic::amdgcn_workgroup_id_z; break;
default: LOG(FATAL) << "unknown workgroup idx";
}
}
llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), intrin_id);
return builder_->CreateCall(f, {});
}
llvm::Value* CreateStorageSync(const Call* op) final {
const std::string& sync = op->args[0].as<StringImm>()->value;
if (sync == "warp") {
// TODO(tqchen) warp sync in CUDA9
return nullptr;
} else if (sync == "shared") {
llvm::Function* f = llvm::Intrinsic::getDeclaration(
module_.get(),
::llvm::Intrinsic::amdgcn_s_barrier);
return builder_->CreateCall(f, {});
} else {
LOG(FATAL) << "Do not support sync " << sync;
return nullptr;
}
}
void InitPassManagerBuilder(llvm::PassManagerBuilder* builder) final {
// Additional optimization hook to tweak the builder.
}
unsigned GetGlobalAddressSpace() {
return 1;
}
protected:
void InitTarget(llvm::TargetMachine* tm) final {
// Maximum vector lane = float4
native_vector_bits_ = 4 * 32;
CodeGenLLVM::InitTarget(tm);
}
};
runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) {
CHECK(target.length(
) >= 4 &&
target.substr(0, 4) == "rocm");
llvm::TargetMachine* tm = \
GetLLVMTargetMachine("-mtriple=amdgcn-amd-amdhsa-hcc -mcpu=gfx900" + \
target.substr(4, target.length() - 4));
std::unique_ptr<CodeGenAMDGPU> cg(new CodeGenAMDGPU());
std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext());
cg->Init(funcs[0]->name, tm, ctx.get(), false, false);
for (LoweredFunc f : funcs) {
cg->AddFunction(f);
}
std::unique_ptr<llvm::Module> module = cg->Finish();
llvm::SmallString<8> dataObj, data_ll, dataAsm;
llvm::raw_svector_ostream destObj(dataObj), dest_ll(data_ll), destAsm(dataAsm);
destObj.SetUnbuffered();
dest_ll.SetUnbuffered();
destAsm.SetUnbuffered();
module->print(dest_ll, nullptr);
std::unique_ptr<llvm::Module> mAsm = llvm::CloneModule(module.get());
std::unique_ptr<llvm::Module> mObj = llvm::CloneModule(module.get());
llvm::legacy::PassManager pass;
CHECK(tm->addPassesToEmitFile(
pass, destObj, llvm::TargetMachine::CGFT_ObjectFile) == 0)
<< "Cannot emit target CGFT_ObjectFile";
pass.run(*mObj);
std::string obj(dataObj.begin(), dataObj.end());
const auto* f = tvm::runtime::Registry::Get("tvm_callback_rocm_link");
CHECK(f != nullptr) << "Require tvm_callback_rocm_link to exist, do import tvm.contrib.rocm";
TVMByteArray arr;
arr.data = &obj[0];
arr.size = obj.length();
std::string hsaco = (*f)(arr);
std::string ll(data_ll.begin(), data_ll.end());
return ROCMModuleCreate(hsaco, "hsaco", ExtractFuncInfo(funcs), ll);
}
TVM_REGISTER_API("codegen.build_rocm")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = BuildAMDGPU(args[0], args[1]);
});
} // namespace codegen
} // namespace tvm
#endif // TVM_ROCM_RUNTIME
#endif // TVM_LLVM_VERSION
...@@ -100,7 +100,7 @@ void CodeGenLLVM::AddFunctionInternal(const LoweredFunc& f, bool ret_void) { ...@@ -100,7 +100,7 @@ void CodeGenLLVM::AddFunctionInternal(const LoweredFunc& f, bool ret_void) {
Type t = arg.type(); Type t = arg.type();
if (t.is_handle() && f->handle_data_type.count(arg)) { if (t.is_handle() && f->handle_data_type.count(arg)) {
arg_type.push_back( arg_type.push_back(
LLVMType(f->handle_data_type[arg].type())->getPointerTo()); LLVMType(f->handle_data_type[arg].type())->getPointerTo(GetGlobalAddressSpace()));
if (!is_restricted_) { if (!is_restricted_) {
alias_var_set_.insert(arg.get()); alias_var_set_.insert(arg.get());
} }
...@@ -555,6 +555,10 @@ int CodeGenLLVM::NativeVectorBits(const runtime::StorageScope& storage_scope) co ...@@ -555,6 +555,10 @@ int CodeGenLLVM::NativeVectorBits(const runtime::StorageScope& storage_scope) co
return native_vector_bits_; return native_vector_bits_;
} }
unsigned CodeGenLLVM::GetGlobalAddressSpace() {
return 0;
}
void CodeGenLLVM::GetAlignment( void CodeGenLLVM::GetAlignment(
Type t, const Variable* buf_var, const Expr& index, Type t, const Variable* buf_var, const Expr& index,
int* p_alignment, int* p_native_bits) { int* p_alignment, int* p_native_bits) {
......
...@@ -23,6 +23,7 @@ namespace codegen { ...@@ -23,6 +23,7 @@ namespace codegen {
using namespace ir; using namespace ir;
/*! /*!
* \brief A base class to generate a LLVM. * \brief A base class to generate a LLVM.
*/ */
...@@ -148,6 +149,9 @@ class CodeGenLLVM : ...@@ -148,6 +149,9 @@ class CodeGenLLVM :
virtual void Optimize(); virtual void Optimize();
// Get the maximim storage align bits of buffer pointer given storage scope. // Get the maximim storage align bits of buffer pointer given storage scope.
virtual int NativeVectorBits(const runtime::StorageScope& storage_scope) const; virtual int NativeVectorBits(const runtime::StorageScope& storage_scope) const;
// Get correct address space depending on the backend
virtual unsigned GetGlobalAddressSpace();
void AddFunctionInternal(const LoweredFunc& f, bool ret_void); void AddFunctionInternal(const LoweredFunc& f, bool ret_void);
// Create extern call // Create extern call
llvm::CallInst* CreateCallExtern(llvm::Type* ret, llvm::CallInst* CreateCallExtern(llvm::Type* ret,
......
...@@ -125,6 +125,8 @@ bool RuntimeEnabled(const std::string& target) { ...@@ -125,6 +125,8 @@ bool RuntimeEnabled(const std::string& target) {
f_name = "device_api.vpi"; f_name = "device_api.vpi";
} else if (target.length() >= 5 && target.substr(0, 5) == "nvptx") { } else if (target.length() >= 5 && target.substr(0, 5) == "nvptx") {
f_name = "codegen.build_nvptx"; f_name = "codegen.build_nvptx";
} else if (target.length() >= 4 && target.substr(0, 4) == "rocm") {
f_name = "codegen.build_rocm";
} else if (target.length() >= 4 && target.substr(0, 4) == "llvm") { } else if (target.length() >= 4 && target.substr(0, 4) == "llvm") {
const PackedFunc* pf = runtime::Registry::Get("codegen.llvm_target_enabled"); const PackedFunc* pf = runtime::Registry::Get("codegen.llvm_target_enabled");
if (pf == nullptr) return false; if (pf == nullptr) return false;
......
...@@ -59,10 +59,17 @@ class ROCMModuleNode : public runtime::ModuleNode { ...@@ -59,10 +59,17 @@ class ROCMModuleNode : public runtime::ModuleNode {
stream->Write(data_); stream->Write(data_);
} }
std::string GetSource(const std::string& format) final {
if (format == fmt_) { return data_; }
if (fmt_ == "hsaco") { return data_; }
return "";
}
// get a CUfunction from primary context in device_id // get a CUfunction from primary context in device_id
hipFunction_t GetFunc(int device_id, const std::string& func_name) { hipFunction_t GetFunc(int device_id, const std::string& func_name) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
// must recheck under the lock scope // must recheck under the lock scope
if (module_[device_id] == nullptr) { if (module_[device_id] == nullptr) {
ROCM_DRIVER_CALL(hipModuleLoadData(&(module_[device_id]), data_.c_str())); ROCM_DRIVER_CALL(hipModuleLoadData(&(module_[device_id]), data_.c_str()));
} }
...@@ -140,7 +147,9 @@ class ROCMWrappedFunc { ...@@ -140,7 +147,9 @@ class ROCMWrappedFunc {
if (fcache_[device_id] == nullptr) { if (fcache_[device_id] == nullptr) {
fcache_[device_id] = m_->GetFunc(device_id, func_name_); fcache_[device_id] = m_->GetFunc(device_id, func_name_);
} }
hipStream_t strm = static_cast<hipStream_t>(ROCMThreadEntry::ThreadLocal()->stream); hipStream_t strm = static_cast<hipStream_t>(ROCMThreadEntry::ThreadLocal()->stream);
ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
void* config[] = { void* config[] = {
HIP_LAUNCH_PARAM_BUFFER_POINTER, &packed_args, HIP_LAUNCH_PARAM_BUFFER_POINTER, &packed_args,
...@@ -181,7 +190,6 @@ PackedFunc ROCMModuleNode::GetFunction( ...@@ -181,7 +190,6 @@ PackedFunc ROCMModuleNode::GetFunction(
CHECK_EQ(sptr_to_self.get(), this); CHECK_EQ(sptr_to_self.get(), this);
CHECK_NE(name, symbol::tvm_module_main) CHECK_NE(name, symbol::tvm_module_main)
<< "Device function do not have main"; << "Device function do not have main";
auto it = fmap_.find(name); auto it = fmap_.find(name);
if (it == fmap_.end()) return PackedFunc(); if (it == fmap_.end()) return PackedFunc();
const FunctionInfo& info = it->second; const FunctionInfo& info = it->second;
......
...@@ -85,6 +85,8 @@ def test_gemm(): ...@@ -85,6 +85,8 @@ def test_gemm():
np.testing.assert_allclose( np.testing.assert_allclose(
c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5) c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5)
check_device("nvptx -mcpu=sm_20")
check_device("rocm")
check_device("metal") check_device("metal")
check_device("opencl") check_device("opencl")
check_device("cuda") check_device("cuda")
......
...@@ -82,6 +82,7 @@ def test_add_pipeline(): ...@@ -82,6 +82,7 @@ def test_add_pipeline():
check_target("cuda", host="llvm") check_target("cuda", host="llvm")
check_module_save("cuda", host="stackvm") check_module_save("cuda", host="stackvm")
check_target("nvptx", host="llvm") check_target("nvptx", host="llvm")
check_target("rocm", host="llvm")
if __name__ == "__main__": if __name__ == "__main__":
test_add_pipeline() test_add_pipeline()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment