Commit 89b8456e by Tianqi Chen Committed by GitHub

[CODEGEN] enable static handle cache (#723)

parent 9e01367d
......@@ -257,6 +257,11 @@ constexpr const char* tvm_if_then_else = "tvm_if_then_else";
*/
constexpr const char* tvm_access_ptr = "tvm_access_ptr";
/*!
* \brief Create a function local static handle that iniitalizes to nullptr.
* can be used to cache function local static resources.
*/
constexpr const char* tvm_static_handle = "tvm_static_handle";
/*!
* \brief Return a unique context id, used for hint of workspace separation.
* Different context id ganrantees not having overlapping workspace.
*/
......
......@@ -80,6 +80,33 @@ def call_pure_intrin(dtype, func_name, *args):
dtype, func_name, convert(args), _Call.PureIntrinsic, None, 0)
def call_intrin(dtype, func_name, *args):
"""Build expression by calling an intrinsic function.
Intrinsics can be overloaded with multiple data types via
the intrinsic translation rule.
Parameters
----------
dtype : str
The data type of the result.
func_name: str
The intrinsic function name.
args : list
Positional arguments.
Returns
-------
call : Expr
The call expression.
"""
args = convert(args)
return _make.Call(
dtype, func_name, convert(args), _Call.Intrinsic, None, 0)
def call_pure_extern(dtype, func_name, *args):
"""Build expression by calling a pure extern function.
......
......@@ -419,6 +419,16 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task) {
builder_->SetInsertPoint(par_launch_end);
}
llvm::Value* CodeGenCPU::CreateStaticHandle() {
llvm::GlobalVariable* gv = new llvm::GlobalVariable(
*module_, t_void_p_, false,
llvm::GlobalValue::PrivateLinkage, 0,
"__tvm_static_handle");
gv->setAlignment(data_layout_->getTypeAllocSize(t_void_p_));
gv->setInitializer(llvm::Constant::getNullValue(t_void_p_));
return gv;
}
void CodeGenCPU::CreateStaticInit(const std::string& init_fname, const Stmt& body) {
using llvm::BasicBlock;
// closure data
......@@ -426,12 +436,7 @@ void CodeGenCPU::CreateStaticInit(const std::string& init_fname, const Stmt& bod
ftype_tvm_static_init_callback_,
llvm::Function::PrivateLinkage,
"__tvm_static_init_lambda", module_.get());
llvm::GlobalVariable* gv = new llvm::GlobalVariable(
*module_, t_void_p_, false,
llvm::GlobalValue::PrivateLinkage, 0,
"__tvm_static_handle");
gv->setAlignment(data_layout_->getTypeAllocSize(t_void_p_));
gv->setInitializer(llvm::Constant::getNullValue(t_void_p_));
llvm::Value* gv = CreateStaticHandle();
llvm::Function* finit = module_->getFunction(init_fname);
if (finit == nullptr) {
finit = llvm::Function::Create(
......@@ -599,6 +604,8 @@ void CodeGenCPU::AddStartupFunction() {
llvm::Value* CodeGenCPU::CreateIntrinsic(const Call* op) {
if (op->is_intrinsic(intrinsic::tvm_call_packed_lowered)) {
return CreateCallPacked(op);
} else if (op->is_intrinsic(intrinsic::tvm_static_handle)) {
return CreateStaticHandle();
} else if (op->is_intrinsic(intrinsic::tvm_throw_last_error)) {
builder_->CreateRet(ConstInt32(-1));
return ConstInt32(-1);
......
......@@ -72,6 +72,7 @@ class CodeGenCPU : public CodeGenLLVM {
llvm::Value* RuntimeTVMAPISetLastError();
llvm::Value* RuntimeTVMParallelLaunch();
llvm::Value* RuntimeTVMParallelBarrier();
llvm::Value* CreateStaticHandle();
llvm::Value* GetPackedFuncHandle(const std::string& str);
llvm::Value* PackClosureData(const Array<Var>& fields, uint64_t *num_bytes);
llvm::Value* CreateStructRefPtr(Type t, llvm::Value* buffer, llvm::Value* index, int kind);
......
import tvm
import ctypes
import numpy as np
def test_static_init():
def test_static_callback():
dtype = 'int64'
n = tvm.var('n')
Ab = tvm.decl_buffer((n, ), dtype)
......@@ -22,6 +23,29 @@ def test_static_init():
f(a)
np.testing.assert_equal(a.asnumpy(), np.ones(a.shape[0]))
def test_static_init():
dtype = 'int64'
n = tvm.var('n')
Ab = tvm.decl_buffer((n, ), dtype)
i = tvm.var('i')
ib = tvm.ir_builder.create()
handle = tvm.call_intrin("handle", "tvm_static_handle")
ib.emit(
tvm.call_packed("test_static_callback", handle, Ab))
@tvm.register_func("test_static_callback")
def test_cb(sh, A):
assert isinstance(sh, ctypes.c_void_p)
return sh
stmt = ib.get()
fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
fapi = tvm.ir_pass.LowerTVMBuiltin(fapi)
f = tvm.codegen.build_module(fapi, "llvm")
a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a)
if __name__ == "__main__":
test_static_callback()
test_static_init()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment