Commit 89b8456e by Tianqi Chen Committed by GitHub

[CODEGEN] enable static handle cache (#723)

parent 9e01367d
...@@ -257,6 +257,11 @@ constexpr const char* tvm_if_then_else = "tvm_if_then_else"; ...@@ -257,6 +257,11 @@ constexpr const char* tvm_if_then_else = "tvm_if_then_else";
*/ */
constexpr const char* tvm_access_ptr = "tvm_access_ptr"; constexpr const char* tvm_access_ptr = "tvm_access_ptr";
/*! /*!
* \brief Create a function local static handle that iniitalizes to nullptr.
* can be used to cache function local static resources.
*/
constexpr const char* tvm_static_handle = "tvm_static_handle";
/*!
* \brief Return a unique context id, used for hint of workspace separation. * \brief Return a unique context id, used for hint of workspace separation.
* Different context id ganrantees not having overlapping workspace. * Different context id ganrantees not having overlapping workspace.
*/ */
......
...@@ -80,6 +80,33 @@ def call_pure_intrin(dtype, func_name, *args): ...@@ -80,6 +80,33 @@ def call_pure_intrin(dtype, func_name, *args):
dtype, func_name, convert(args), _Call.PureIntrinsic, None, 0) dtype, func_name, convert(args), _Call.PureIntrinsic, None, 0)
def call_intrin(dtype, func_name, *args):
"""Build expression by calling an intrinsic function.
Intrinsics can be overloaded with multiple data types via
the intrinsic translation rule.
Parameters
----------
dtype : str
The data type of the result.
func_name: str
The intrinsic function name.
args : list
Positional arguments.
Returns
-------
call : Expr
The call expression.
"""
args = convert(args)
return _make.Call(
dtype, func_name, convert(args), _Call.Intrinsic, None, 0)
def call_pure_extern(dtype, func_name, *args): def call_pure_extern(dtype, func_name, *args):
"""Build expression by calling a pure extern function. """Build expression by calling a pure extern function.
......
...@@ -419,6 +419,16 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task) { ...@@ -419,6 +419,16 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task) {
builder_->SetInsertPoint(par_launch_end); builder_->SetInsertPoint(par_launch_end);
} }
llvm::Value* CodeGenCPU::CreateStaticHandle() {
llvm::GlobalVariable* gv = new llvm::GlobalVariable(
*module_, t_void_p_, false,
llvm::GlobalValue::PrivateLinkage, 0,
"__tvm_static_handle");
gv->setAlignment(data_layout_->getTypeAllocSize(t_void_p_));
gv->setInitializer(llvm::Constant::getNullValue(t_void_p_));
return gv;
}
void CodeGenCPU::CreateStaticInit(const std::string& init_fname, const Stmt& body) { void CodeGenCPU::CreateStaticInit(const std::string& init_fname, const Stmt& body) {
using llvm::BasicBlock; using llvm::BasicBlock;
// closure data // closure data
...@@ -426,12 +436,7 @@ void CodeGenCPU::CreateStaticInit(const std::string& init_fname, const Stmt& bod ...@@ -426,12 +436,7 @@ void CodeGenCPU::CreateStaticInit(const std::string& init_fname, const Stmt& bod
ftype_tvm_static_init_callback_, ftype_tvm_static_init_callback_,
llvm::Function::PrivateLinkage, llvm::Function::PrivateLinkage,
"__tvm_static_init_lambda", module_.get()); "__tvm_static_init_lambda", module_.get());
llvm::GlobalVariable* gv = new llvm::GlobalVariable( llvm::Value* gv = CreateStaticHandle();
*module_, t_void_p_, false,
llvm::GlobalValue::PrivateLinkage, 0,
"__tvm_static_handle");
gv->setAlignment(data_layout_->getTypeAllocSize(t_void_p_));
gv->setInitializer(llvm::Constant::getNullValue(t_void_p_));
llvm::Function* finit = module_->getFunction(init_fname); llvm::Function* finit = module_->getFunction(init_fname);
if (finit == nullptr) { if (finit == nullptr) {
finit = llvm::Function::Create( finit = llvm::Function::Create(
...@@ -599,6 +604,8 @@ void CodeGenCPU::AddStartupFunction() { ...@@ -599,6 +604,8 @@ void CodeGenCPU::AddStartupFunction() {
llvm::Value* CodeGenCPU::CreateIntrinsic(const Call* op) { llvm::Value* CodeGenCPU::CreateIntrinsic(const Call* op) {
if (op->is_intrinsic(intrinsic::tvm_call_packed_lowered)) { if (op->is_intrinsic(intrinsic::tvm_call_packed_lowered)) {
return CreateCallPacked(op); return CreateCallPacked(op);
} else if (op->is_intrinsic(intrinsic::tvm_static_handle)) {
return CreateStaticHandle();
} else if (op->is_intrinsic(intrinsic::tvm_throw_last_error)) { } else if (op->is_intrinsic(intrinsic::tvm_throw_last_error)) {
builder_->CreateRet(ConstInt32(-1)); builder_->CreateRet(ConstInt32(-1));
return ConstInt32(-1); return ConstInt32(-1);
......
...@@ -72,6 +72,7 @@ class CodeGenCPU : public CodeGenLLVM { ...@@ -72,6 +72,7 @@ class CodeGenCPU : public CodeGenLLVM {
llvm::Value* RuntimeTVMAPISetLastError(); llvm::Value* RuntimeTVMAPISetLastError();
llvm::Value* RuntimeTVMParallelLaunch(); llvm::Value* RuntimeTVMParallelLaunch();
llvm::Value* RuntimeTVMParallelBarrier(); llvm::Value* RuntimeTVMParallelBarrier();
llvm::Value* CreateStaticHandle();
llvm::Value* GetPackedFuncHandle(const std::string& str); llvm::Value* GetPackedFuncHandle(const std::string& str);
llvm::Value* PackClosureData(const Array<Var>& fields, uint64_t *num_bytes); llvm::Value* PackClosureData(const Array<Var>& fields, uint64_t *num_bytes);
llvm::Value* CreateStructRefPtr(Type t, llvm::Value* buffer, llvm::Value* index, int kind); llvm::Value* CreateStructRefPtr(Type t, llvm::Value* buffer, llvm::Value* index, int kind);
......
import tvm import tvm
import ctypes
import numpy as np import numpy as np
def test_static_init(): def test_static_callback():
dtype = 'int64' dtype = 'int64'
n = tvm.var('n') n = tvm.var('n')
Ab = tvm.decl_buffer((n, ), dtype) Ab = tvm.decl_buffer((n, ), dtype)
...@@ -22,6 +23,29 @@ def test_static_init(): ...@@ -22,6 +23,29 @@ def test_static_init():
f(a) f(a)
np.testing.assert_equal(a.asnumpy(), np.ones(a.shape[0])) np.testing.assert_equal(a.asnumpy(), np.ones(a.shape[0]))
def test_static_init():
dtype = 'int64'
n = tvm.var('n')
Ab = tvm.decl_buffer((n, ), dtype)
i = tvm.var('i')
ib = tvm.ir_builder.create()
handle = tvm.call_intrin("handle", "tvm_static_handle")
ib.emit(
tvm.call_packed("test_static_callback", handle, Ab))
@tvm.register_func("test_static_callback")
def test_cb(sh, A):
assert isinstance(sh, ctypes.c_void_p)
return sh
stmt = ib.get()
fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
fapi = tvm.ir_pass.LowerTVMBuiltin(fapi)
f = tvm.codegen.build_module(fapi, "llvm")
a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a)
if __name__ == "__main__": if __name__ == "__main__":
test_static_callback()
test_static_init() test_static_init()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment