Commit 9d84cb07 by Tianqi Chen Committed by GitHub

[RUNTIME] Add workspace pool (#229)

* [RUNTIME] Add workspace pool

* fix doc

* fix the free list

* avoid zero size
parent 5cdc8604
...@@ -239,12 +239,19 @@ constexpr const char* tvm_handle_is_null = "tvm_handle_is_null"; ...@@ -239,12 +239,19 @@ constexpr const char* tvm_handle_is_null = "tvm_handle_is_null";
/*! /*!
* \brief See pesudo code * \brief See pesudo code
* *
* void tvm_throw_last_error() {
* throw TVMGetLastError();
* }
*/
constexpr const char* tvm_throw_last_error = "tvm_throw_last_error";
/*!
* \brief See pesudo code
*
* dtype in {shape, array, arg_value, arg_tcode} * dtype in {shape, array, arg_value, arg_tcode}
* *
* Handle tvm_stack_alloca(string dtype, int num) { * Handle tvm_stack_alloca(string dtype, int num) {
* return new on stack dtype[num]; * return new on stack dtype[num];
* } * }
* \sa TVMStructFieldKind
*/ */
constexpr const char* tvm_stack_alloca = "tvm_stack_alloca"; constexpr const char* tvm_stack_alloca = "tvm_stack_alloca";
/*! /*!
......
...@@ -267,6 +267,14 @@ LoweredFunc MakeAPI(Stmt body, ...@@ -267,6 +267,14 @@ LoweredFunc MakeAPI(Stmt body,
bool is_restricted); bool is_restricted);
/*! /*!
* \brief Bind the device type of host function to be device_type.
* \param func The function to be binded.
* \param device_type The device type to be binded.
* \return The binded function.
*/
LoweredFunc BindDeviceType(LoweredFunc func,
int device_type);
/*!
* \brief Find undefined vars in the statment. * \brief Find undefined vars in the statment.
* \param stmt The function to be checked. * \param stmt The function to be checked.
* \param defs The vars that is defined. * \param defs The vars that is defined.
......
/*!
* Copyright (c) 2017 by Contributors
* \file c_backend_api.h
* \brief TVM runtime backend API.
*
* The functions defined in this header are intended to be
* used by compiled tvm operators, usually user do not need to use these
* function directly.
*/
#ifndef TVM_RUNTIME_C_BACKEND_API_H_
#define TVM_RUNTIME_C_BACKEND_API_H_
#include "./c_runtime_api.h"
#ifdef __cplusplus
TVM_EXTERN_C {
#endif
// Backend related functions.
/*!
* \brief Backend function for modules to get function
* from its environment mod_node (its imports and global function).
* The user do should not call TVMFuncFree on func.
*
* \param mod_node The module handle.
* \param func_name The name of the function.
* \param out The result function.
* \return 0 when no error is thrown, -1 when failure happens
*/
TVM_DLL int TVMBackendGetFuncFromEnv(void* mod_node,
const char* func_name,
TVMFunctionHandle *out);
/*!
* \brief Backend function to register system-wide library symbol.
*
* \param name The name of the symbol
* \param ptr The symbol address.
* \return 0 when no error is thrown, -1 when failure happens
*/
TVM_DLL int TVMBackendRegisterSystemLibSymbol(const char* name, void* ptr);
/*!
* \brief Backend function to allocate temporal workspace.
*
* \note The result allocate spaced is ensured to be aligned to kTempAllocaAlignment.
*
* \param size The size of the space requested.
* \param device_type The device type which the space will be allocated.
* \param device_id The device id which the space will be allocated.
* \return nullptr when error is thrown, a valid ptr if success
*/
TVM_DLL void* TVMBackendAllocWorkspace(int device_type,
int device_id,
uint64_t size);
/*!
* \brief Backend function to free temporal workspace.
*
* \param ptr The result allocated space pointer.
* \param device_type The device type which the space will be allocated.
* \param device_id The device id which the space will be allocated.
* \return 0 when no error is thrown, -1 when failure happens
*
* \sa TVMBackendAllocWorkspace
*/
TVM_DLL int TVMBackendFreeWorkspace(int device_type,
int device_id,
void* ptr);
/*!
* \brief Backend function for running parallel for loop.
*
* \param begin The start of iteration.
* \param end The end of iteration.
* \param lambda The lambda function to be executed.
* \param env The environment of lambda function.
*
* \return 0 when no error is thrown, -1 when failure happens
*/
TVM_DLL int TVMBackendParallelFor(
int64_t begin,
int64_t end,
int (*lambda)(int64_t begin, int64_t end, void* env),
void* env);
#ifdef __cplusplus
} // TVM_EXTERN_C
#endif
#endif // TVM_RUNTIME_C_BACKEND_API_H_
...@@ -332,55 +332,6 @@ TVM_DLL int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out); ...@@ -332,55 +332,6 @@ TVM_DLL int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out);
TVM_DLL int TVMFuncListGlobalNames(int *out_size, TVM_DLL int TVMFuncListGlobalNames(int *out_size,
const char*** out_array); const char*** out_array);
// Backend related functions.
/*!
* \brief Backend function for modules to get function
* from its environment mod_node (its imports and global function).
*
* The user do should not call TVMFuncFree on func.
*
* \note This API is supposed to be used by backend,
* it is not supposed to be used by user.
*
* \param mod_node The module handle.
* \param func_name The name of the function.
* \param out The result function.
* \return 0 when no error is thrown, -1 when failure happens
*/
TVM_DLL int TVMBackendGetFuncFromEnv(void* mod_node,
const char* func_name,
TVMFunctionHandle *out);
/*!
* \brief Backend function to register system-wide library symbol.
*
* \note This API is supposed to be used by backend,
* it is not supposed to be used by user.
*
* \param name The name of the symbol
* \param ptr The symbol address.
* \return 0 when no error is thrown, -1 when failure happens
*/
TVM_DLL int TVMBackendRegisterSystemLibSymbol(const char* name, void* ptr);
/*!
* \brief Backend function for running parallel for loop.
*
* \note This API is supposed to be used by backend,
* it is not supposed to be used by user.
*
* \param begin The start of iteration.
* \param end The end of iteration.
* \param lambda The lambda function to be executed.
* \param env The environment of lambda function.
*
* \return 0 when no error is thrown, -1 when failure happens
*/
TVM_DLL int TVMBackendParallelFor(
int64_t begin,
int64_t end,
int (*lambda)(int64_t begin, int64_t end, void* env),
void* env);
// Array related apis for quick proptyping // Array related apis for quick proptyping
/*! /*!
* \brief Allocate a nd-array's memory, * \brief Allocate a nd-array's memory,
...@@ -458,6 +409,7 @@ TVM_DLL int TVMSetStream(TVMContext ctx, TVMStreamHandle handle); ...@@ -458,6 +409,7 @@ TVM_DLL int TVMSetStream(TVMContext ctx, TVMStreamHandle handle);
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
TVM_DLL int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream); TVM_DLL int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream);
#ifdef __cplusplus #ifdef __cplusplus
} // TVM_EXTERN_C } // TVM_EXTERN_C
#endif #endif
......
...@@ -27,6 +27,9 @@ constexpr int kAllocAlignment = 64; ...@@ -27,6 +27,9 @@ constexpr int kAllocAlignment = 64;
/*! \brief Number of bytes each allocation must align to in temporary allocation */ /*! \brief Number of bytes each allocation must align to in temporary allocation */
constexpr int kTempAllocaAlignment = 64; constexpr int kTempAllocaAlignment = 64;
/*! \brief Maximum size that can be allocated on stack */
constexpr int kMaxStackAlloca = 1024;
/*! /*!
* \brief TVM Runtime Device API, abstracts the device * \brief TVM Runtime Device API, abstracts the device
* specific interface for memory management. * specific interface for memory management.
...@@ -96,6 +99,28 @@ class DeviceAPI { ...@@ -96,6 +99,28 @@ class DeviceAPI {
*/ */
virtual void SetStream(TVMContext ctx, TVMStreamHandle stream) {} virtual void SetStream(TVMContext ctx, TVMStreamHandle stream) {}
/*! /*!
* \brief Allocate temporal workspace for backend execution.
*
* \note We have the following assumption about backend temporal
* workspace allocation, and backend will optimize for such assumption:
*
* - Only a few allocation will happen, and space will be released after use.
* - The release order is usually in reverse order of allocate (stack style).
* - Repeative pattern of same allocations over different runs.
* - Workspace should not overlap between different threads(i.e. be threadlocal)
*
* \param ctx The context of allocation.
* \param size The size to be allocated.
*/
virtual void* AllocWorkspace(TVMContext ctx, size_t size);
/*!
* \brief Free temporal workspace in backend execution.
*
* \param ctx The context of allocation.
* \param ptr The pointer to be freed.
*/
virtual void FreeWorkspace(TVMContext ctx, void* ptr);
/*!
* \brief Get device API base don context. * \brief Get device API base don context.
* \param ctx The context * \param ctx The context
* \param allow_missing Whether allow missing * \param allow_missing Whether allow missing
......
...@@ -12,6 +12,7 @@ from . import ir_pass ...@@ -12,6 +12,7 @@ from . import ir_pass
from . import collections from . import collections
from . import module from . import module
from . import codegen from . import codegen
from . import ndarray
class BuildConfig(object): class BuildConfig(object):
"""Configuration scope to set a build config option. """Configuration scope to set a build config option.
...@@ -311,11 +312,16 @@ def build(sch, ...@@ -311,11 +312,16 @@ def build(sch,
fdevice.append(func) fdevice.append(func)
else: else:
raise ValueError("unknown function type %d" % func.func_type) raise ValueError("unknown function type %d" % func.func_type)
fhost = [ir_pass.LowerPackedCall(x) for x in fhost]
if not target.startswith("llvm") and target != "stackvm" and not fdevice: if not target.startswith("llvm") and target != "stackvm" and not fdevice:
raise ValueError( raise ValueError(
"Specified target %s, but cannot find device code, did you do bind?" % target) "Specified target %s, but cannot find device code, did you do bind?" % target)
device = "cpu" if target.startswith("llvm") or target == "stackvm" else target
device_type = ndarray.context(device, 0).device_type
fhost = [ir_pass.BindDeviceType(x, device_type) for x in fhost]
fhost = [ir_pass.LowerPackedCall(x) for x in fhost]
if fdevice: if fdevice:
if not target_host: if not target_host:
target_host = "llvm" if module.enabled("llvm") else "stackvm" target_host = "llvm" if module.enabled("llvm") else "stackvm"
......
...@@ -91,6 +91,7 @@ REGISTER_PASS1(VectorizeLoop); ...@@ -91,6 +91,7 @@ REGISTER_PASS1(VectorizeLoop);
REGISTER_PASS4(UnrollLoop); REGISTER_PASS4(UnrollLoop);
REGISTER_PASS2(StorageSync); REGISTER_PASS2(StorageSync);
REGISTER_PASS5(MakeAPI); REGISTER_PASS5(MakeAPI);
REGISTER_PASS2(BindDeviceType);
REGISTER_PASS1(SplitHostDevice); REGISTER_PASS1(SplitHostDevice);
REGISTER_PASS1(StorageRewrite); REGISTER_PASS1(StorageRewrite);
REGISTER_PASS1(InjectVirtualThread); REGISTER_PASS1(InjectVirtualThread);
......
...@@ -577,11 +577,16 @@ llvm::Value* CodeGenLLVM::CreateCallExtern(const Call* op) { ...@@ -577,11 +577,16 @@ llvm::Value* CodeGenLLVM::CreateCallExtern(const Call* op) {
} }
if (op->type.is_scalar()) { if (op->type.is_scalar()) {
llvm::Function* f = module_->getFunction(op->name); llvm::Function* f = module_->getFunction(op->name);
if (f) { if (f == nullptr) {
return builder_->CreateCall(f, arg_values); std::vector<llvm::Type*> arg_types;
} else { for (llvm::Value* v : arg_values) {
LOG(FATAL) << "cannot find function " << op->name; arg_types.push_back(v->getType());
}
f = llvm::Function::Create(
llvm::FunctionType::get(LLVMType(op->type), arg_types, false),
llvm::Function::ExternalLinkage, op->name, module_.get());
} }
return builder_->CreateCall(f, arg_values);
} else { } else {
llvm::Function* f = module_->getFunction(op->name); llvm::Function* f = module_->getFunction(op->name);
if (f) { if (f) {
...@@ -774,6 +779,9 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) { ...@@ -774,6 +779,9 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) {
return builder_->CreateLShr( return builder_->CreateLShr(
MakeValue(op->args[0]), MakeValue(op->args[1])); MakeValue(op->args[0]), MakeValue(op->args[1]));
} }
} else if (op->is_intrinsic(intrinsic::tvm_throw_last_error)) {
builder_->CreateRet(llvm::ConstantInt::getSigned(t_int32_, -1));
return ConstInt32(-1);
} else if (op->is_intrinsic(intrinsic::tvm_address_of)) { } else if (op->is_intrinsic(intrinsic::tvm_address_of)) {
const Load *l = op->args[0].as<Load>(); const Load *l = op->args[0].as<Load>();
CHECK(op->args.size() == 1 && l); CHECK(op->args.size() == 1 && l);
......
...@@ -196,6 +196,20 @@ void CodeGenStackVM::VisitExpr_(const Call* op) { ...@@ -196,6 +196,20 @@ void CodeGenStackVM::VisitExpr_(const Call* op) {
// add stack size to be safe. // add stack size to be safe.
vm_.stack_size += size; vm_.stack_size += size;
this->PushOp(StackVM::TVM_STACK_ALLOCA_BY_8BYTE, static_cast<int>(size)); this->PushOp(StackVM::TVM_STACK_ALLOCA_BY_8BYTE, static_cast<int>(size));
} else if (op->name == "TVMBackendAllocWorkspace") {
CHECK_EQ(op->args.size(), 3U);
this->Push(op->args[0]);
this->Push(op->args[1]);
this->Push(op->args[2]);
this->PushOp(StackVM::TVM_DEVICE_ALLOCA);
} else if (op->name == "TVMBackendFreeWorkspace") {
CHECK_EQ(op->args.size(), 3U);
this->Push(op->args[0]);
this->Push(op->args[1]);
this->Push(op->args[2]);
this->PushOp(StackVM::TVM_DEVICE_FREE);
} else if (op->is_intrinsic(intrinsic::tvm_throw_last_error)) {
this->PushOp(StackVM::TVM_THROW_LAST_ERROR);
} else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) { } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) {
CHECK_EQ(op->args.size(), 1U); CHECK_EQ(op->args.size(), 1U);
this->Push(op->args[0]); this->Push(op->args[0]);
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
*/ */
#include <dmlc/thread_local.h> #include <dmlc/thread_local.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/runtime/c_backend_api.h>
#include "./stack_vm.h" #include "./stack_vm.h"
namespace tvm { namespace tvm {
...@@ -136,6 +137,9 @@ int64_t StackVM::PrintCode(std::ostream& os, int64_t pc) const { ...@@ -136,6 +137,9 @@ int64_t StackVM::PrintCode(std::ostream& os, int64_t pc) const {
STACK_VM_PRINT_CODE2(TVM_STRUCT_SET); STACK_VM_PRINT_CODE2(TVM_STRUCT_SET);
// Allocate data by 8 bytes. // Allocate data by 8 bytes.
STACK_VM_PRINT_CODE1(TVM_STACK_ALLOCA_BY_8BYTE); STACK_VM_PRINT_CODE1(TVM_STACK_ALLOCA_BY_8BYTE);
STACK_VM_PRINT_CODE0(TVM_DEVICE_ALLOCA);
STACK_VM_PRINT_CODE0(TVM_DEVICE_FREE);
STACK_VM_PRINT_CODE0(TVM_THROW_LAST_ERROR);
// packed function. // packed function.
case CALL_PACKED_LOWERED: { case CALL_PACKED_LOWERED: {
int call_fid = code[pc + 1].v_int; int call_fid = code[pc + 1].v_int;
...@@ -450,6 +454,30 @@ void StackVM::Run(State* s) const { ...@@ -450,6 +454,30 @@ void StackVM::Run(State* s) const {
pc = pc + 2; pc = pc + 2;
break; break;
} }
case TVM_DEVICE_ALLOCA: {
int device_type = static_cast<int>(stack[sp - 2].v_int64);
int device_id = static_cast<int>(stack[sp - 1].v_int64);
size_t nbytes = static_cast<size_t>(stack[sp].v_int64);
void* ptr = TVMBackendAllocWorkspace(device_type, device_id, nbytes);
stack[sp - 2].v_handle = ptr;
sp = sp - 2;
pc = pc + 1;
break;
}
case TVM_DEVICE_FREE: {
int device_type = static_cast<int>(stack[sp - 2].v_int64);
int device_id = static_cast<int>(stack[sp - 1].v_int64);
void* ptr = stack[sp].v_handle;
int ret = TVMBackendFreeWorkspace(device_type, device_id, ptr);
stack[sp - 2].v_int64 = ret;
sp = sp - 2;
pc = pc + 1;
break;
}
case TVM_THROW_LAST_ERROR: {
LOG(FATAL) << TVMGetLastError();
break;
}
} }
CHECK_GE(sp, alloca_sp) << "touch allocated space"; CHECK_GE(sp, alloca_sp) << "touch allocated space";
CHECK_LT(sp, stack_cap) << "Stack overflow"; CHECK_LT(sp, stack_cap) << "Stack overflow";
......
...@@ -216,6 +216,34 @@ class StackVM { ...@@ -216,6 +216,34 @@ class StackVM {
*/ */
TVM_STACK_ALLOCA_BY_8BYTE, TVM_STACK_ALLOCA_BY_8BYTE,
/*! /*!
* \brief allocate data from device.
* \code
* device_type = stack[sp - 2].v_int64;
* device_id = stack[sp - 1].v_int64;
* nbytes = stack[sp].v_int64;
* stack[sp - 2].v_handle = device_alloca(device_type, device_id, nbytes);
* sp = sp - 2;
* pc = pc + 1;
* \endcode
*/
TVM_DEVICE_ALLOCA,
/*!
* \brief free data into device.
* \code
* device_type = stack[sp - 2].v_int64;
* device_id = stack[sp - 1].v_int64;
* ptr = stack[sp].v_handle;
* stack[sp - 2].v_int64 = device_free(device_type, device_id, ptr);
* sp = sp - 2;
* pc = pc + 1;
* \endcode
*/
TVM_DEVICE_FREE,
/*!
* \brief throw last error
*/
TVM_THROW_LAST_ERROR,
/*!
* \brief get data from structure. * \brief get data from structure.
* \code * \code
* index = code[pc + 1].v_int; * index = code[pc + 1].v_int;
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <unordered_set> #include <unordered_set>
#include "./ir_util.h" #include "./ir_util.h"
#include "../arithmetic/compute_expr.h"
namespace tvm { namespace tvm {
namespace ir { namespace ir {
...@@ -59,6 +60,57 @@ class PackedCallBuilder : public IRMutator { ...@@ -59,6 +60,57 @@ class PackedCallBuilder : public IRMutator {
} }
return stmt; return stmt;
} }
Stmt Mutate_(const Allocate* op, const Stmt& s) {
// Lower allocate to device allocate when needed.
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Allocate>();
// Get constant allocation bound.
int64_t dev_type;
int64_t nbytes = GetVectorBytes(op->type);
if (device_type_.defined()) {
if (arith::GetConst(device_type_, &dev_type)) {
if (dev_type == kCPU) {
int32_t constant_size = op->constant_allocation_size();
if (constant_size > 0 && constant_size * nbytes < runtime::kMaxStackAlloca) {
return stmt;
}
}
}
}
Expr total_bytes = make_const(op->extents[0].type(), nbytes);
for (size_t i = 0; i < op->extents.size(); ++i) {
total_bytes = total_bytes * op->extents[i];
}
Stmt throw_last_error = Evaluate::make(Call::make(Int(32),
intrinsic::tvm_throw_last_error, {},
Call::Intrinsic));
Stmt body = Block::make(
IfThenElse::make(Call::make(Bool(1),
intrinsic::tvm_handle_is_null,
{op->buffer_var}, Call::PureIntrinsic),
throw_last_error),
op->body);
Stmt alloca = LetStmt::make(op->buffer_var,
Call::make(op->buffer_var.type(),
"TVMBackendAllocWorkspace",
{cast(Int(32), device_type_),
cast(Int(32), device_id_),
cast(UInt(64), total_bytes)},
Call::Extern),
body);
Expr free_op = Call::make(Int(32),
"TVMBackendFreeWorkspace",
{cast(Int(32), device_type_),
cast(Int(32), device_id_),
op->buffer_var},
Call::Extern);
Stmt free_stmt = IfThenElse::make(free_op != make_zero(Int(32)), throw_last_error);
return Block::make(alloca, free_stmt);
}
Stmt Mutate_(const AttrStmt* op, const Stmt &s) final { Stmt Mutate_(const AttrStmt* op, const Stmt &s) final {
if (op->attr_key == attr::device_context_id) { if (op->attr_key == attr::device_context_id) {
CHECK(!device_id_.defined()); CHECK(!device_id_.defined());
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_visitor.h> #include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <tvm/buffer.h> #include <tvm/buffer.h>
#include <tvm/runtime/device_api.h> #include <tvm/runtime/device_api.h>
#include <vector> #include <vector>
...@@ -164,5 +165,37 @@ LoweredFunc MakeAPI(Stmt body, ...@@ -164,5 +165,37 @@ LoweredFunc MakeAPI(Stmt body,
} }
return f; return f;
} }
class DeviceTypeBinder: public IRMutator {
public:
explicit DeviceTypeBinder(int device_type)
: device_type_(device_type) {}
Stmt Mutate_(const AttrStmt* op, const Stmt &s) final {
if (op->attr_key == attr::device_context_type) {
if (const Variable* var = op->value.as<Variable>()) {
std::unordered_map<const Variable*, Expr> dmap;
Expr value = make_const(op->value.type(), device_type_);
dmap[var] = value;
Stmt body = Substitute(s, dmap);
std::ostringstream os;
os << "device_type need to be " << device_type_;
return AssertStmt::make(op->value == value, os.str(), body);
}
}
return IRMutator::Mutate_(op, s);
}
public:
int device_type_;
};
LoweredFunc BindDeviceType(LoweredFunc f,
int device_type) {
auto n = std::make_shared<LoweredFuncNode>(*f.operator->());
n->body = DeviceTypeBinder(device_type).Mutate(n->body);
return LoweredFunc(n);
}
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
*/ */
#include <dmlc/thread_local.h> #include <dmlc/thread_local.h>
#include <tvm/runtime/c_runtime_api.h> #include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/c_backend_api.h>
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <tvm/runtime/module.h> #include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
...@@ -93,6 +94,14 @@ DeviceAPI* DeviceAPI::Get(TVMContext ctx, bool allow_missing) { ...@@ -93,6 +94,14 @@ DeviceAPI* DeviceAPI::Get(TVMContext ctx, bool allow_missing) {
static_cast<int>(ctx.device_type), allow_missing); static_cast<int>(ctx.device_type), allow_missing);
} }
void* DeviceAPI::AllocWorkspace(TVMContext ctx, size_t size) {
return AllocDataSpace(ctx, size, kTempAllocaAlignment);
}
void DeviceAPI::FreeWorkspace(TVMContext ctx, void* ptr) {
FreeDataSpace(ctx, ptr);
}
inline TVMArray* TVMArrayCreate_() { inline TVMArray* TVMArrayCreate_() {
TVMArray* arr = new TVMArray(); TVMArray* arr = new TVMArray();
arr->shape = nullptr; arr->shape = nullptr;
...@@ -225,6 +234,25 @@ int TVMBackendGetFuncFromEnv(void* mod_node, ...@@ -225,6 +234,25 @@ int TVMBackendGetFuncFromEnv(void* mod_node,
API_END(); API_END();
} }
void* TVMBackendAllocWorkspace(int device_type,
int device_id,
uint64_t size) {
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
return DeviceAPIManager::Get(ctx)->AllocWorkspace(ctx, size);
}
int TVMBackendFreeWorkspace(int device_type,
int device_id,
void* ptr) {
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
DeviceAPIManager::Get(ctx)->FreeWorkspace(ctx, ptr);
return 0;
}
int TVMBackendParallelFor( int TVMBackendParallelFor(
int64_t begin, int64_t begin,
int64_t end, int64_t end,
......
...@@ -3,14 +3,15 @@ ...@@ -3,14 +3,15 @@
* \file cpu_device_api.cc * \file cpu_device_api.cc
*/ */
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/device_api.h> #include <tvm/runtime/device_api.h>
#include <cstdlib> #include <cstdlib>
#include <cstring> #include <cstring>
#include "./workspace_pool.h"
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
class CPUDeviceAPI final : public DeviceAPI { class CPUDeviceAPI final : public DeviceAPI {
public: public:
void SetDevice(TVMContext ctx) final {} void SetDevice(TVMContext ctx) final {}
...@@ -54,12 +55,34 @@ class CPUDeviceAPI final : public DeviceAPI { ...@@ -54,12 +55,34 @@ class CPUDeviceAPI final : public DeviceAPI {
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final { void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {
} }
void* AllocWorkspace(TVMContext ctx, size_t size) final;
void FreeWorkspace(TVMContext ctx, void* data) final;
static const std::shared_ptr<CPUDeviceAPI>& Global() {
static std::shared_ptr<CPUDeviceAPI> inst =
std::make_shared<CPUDeviceAPI>();
return inst;
}
}; };
struct CPUWorkspacePool : public WorkspacePool {
CPUWorkspacePool() :
WorkspacePool(kCPU, CPUDeviceAPI::Global()) {}
};
void* CPUDeviceAPI::AllocWorkspace(TVMContext ctx, size_t size) {
return dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()
->AllocWorkspace(ctx, size);
}
void CPUDeviceAPI::FreeWorkspace(TVMContext ctx, void* data) {
dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()->FreeWorkspace(ctx, data);
}
TVM_REGISTER_GLOBAL("device_api.cpu") TVM_REGISTER_GLOBAL("device_api.cpu")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
static CPUDeviceAPI inst; DeviceAPI* ptr = CPUDeviceAPI::Global().get();
DeviceAPI* ptr = &inst;
*rv = static_cast<void*>(ptr); *rv = static_cast<void*>(ptr);
}); });
} // namespace runtime } // namespace runtime
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#if TVM_CUDA_RUNTIME #if TVM_CUDA_RUNTIME
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include "../workspace_pool.h"
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
...@@ -39,6 +40,10 @@ class CUDAThreadEntry { ...@@ -39,6 +40,10 @@ class CUDAThreadEntry {
public: public:
/*! \brief The cuda stream */ /*! \brief The cuda stream */
cudaStream_t stream{nullptr}; cudaStream_t stream{nullptr};
/*! \brief thread local pool*/
WorkspacePool pool;
/*! \brief constructor */
CUDAThreadEntry();
// get the threadlocal workspace // get the threadlocal workspace
static CUDAThreadEntry* ThreadLocal(); static CUDAThreadEntry* ThreadLocal();
}; };
......
...@@ -98,6 +98,20 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -98,6 +98,20 @@ class CUDADeviceAPI final : public DeviceAPI {
->stream = static_cast<cudaStream_t>(stream); ->stream = static_cast<cudaStream_t>(stream);
} }
void* AllocWorkspace(TVMContext ctx, size_t size) final {
return CUDAThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
}
void FreeWorkspace(TVMContext ctx, void* data) final {
CUDAThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data);
}
static const std::shared_ptr<CUDADeviceAPI>& Global() {
static std::shared_ptr<CUDADeviceAPI> inst =
std::make_shared<CUDADeviceAPI>();
return inst;
}
private: private:
static void GPUCopy(const void* from, static void GPUCopy(const void* from,
void* to, void* to,
...@@ -114,14 +128,17 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -114,14 +128,17 @@ class CUDADeviceAPI final : public DeviceAPI {
typedef dmlc::ThreadLocalStore<CUDAThreadEntry> CUDAThreadStore; typedef dmlc::ThreadLocalStore<CUDAThreadEntry> CUDAThreadStore;
CUDAThreadEntry::CUDAThreadEntry()
: pool(kGPU, CUDADeviceAPI::Global()) {
}
CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() { CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() {
return CUDAThreadStore::Get(); return CUDAThreadStore::Get();
} }
TVM_REGISTER_GLOBAL("device_api.gpu") TVM_REGISTER_GLOBAL("device_api.gpu")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
static CUDADeviceAPI inst; DeviceAPI* ptr = CUDADeviceAPI::Global().get();
DeviceAPI* ptr = &inst;
*rv = static_cast<void*>(ptr); *rv = static_cast<void*>(ptr);
}); });
......
...@@ -73,8 +73,10 @@ class MetalWorkspace final : public DeviceAPI { ...@@ -73,8 +73,10 @@ class MetalWorkspace final : public DeviceAPI {
TVMContext ctx_to, TVMContext ctx_to,
TVMStreamHandle stream) final; TVMStreamHandle stream) final;
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final; void StreamSync(TVMContext ctx, TVMStreamHandle stream) final;
void* AllocWorkspace(TVMContext ctx, size_t size) final;
void FreeWorkspace(TVMContext ctx, void* data) final;
// get the global workspace // get the global workspace
static MetalWorkspace* Global(); static const std::shared_ptr<MetalWorkspace>& Global();
}; };
/*! \brief Thread local workspace */ /*! \brief Thread local workspace */
...@@ -84,8 +86,11 @@ class MetalThreadEntry { ...@@ -84,8 +86,11 @@ class MetalThreadEntry {
TVMContext context; TVMContext context;
/*! \brief The shared buffer used for copy. */ /*! \brief The shared buffer used for copy. */
std::vector<id<MTLBuffer> > temp_buffer_; std::vector<id<MTLBuffer> > temp_buffer_;
/*! \brief workspace pool */
MetalThreadEntry() { WorkspacePool pool;
// constructor
MetalThreadEntry()
: pool(static_cast<DLDeviceType>(kMetal), MetalWorkspace::Global()) {
context.device_id = 0; context.device_id = 0;
context.device_type = static_cast<DLDeviceType>(kMetal); context.device_type = static_cast<DLDeviceType>(kMetal);
} }
......
...@@ -215,6 +215,14 @@ void MetalWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) { ...@@ -215,6 +215,14 @@ void MetalWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) {
[cb waitUntilCompleted]; [cb waitUntilCompleted];
} }
void* MetalWorkspace::AllocWorkspace(TVMContext ctx, size_t size) {
return MetalThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
}
void MetalWorkspace::FreeWorkspace(TVMContext ctx, void* data) {
MetalThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data);
}
MetalThreadEntry::~MetalThreadEntry() { MetalThreadEntry::~MetalThreadEntry() {
for (auto x : temp_buffer_) { for (auto x : temp_buffer_) {
if (x != nil) [x release]; if (x != nil) [x release];
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include <mutex> #include <mutex>
#include <string> #include <string>
#include <vector> #include <vector>
#include "../workspace_pool.h"
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
...@@ -152,6 +153,8 @@ class OpenCLWorkspace final : public DeviceAPI { ...@@ -152,6 +153,8 @@ class OpenCLWorkspace final : public DeviceAPI {
TVMContext ctx_to, TVMContext ctx_to,
TVMStreamHandle stream) final; TVMStreamHandle stream) final;
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final; void StreamSync(TVMContext ctx, TVMStreamHandle stream) final;
void* AllocWorkspace(TVMContext ctx, size_t size) final;
void FreeWorkspace(TVMContext ctx, void* data) final;
// get the global workspace // get the global workspace
static const std::shared_ptr<OpenCLWorkspace>& Global(); static const std::shared_ptr<OpenCLWorkspace>& Global();
}; };
...@@ -171,8 +174,11 @@ class OpenCLThreadEntry { ...@@ -171,8 +174,11 @@ class OpenCLThreadEntry {
TVMContext context; TVMContext context;
/*! \brief The thread-local kernel table */ /*! \brief The thread-local kernel table */
std::vector<KTEntry> kernel_table; std::vector<KTEntry> kernel_table;
/*! \brief workspace pool */
OpenCLThreadEntry() { WorkspacePool pool;
// constructor
OpenCLThreadEntry()
: pool(kOpenCL, OpenCLWorkspace::Global()) {
context.device_id = 0; context.device_id = 0;
context.device_type = kOpenCL; context.device_type = kOpenCL;
} }
......
...@@ -107,6 +107,14 @@ void OpenCLWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) { ...@@ -107,6 +107,14 @@ void OpenCLWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) {
OPENCL_CALL(clFinish(this->GetQueue(ctx))); OPENCL_CALL(clFinish(this->GetQueue(ctx)));
} }
void* OpenCLWorkspace::AllocWorkspace(TVMContext ctx, size_t size) {
return OpenCLThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
}
void OpenCLWorkspace::FreeWorkspace(TVMContext ctx, void* data) {
OpenCLThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data);
}
typedef dmlc::ThreadLocalStore<OpenCLThreadEntry> OpenCLThreadStore; typedef dmlc::ThreadLocalStore<OpenCLThreadEntry> OpenCLThreadStore;
OpenCLThreadEntry* OpenCLThreadEntry::ThreadLocal() { OpenCLThreadEntry* OpenCLThreadEntry::ThreadLocal() {
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
* \brief SystemLib module. * \brief SystemLib module.
*/ */
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/c_backend_api.h>
#include <mutex> #include <mutex>
#include "./module_util.h" #include "./module_util.h"
......
/*!
* Copyright (c) 2017 by Contributors
* \file workspace_pool.h
* \brief Workspace pool utility.
*/
#include "./workspace_pool.h"
namespace tvm {
namespace runtime {
// page size.
constexpr size_t kWorkspacePageSize = 4 << 10;
class WorkspacePool::Pool {
public:
// constructor
Pool() {
// safe guard header on each list.
Entry e;
e.data = nullptr;
e.size = 0;
free_list_.push_back(e);
allocated_.push_back(e);
}
// allocate from pool
void* Alloc(TVMContext ctx, DeviceAPI* device, size_t size) {
// Allocate align to page.
size = (size + (kWorkspacePageSize - 1)) / kWorkspacePageSize * kWorkspacePageSize;
if (size == 0) size = kWorkspacePageSize;
Entry e;
if (free_list_.size() == 2) {
e = free_list_.back();
free_list_.pop_back();
if (e.size < size) {
// resize the page
device->FreeDataSpace(ctx, e.data);
e.data = device->AllocDataSpace(ctx, size, kTempAllocaAlignment);
e.size = size;
}
} else if (free_list_.size() == 1) {
e.data = device->AllocDataSpace(ctx, size, kTempAllocaAlignment);
e.size = size;
} else {
if (free_list_.back().size >= size) {
// find smallest fit
auto it = free_list_.end() - 2;
for (; it->size >= size; --it) {}
e = *(it + 1);
free_list_.erase(it + 1);
} else {
// resize the page
e = free_list_.back();
free_list_.pop_back();
device->FreeDataSpace(ctx, e.data);
e.data = device->AllocDataSpace(ctx, size, kTempAllocaAlignment);
e.size = size;
}
}
allocated_.push_back(e);
return e.data;
}
// free resource back to pool
void Free(void* data) {
Entry e;
if (allocated_.back().data == data) {
// quick path, last allocated.
e = allocated_.back();
allocated_.pop_back();
} else {
int index = static_cast<int>(allocated_.size()) - 2;
for (; index > 0 && allocated_[index].data != data; --index) {}
CHECK_GT(index, 0) << "trying to free things that has not been allocated";
e = allocated_[index];
allocated_.erase(allocated_.begin() + index);
}
if (free_list_.back().size < e.size) {
free_list_.push_back(e);
} else if (free_list_.size() == 2) {
free_list_.push_back(free_list_.back());
free_list_[1] = e;
} else {
size_t i = free_list_.size() - 1;
free_list_.resize(free_list_.size() + 1);
for (; e.size < free_list_[i].size; --i) {
free_list_[i + 1] = free_list_[i];
}
free_list_[i + 1] = e;
}
}
// Release all resources
void Release(TVMContext ctx, DeviceAPI* device) {
CHECK_EQ(allocated_.size(), 1);
for (size_t i = 1; i < free_list_.size(); ++i) {
device->FreeDataSpace(ctx, free_list_[i].data);
}
free_list_.clear();
}
private:
/*! \brief a single entry in the pool */
struct Entry {
void* data;
size_t size;
};
/*! \brief List of free items, sorted from small to big size */
std::vector<Entry> free_list_;
/*! \brief List of allocated items */
std::vector<Entry> allocated_;
};
WorkspacePool::WorkspacePool(DLDeviceType device_type, std::shared_ptr<DeviceAPI> device)
: device_type_(device_type), device_(device) {
}
WorkspacePool::~WorkspacePool() {
for (size_t i = 0; i < array_.size(); ++i) {
if (array_[i] != nullptr) {
TVMContext ctx;
ctx.device_type = device_type_;
ctx.device_id = static_cast<int>(i);
array_[i]->Release(ctx, device_.get());
delete array_[i];
}
}
}
void* WorkspacePool::AllocWorkspace(TVMContext ctx, size_t size) {
if (static_cast<size_t>(ctx.device_id) >= array_.size()) {
array_.resize(ctx.device_id + 1, nullptr);
}
if (array_[ctx.device_id] == nullptr) {
array_[ctx.device_id] = new Pool();
}
return array_[ctx.device_id]->Alloc(ctx, device_.get(), size);
}
void WorkspacePool::FreeWorkspace(TVMContext ctx, void* ptr) {
CHECK(static_cast<size_t>(ctx.device_id) < array_.size() &&
array_[ctx.device_id] != nullptr);
array_[ctx.device_id]->Free(ptr);
}
} // namespace runtime
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file workspace_pool.h
* \brief Workspace pool utility.
*/
#ifndef TVM_RUNTIME_WORKSPACE_POOL_H_
#define TVM_RUNTIME_WORKSPACE_POOL_H_
#include <tvm/runtime/device_api.h>
#include <vector>
namespace tvm {
namespace runtime {
/*!
* \brief A workspace pool to manage
*
* \note We have the following assumption about backend temporal
* workspace allocation, and will optimize for such assumption,
* some of these assumptions can be enforced by the compiler.
*
* - Only a few allocation will happen, and space will be released after use.
* - The release order is usually in reverse order of allocate
* - Repeative pattern of same allocations over different runs.
*/
class WorkspacePool {
public:
/*!
* \brief Create pool with specific device type and device.
* \param device_type The device type.
* \param device The device API.
*/
WorkspacePool(DLDeviceType device_type, std::shared_ptr<DeviceAPI> device);
/*! \brief destructor */
~WorkspacePool();
/*!
* \brief Allocate temporal workspace.
* \param ctx The context of allocation.
* \param The size to be allocated.
*/
void* AllocWorkspace(TVMContext ctx, size_t size);
/*!
* \brief Free temporal workspace in backend execution.
*
* \param ctx The context of allocation.
* \param ptr The pointer to be freed.
*/
void FreeWorkspace(TVMContext ctx, void* ptr);
private:
class Pool;
/*! \brief pool of device local array */
std::vector<Pool*> array_;
/*! \brief device type this pool support */
DLDeviceType device_type_;
/*! \brief The device API */
std::shared_ptr<DeviceAPI> device_;
};
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_WORKSPACE_POOL_H_
...@@ -7,7 +7,8 @@ def test_add_pipeline(): ...@@ -7,7 +7,8 @@ def test_add_pipeline():
A = tvm.placeholder((n,), name='A') A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B') B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
s = tvm.create_schedule(C.op) D = tvm.compute(A.shape, lambda *i: C(*i) + 1, name='C')
s = tvm.create_schedule(D.op)
# GPU schedule have to split by gridIdx and threadIdx # GPU schedule have to split by gridIdx and threadIdx
num_thread = 256 num_thread = 256
...@@ -15,6 +16,10 @@ def test_add_pipeline(): ...@@ -15,6 +16,10 @@ def test_add_pipeline():
s[C].bind(xo, tvm.thread_axis("threadIdx.x")) s[C].bind(xo, tvm.thread_axis("threadIdx.x"))
s[C].bind(xi, tvm.thread_axis("blockIdx.x")) s[C].bind(xi, tvm.thread_axis("blockIdx.x"))
xo, xi = s[D].split(D.op.axis[0], factor=num_thread)
s[D].bind(xo, tvm.thread_axis("threadIdx.x"))
s[D].bind(xi, tvm.thread_axis("blockIdx.x"))
# compile to IR # compile to IR
s = s.normalize() s = s.normalize()
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
......
...@@ -6,7 +6,9 @@ def test_llvm_add_pipeline(): ...@@ -6,7 +6,9 @@ def test_llvm_add_pipeline():
n = tvm.convert(nn) n = tvm.convert(nn)
A = tvm.placeholder((n,), name='A') A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B') B = tvm.placeholder((n,), name='B')
T = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='T') AA = tvm.compute((n,), lambda *i: A(*i), name='A')
BB = tvm.compute((n,), lambda *i: B(*i), name='B')
T = tvm.compute(A.shape, lambda *i: AA(*i) + BB(*i), name='T')
C = tvm.compute(A.shape, lambda *i: T(*i), name='C') C = tvm.compute(A.shape, lambda *i: T(*i), name='C')
s = tvm.create_schedule(C.op) s = tvm.create_schedule(C.op)
xo, xi = s[C].split(C.op.axis[0], factor=4) xo, xi = s[C].split(C.op.axis[0], factor=4)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment