Commit 944de73b by Zhixun Tan Committed by Tianqi Chen

Add type code and bits to AllocWorkspace. (#831)

parent eb8077ff
...@@ -44,14 +44,20 @@ TVM_DLL int TVMBackendRegisterSystemLibSymbol(const char* name, void* ptr); ...@@ -44,14 +44,20 @@ TVM_DLL int TVMBackendRegisterSystemLibSymbol(const char* name, void* ptr);
* *
* \note The result allocate spaced is ensured to be aligned to kTempAllocaAlignment. * \note The result allocate spaced is ensured to be aligned to kTempAllocaAlignment.
* *
* \param size The size of the space requested. * \param nbytes The size of the space requested.
* \param device_type The device type which the space will be allocated. * \param device_type The device type which the space will be allocated.
* \param device_id The device id which the space will be allocated. * \param device_id The device id which the space will be allocated.
* \param dtype_code_hint The type code of the array elements. Only used in
* certain backends such as OpenGL.
* \param dtype_bits_hint The type bits of the array elements. Only used in
* certain backends such as OpenGL.
* \return nullptr when error is thrown, a valid ptr if success * \return nullptr when error is thrown, a valid ptr if success
*/ */
TVM_DLL void* TVMBackendAllocWorkspace(int device_type, TVM_DLL void* TVMBackendAllocWorkspace(int device_type,
int device_id, int device_id,
uint64_t size); uint64_t nbytes,
int dtype_code_hint,
int dtype_bits_hint);
/*! /*!
* \brief Backend function to free temporal workspace. * \brief Backend function to free temporal workspace.
......
...@@ -114,9 +114,13 @@ class DeviceAPI { ...@@ -114,9 +114,13 @@ class DeviceAPI {
* - Workspace should not overlap between different threads(i.e. be threadlocal) * - Workspace should not overlap between different threads(i.e. be threadlocal)
* *
* \param ctx The context of allocation. * \param ctx The context of allocation.
* \param size The size to be allocated. * \param nbytes The size to be allocated.
* \param type_hint The type of elements. Only needed by certain backends such
* as OpenGL, as nbytes is sufficient for most backends.
*/ */
TVM_DLL virtual void* AllocWorkspace(TVMContext ctx, size_t size); TVM_DLL virtual void* AllocWorkspace(TVMContext ctx,
size_t nbytes,
TVMType type_hint = {});
/*! /*!
* \brief Free temporal workspace in backend execution. * \brief Free temporal workspace in backend execution.
* *
......
...@@ -24,6 +24,8 @@ void CodeGenOpenGL::InitFuncState(LoweredFunc f) { ...@@ -24,6 +24,8 @@ void CodeGenOpenGL::InitFuncState(LoweredFunc f) {
inputs_.clear(); inputs_.clear();
output_iter_var_ = nullptr; output_iter_var_ = nullptr;
thread_extent_var_ = ""; thread_extent_var_ = "";
this->decl_stream.str("");
this->stream.str("");
} }
void CodeGenOpenGL::AddFunction(LoweredFunc f) { void CodeGenOpenGL::AddFunction(LoweredFunc f) {
......
...@@ -197,10 +197,12 @@ void CodeGenStackVM::VisitExpr_(const Call* op) { ...@@ -197,10 +197,12 @@ void CodeGenStackVM::VisitExpr_(const Call* op) {
vm_.stack_size += size; vm_.stack_size += size;
this->PushOp(StackVM::TVM_STACK_ALLOCA_BY_8BYTE, static_cast<int>(size)); this->PushOp(StackVM::TVM_STACK_ALLOCA_BY_8BYTE, static_cast<int>(size));
} else if (op->name == "TVMBackendAllocWorkspace") { } else if (op->name == "TVMBackendAllocWorkspace") {
CHECK_EQ(op->args.size(), 3U); CHECK_EQ(op->args.size(), 5U);
this->Push(op->args[0]); this->Push(op->args[0]);
this->Push(op->args[1]); this->Push(op->args[1]);
this->Push(op->args[2]); this->Push(op->args[2]);
this->Push(op->args[3]);
this->Push(op->args[4]);
this->PushOp(StackVM::TVM_DEVICE_ALLOCA); this->PushOp(StackVM::TVM_DEVICE_ALLOCA);
} else if (op->name == "TVMBackendFreeWorkspace") { } else if (op->name == "TVMBackendFreeWorkspace") {
CHECK_EQ(op->args.size(), 3U); CHECK_EQ(op->args.size(), 3U);
......
...@@ -455,12 +455,15 @@ void StackVM::Run(State* s) const { ...@@ -455,12 +455,15 @@ void StackVM::Run(State* s) const {
break; break;
} }
case TVM_DEVICE_ALLOCA: { case TVM_DEVICE_ALLOCA: {
int device_type = static_cast<int>(stack[sp - 2].v_int64); int device_type = static_cast<int>(stack[sp - 4].v_int64);
int device_id = static_cast<int>(stack[sp - 1].v_int64); int device_id = static_cast<int>(stack[sp - 3].v_int64);
size_t nbytes = static_cast<size_t>(stack[sp].v_int64); size_t nbytes = static_cast<size_t>(stack[sp - 2].v_int64);
void* ptr = TVMBackendAllocWorkspace(device_type, device_id, nbytes); int dtype_code_hint = static_cast<int>(stack[sp - 1].v_int64);
stack[sp - 2].v_handle = ptr; int dtype_bits_hint = static_cast<int>(stack[sp].v_int64);
sp = sp - 2; void* ptr = TVMBackendAllocWorkspace(device_type, device_id, nbytes,
dtype_code_hint, dtype_bits_hint);
stack[sp - 4].v_handle = ptr;
sp = sp - 4;
pc = pc + 1; pc = pc + 1;
break; break;
} }
......
...@@ -96,14 +96,18 @@ class BuiltinLower : public IRMutator { ...@@ -96,14 +96,18 @@ class BuiltinLower : public IRMutator {
{op->buffer_var}, Call::PureIntrinsic), {op->buffer_var}, Call::PureIntrinsic),
throw_last_error), throw_last_error),
op->body); op->body);
Stmt alloca = LetStmt::make(op->buffer_var,
Call::make(op->buffer_var.type(), Stmt alloca = LetStmt::make(
"TVMBackendAllocWorkspace", op->buffer_var,
{cast(Int(32), device_type_), Call::make(op->buffer_var.type(),
cast(Int(32), device_id_), "TVMBackendAllocWorkspace",
cast(UInt(64), total_bytes)}, {cast(Int(32), device_type_),
Call::Extern), cast(Int(32), device_id_),
body); cast(UInt(64), total_bytes),
IntImm::make(Int(32), op->type.code()),
IntImm::make(Int(32), op->type.bits())},
Call::Extern),
body);
Expr free_op = Call::make(Int(32), Expr free_op = Call::make(Int(32),
"TVMBackendFreeWorkspace", "TVMBackendFreeWorkspace",
......
...@@ -146,6 +146,11 @@ class IRUseDefAnalysis : public IRMutator { ...@@ -146,6 +146,11 @@ class IRUseDefAnalysis : public IRMutator {
class HostDeviceSplitter : public IRMutator { class HostDeviceSplitter : public IRMutator {
public: public:
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
handle_data_type_[op->buffer_var.get()] = make_const(op->type, 0);
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const AttrStmt *op, const Stmt& s) final { Stmt Mutate_(const AttrStmt *op, const Stmt& s) final {
if (op->attr_key == attr::thread_extent || if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::pipeline_exec_scope) { op->attr_key == attr::pipeline_exec_scope) {
......
...@@ -95,8 +95,9 @@ DeviceAPI* DeviceAPI::Get(TVMContext ctx, bool allow_missing) { ...@@ -95,8 +95,9 @@ DeviceAPI* DeviceAPI::Get(TVMContext ctx, bool allow_missing) {
static_cast<int>(ctx.device_type), allow_missing); static_cast<int>(ctx.device_type), allow_missing);
} }
void* DeviceAPI::AllocWorkspace(TVMContext ctx, size_t size) { void* DeviceAPI::AllocWorkspace(TVMContext ctx,
TVMType type_hint{kDLUInt, 8, 1}; size_t size,
TVMType type_hint) {
return AllocDataSpace(ctx, size, kTempAllocaAlignment, type_hint); return AllocDataSpace(ctx, size, kTempAllocaAlignment, type_hint);
} }
...@@ -220,12 +221,22 @@ int TVMBackendGetFuncFromEnv(void* mod_node, ...@@ -220,12 +221,22 @@ int TVMBackendGetFuncFromEnv(void* mod_node,
} }
void* TVMBackendAllocWorkspace(int device_type, void* TVMBackendAllocWorkspace(int device_type,
int device_id, int device_id,
uint64_t size) { uint64_t size,
int dtype_code_hint,
int dtype_bits_hint) {
TVMContext ctx; TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type); ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id; ctx.device_id = device_id;
return DeviceAPIManager::Get(ctx)->AllocWorkspace(ctx, static_cast<size_t>(size));
TVMType type_hint;
type_hint.code = static_cast<decltype(type_hint.code)>(dtype_code_hint);
type_hint.bits = static_cast<decltype(type_hint.bits)>(dtype_bits_hint);
type_hint.lanes = 1;
return DeviceAPIManager::Get(ctx)->AllocWorkspace(ctx,
static_cast<size_t>(size),
type_hint);
} }
int TVMBackendFreeWorkspace(int device_type, int TVMBackendFreeWorkspace(int device_type,
......
...@@ -59,7 +59,7 @@ class CPUDeviceAPI final : public DeviceAPI { ...@@ -59,7 +59,7 @@ class CPUDeviceAPI final : public DeviceAPI {
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final { void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {
} }
void* AllocWorkspace(TVMContext ctx, size_t size) final; void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final;
void FreeWorkspace(TVMContext ctx, void* data) final; void FreeWorkspace(TVMContext ctx, void* data) final;
static const std::shared_ptr<CPUDeviceAPI>& Global() { static const std::shared_ptr<CPUDeviceAPI>& Global() {
...@@ -74,7 +74,9 @@ struct CPUWorkspacePool : public WorkspacePool { ...@@ -74,7 +74,9 @@ struct CPUWorkspacePool : public WorkspacePool {
WorkspacePool(kDLCPU, CPUDeviceAPI::Global()) {} WorkspacePool(kDLCPU, CPUDeviceAPI::Global()) {}
}; };
void* CPUDeviceAPI::AllocWorkspace(TVMContext ctx, size_t size) { void* CPUDeviceAPI::AllocWorkspace(TVMContext ctx,
size_t size,
TVMType type_hint) {
return dmlc::ThreadLocalStore<CPUWorkspacePool>::Get() return dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()
->AllocWorkspace(ctx, size); ->AllocWorkspace(ctx, size);
} }
......
...@@ -112,7 +112,7 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -112,7 +112,7 @@ class CUDADeviceAPI final : public DeviceAPI {
->stream = static_cast<cudaStream_t>(stream); ->stream = static_cast<cudaStream_t>(stream);
} }
void* AllocWorkspace(TVMContext ctx, size_t size) final { void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final {
return CUDAThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size); return CUDAThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
} }
......
...@@ -77,7 +77,7 @@ class MetalWorkspace final : public DeviceAPI { ...@@ -77,7 +77,7 @@ class MetalWorkspace final : public DeviceAPI {
TVMContext ctx_to, TVMContext ctx_to,
TVMStreamHandle stream) final; TVMStreamHandle stream) final;
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final; void StreamSync(TVMContext ctx, TVMStreamHandle stream) final;
void* AllocWorkspace(TVMContext ctx, size_t size) final; void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final;
void FreeWorkspace(TVMContext ctx, void* data) final; void FreeWorkspace(TVMContext ctx, void* data) final;
// get the global workspace // get the global workspace
static const std::shared_ptr<MetalWorkspace>& Global(); static const std::shared_ptr<MetalWorkspace>& Global();
......
...@@ -228,7 +228,9 @@ void MetalWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) { ...@@ -228,7 +228,9 @@ void MetalWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) {
[cb waitUntilCompleted]; [cb waitUntilCompleted];
} }
void* MetalWorkspace::AllocWorkspace(TVMContext ctx, size_t size) { void* MetalWorkspace::AllocWorkspace(TVMContext ctx,
size_t size,
TVMType type_hint) {
return MetalThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size); return MetalThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
} }
......
...@@ -156,7 +156,7 @@ class OpenCLWorkspace final : public DeviceAPI { ...@@ -156,7 +156,7 @@ class OpenCLWorkspace final : public DeviceAPI {
TVMContext ctx_to, TVMContext ctx_to,
TVMStreamHandle stream) final; TVMStreamHandle stream) final;
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final; void StreamSync(TVMContext ctx, TVMStreamHandle stream) final;
void* AllocWorkspace(TVMContext ctx, size_t size) final; void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final;
void FreeWorkspace(TVMContext ctx, void* data) final; void FreeWorkspace(TVMContext ctx, void* data) final;
// get the global workspace // get the global workspace
static const std::shared_ptr<OpenCLWorkspace>& Global(); static const std::shared_ptr<OpenCLWorkspace>& Global();
......
...@@ -108,7 +108,9 @@ void OpenCLWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) { ...@@ -108,7 +108,9 @@ void OpenCLWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) {
OPENCL_CALL(clFinish(this->GetQueue(ctx))); OPENCL_CALL(clFinish(this->GetQueue(ctx)));
} }
void* OpenCLWorkspace::AllocWorkspace(TVMContext ctx, size_t size) { void* OpenCLWorkspace::AllocWorkspace(TVMContext ctx,
size_t size,
TVMType type_hint) {
return OpenCLThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size); return OpenCLThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
} }
......
...@@ -175,8 +175,6 @@ class OpenGLWorkspace final : public DeviceAPI { ...@@ -175,8 +175,6 @@ class OpenGLWorkspace final : public DeviceAPI {
TVMContext ctx_to, TVMContext ctx_to,
TVMStreamHandle stream) final; TVMStreamHandle stream) final;
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final; void StreamSync(TVMContext ctx, TVMStreamHandle stream) final;
void* AllocWorkspace(TVMContext ctx, size_t size) final;
void FreeWorkspace(TVMContext ctx, void* data) final;
/*! /*!
* \brief Get the global OpenGL workspace. * \brief Get the global OpenGL workspace.
......
...@@ -156,15 +156,6 @@ void OpenGLWorkspace::CopyDataFromTo(const void* from, ...@@ -156,15 +156,6 @@ void OpenGLWorkspace::CopyDataFromTo(const void* from,
void OpenGLWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) {} void OpenGLWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) {}
void* OpenGLWorkspace::AllocWorkspace(TVMContext ctx, size_t size) {
LOG(FATAL) << "Cannot allocate OpenGL workspace.";
return nullptr;
}
void OpenGLWorkspace::FreeWorkspace(TVMContext ctx, void* data) {
LOG(FATAL) << "Cannot free OpenGL workspace.";
}
OpenGLWorkspace::OpenGLWorkspace() { OpenGLWorkspace::OpenGLWorkspace() {
// Set an error handler. // Set an error handler.
// This can be called before glfwInit(). // This can be called before glfwInit().
......
...@@ -110,7 +110,7 @@ class ROCMDeviceAPI final : public DeviceAPI { ...@@ -110,7 +110,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
->stream = static_cast<hipStream_t>(stream); ->stream = static_cast<hipStream_t>(stream);
} }
void* AllocWorkspace(TVMContext ctx, size_t size) final { void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final {
return ROCMThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size); return ROCMThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
} }
......
import tvm
import numpy as np
def test_local_multi_stage():
if not tvm.module.enabled("opengl"):
return
if not tvm.module.enabled("llvm"):
return
n = tvm.var("n")
A = tvm.placeholder((n,), name='A', dtype="int32")
B = tvm.compute((n,), lambda i: A[i] + 1, name="B")
C = tvm.compute((n,), lambda i: B[i] * 2, name="C")
s = tvm.create_schedule(C.op)
s[B].opengl()
s[C].opengl()
f = tvm.build(s, [A, C], "opengl", name="multi_stage")
ctx = tvm.opengl(0)
n = 10
a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx)
c = tvm.nd.array(np.random.uniform(size=(n,)).astype(B.dtype), ctx)
f(a, c)
np.testing.assert_allclose(c.asnumpy(), (a.asnumpy() + 1) * 2)
if __name__ == "__main__":
test_local_multi_stage()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment