Commit 944de73b by Zhixun Tan Committed by Tianqi Chen

Add type code and bits to AllocWorkspace. (#831)

parent eb8077ff
......@@ -44,14 +44,20 @@ TVM_DLL int TVMBackendRegisterSystemLibSymbol(const char* name, void* ptr);
*
* \note The result allocate spaced is ensured to be aligned to kTempAllocaAlignment.
*
* \param size The size of the space requested.
* \param nbytes The size of the space requested.
* \param device_type The device type which the space will be allocated.
* \param device_id The device id which the space will be allocated.
* \param dtype_code_hint The type code of the array elements. Only used in
* certain backends such as OpenGL.
* \param dtype_bits_hint The type bits of the array elements. Only used in
* certain backends such as OpenGL.
* \return nullptr when error is thrown, a valid ptr if success
*/
TVM_DLL void* TVMBackendAllocWorkspace(int device_type,
int device_id,
uint64_t size);
uint64_t nbytes,
int dtype_code_hint,
int dtype_bits_hint);
/*!
* \brief Backend function to free temporal workspace.
......
......@@ -114,9 +114,13 @@ class DeviceAPI {
* - Workspace should not overlap between different threads(i.e. be threadlocal)
*
* \param ctx The context of allocation.
* \param size The size to be allocated.
* \param nbytes The size to be allocated.
* \param type_hint The type of elements. Only needed by certain backends such
* as OpenGL, as nbytes is sufficient for most backends.
*/
TVM_DLL virtual void* AllocWorkspace(TVMContext ctx, size_t size);
TVM_DLL virtual void* AllocWorkspace(TVMContext ctx,
size_t nbytes,
TVMType type_hint = {});
/*!
* \brief Free temporal workspace in backend execution.
*
......
......@@ -24,6 +24,8 @@ void CodeGenOpenGL::InitFuncState(LoweredFunc f) {
inputs_.clear();
output_iter_var_ = nullptr;
thread_extent_var_ = "";
this->decl_stream.str("");
this->stream.str("");
}
void CodeGenOpenGL::AddFunction(LoweredFunc f) {
......
......@@ -197,10 +197,12 @@ void CodeGenStackVM::VisitExpr_(const Call* op) {
vm_.stack_size += size;
this->PushOp(StackVM::TVM_STACK_ALLOCA_BY_8BYTE, static_cast<int>(size));
} else if (op->name == "TVMBackendAllocWorkspace") {
CHECK_EQ(op->args.size(), 3U);
CHECK_EQ(op->args.size(), 5U);
this->Push(op->args[0]);
this->Push(op->args[1]);
this->Push(op->args[2]);
this->Push(op->args[3]);
this->Push(op->args[4]);
this->PushOp(StackVM::TVM_DEVICE_ALLOCA);
} else if (op->name == "TVMBackendFreeWorkspace") {
CHECK_EQ(op->args.size(), 3U);
......
......@@ -455,12 +455,15 @@ void StackVM::Run(State* s) const {
break;
}
case TVM_DEVICE_ALLOCA: {
int device_type = static_cast<int>(stack[sp - 2].v_int64);
int device_id = static_cast<int>(stack[sp - 1].v_int64);
size_t nbytes = static_cast<size_t>(stack[sp].v_int64);
void* ptr = TVMBackendAllocWorkspace(device_type, device_id, nbytes);
stack[sp - 2].v_handle = ptr;
sp = sp - 2;
int device_type = static_cast<int>(stack[sp - 4].v_int64);
int device_id = static_cast<int>(stack[sp - 3].v_int64);
size_t nbytes = static_cast<size_t>(stack[sp - 2].v_int64);
int dtype_code_hint = static_cast<int>(stack[sp - 1].v_int64);
int dtype_bits_hint = static_cast<int>(stack[sp].v_int64);
void* ptr = TVMBackendAllocWorkspace(device_type, device_id, nbytes,
dtype_code_hint, dtype_bits_hint);
stack[sp - 4].v_handle = ptr;
sp = sp - 4;
pc = pc + 1;
break;
}
......
......@@ -96,14 +96,18 @@ class BuiltinLower : public IRMutator {
{op->buffer_var}, Call::PureIntrinsic),
throw_last_error),
op->body);
Stmt alloca = LetStmt::make(op->buffer_var,
Call::make(op->buffer_var.type(),
"TVMBackendAllocWorkspace",
{cast(Int(32), device_type_),
cast(Int(32), device_id_),
cast(UInt(64), total_bytes)},
Call::Extern),
body);
Stmt alloca = LetStmt::make(
op->buffer_var,
Call::make(op->buffer_var.type(),
"TVMBackendAllocWorkspace",
{cast(Int(32), device_type_),
cast(Int(32), device_id_),
cast(UInt(64), total_bytes),
IntImm::make(Int(32), op->type.code()),
IntImm::make(Int(32), op->type.bits())},
Call::Extern),
body);
Expr free_op = Call::make(Int(32),
"TVMBackendFreeWorkspace",
......
......@@ -146,6 +146,11 @@ class IRUseDefAnalysis : public IRMutator {
class HostDeviceSplitter : public IRMutator {
public:
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
handle_data_type_[op->buffer_var.get()] = make_const(op->type, 0);
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const AttrStmt *op, const Stmt& s) final {
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::pipeline_exec_scope) {
......
......@@ -95,8 +95,9 @@ DeviceAPI* DeviceAPI::Get(TVMContext ctx, bool allow_missing) {
static_cast<int>(ctx.device_type), allow_missing);
}
void* DeviceAPI::AllocWorkspace(TVMContext ctx, size_t size) {
TVMType type_hint{kDLUInt, 8, 1};
void* DeviceAPI::AllocWorkspace(TVMContext ctx,
size_t size,
TVMType type_hint) {
return AllocDataSpace(ctx, size, kTempAllocaAlignment, type_hint);
}
......@@ -220,12 +221,22 @@ int TVMBackendGetFuncFromEnv(void* mod_node,
}
void* TVMBackendAllocWorkspace(int device_type,
int device_id,
uint64_t size) {
int device_id,
uint64_t size,
int dtype_code_hint,
int dtype_bits_hint) {
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
return DeviceAPIManager::Get(ctx)->AllocWorkspace(ctx, static_cast<size_t>(size));
TVMType type_hint;
type_hint.code = static_cast<decltype(type_hint.code)>(dtype_code_hint);
type_hint.bits = static_cast<decltype(type_hint.bits)>(dtype_bits_hint);
type_hint.lanes = 1;
return DeviceAPIManager::Get(ctx)->AllocWorkspace(ctx,
static_cast<size_t>(size),
type_hint);
}
int TVMBackendFreeWorkspace(int device_type,
......
......@@ -59,7 +59,7 @@ class CPUDeviceAPI final : public DeviceAPI {
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {
}
void* AllocWorkspace(TVMContext ctx, size_t size) final;
void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final;
void FreeWorkspace(TVMContext ctx, void* data) final;
static const std::shared_ptr<CPUDeviceAPI>& Global() {
......@@ -74,7 +74,9 @@ struct CPUWorkspacePool : public WorkspacePool {
WorkspacePool(kDLCPU, CPUDeviceAPI::Global()) {}
};
void* CPUDeviceAPI::AllocWorkspace(TVMContext ctx, size_t size) {
void* CPUDeviceAPI::AllocWorkspace(TVMContext ctx,
size_t size,
TVMType type_hint) {
return dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()
->AllocWorkspace(ctx, size);
}
......
......@@ -112,7 +112,7 @@ class CUDADeviceAPI final : public DeviceAPI {
->stream = static_cast<cudaStream_t>(stream);
}
void* AllocWorkspace(TVMContext ctx, size_t size) final {
void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final {
return CUDAThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
}
......
......@@ -77,7 +77,7 @@ class MetalWorkspace final : public DeviceAPI {
TVMContext ctx_to,
TVMStreamHandle stream) final;
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final;
void* AllocWorkspace(TVMContext ctx, size_t size) final;
void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final;
void FreeWorkspace(TVMContext ctx, void* data) final;
// get the global workspace
static const std::shared_ptr<MetalWorkspace>& Global();
......
......@@ -228,7 +228,9 @@ void MetalWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) {
[cb waitUntilCompleted];
}
void* MetalWorkspace::AllocWorkspace(TVMContext ctx, size_t size) {
void* MetalWorkspace::AllocWorkspace(TVMContext ctx,
size_t size,
TVMType type_hint) {
return MetalThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
}
......
......@@ -156,7 +156,7 @@ class OpenCLWorkspace final : public DeviceAPI {
TVMContext ctx_to,
TVMStreamHandle stream) final;
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final;
void* AllocWorkspace(TVMContext ctx, size_t size) final;
void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final;
void FreeWorkspace(TVMContext ctx, void* data) final;
// get the global workspace
static const std::shared_ptr<OpenCLWorkspace>& Global();
......
......@@ -108,7 +108,9 @@ void OpenCLWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) {
OPENCL_CALL(clFinish(this->GetQueue(ctx)));
}
void* OpenCLWorkspace::AllocWorkspace(TVMContext ctx, size_t size) {
void* OpenCLWorkspace::AllocWorkspace(TVMContext ctx,
size_t size,
TVMType type_hint) {
return OpenCLThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
}
......
......@@ -175,8 +175,6 @@ class OpenGLWorkspace final : public DeviceAPI {
TVMContext ctx_to,
TVMStreamHandle stream) final;
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final;
void* AllocWorkspace(TVMContext ctx, size_t size) final;
void FreeWorkspace(TVMContext ctx, void* data) final;
/*!
* \brief Get the global OpenGL workspace.
......
......@@ -156,15 +156,6 @@ void OpenGLWorkspace::CopyDataFromTo(const void* from,
void OpenGLWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) {}
void* OpenGLWorkspace::AllocWorkspace(TVMContext ctx, size_t size) {
LOG(FATAL) << "Cannot allocate OpenGL workspace.";
return nullptr;
}
void OpenGLWorkspace::FreeWorkspace(TVMContext ctx, void* data) {
LOG(FATAL) << "Cannot free OpenGL workspace.";
}
OpenGLWorkspace::OpenGLWorkspace() {
// Set an error handler.
// This can be called before glfwInit().
......
......@@ -110,7 +110,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
->stream = static_cast<hipStream_t>(stream);
}
void* AllocWorkspace(TVMContext ctx, size_t size) final {
void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final {
return ROCMThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
}
......
import tvm
import numpy as np
def test_local_multi_stage():
if not tvm.module.enabled("opengl"):
return
if not tvm.module.enabled("llvm"):
return
n = tvm.var("n")
A = tvm.placeholder((n,), name='A', dtype="int32")
B = tvm.compute((n,), lambda i: A[i] + 1, name="B")
C = tvm.compute((n,), lambda i: B[i] * 2, name="C")
s = tvm.create_schedule(C.op)
s[B].opengl()
s[C].opengl()
f = tvm.build(s, [A, C], "opengl", name="multi_stage")
ctx = tvm.opengl(0)
n = 10
a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx)
c = tvm.nd.array(np.random.uniform(size=(n,)).astype(B.dtype), ctx)
f(a, c)
np.testing.assert_allclose(c.asnumpy(), (a.asnumpy() + 1) * 2)
if __name__ == "__main__":
test_local_multi_stage()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment