Commit 41768cf9 by Tianqi Chen Committed by GitHub

[SCHEDULE][RUNIME] Introduce pragma for additional extension hint, threadpool runtime. (#299)

parent fd96d285
......@@ -32,23 +32,35 @@ def test_rpc_module():
n = tvm.convert(1024)
A = tvm.placeholder((n,), name='A')
B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
temp = util.tempdir()
s = tvm.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], factor=64)
s[B].bind(xi, tvm.thread_axis("threadIdx.x"))
s[B].bind(xo, tvm.thread_axis("blockIdx.x"))
temp = util.tempdir()
# Build the dynamic lib.
# If we don't want to do metal and only use cpu, just set target to be target
f = tvm.build(s, [A, B], "metal", target_host=target, name="myadd")
path_dso = temp.relpath("dev_lib.dylib")
f.export_library(path_dso, xcode.create_dylib,
path_dso1 = temp.relpath("dev_lib.dylib")
f.export_library(path_dso1, xcode.create_dylib,
arch=arch, sdk=sdk)
xcode.codesign(path_dso)
xcode.codesign(path_dso1)
s = tvm.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], factor=64)
s[B].parallel(xi)
s[B].pragma(xo, "parallel_launch_point")
s[B].pragma(xi, "parallel_barrier_when_finish")
f = tvm.build(s, [A, B], target, name="myadd_cpu")
path_dso2 = temp.relpath("cpu_lib.dylib")
f.export_library(path_dso2, xcode.create_dylib,
arch=arch, sdk=sdk)
xcode.codesign(path_dso2)
# Start RPC test server that contains the compiled library.
server = xcode.popen_test_rpc(proxy_host, proxy_port, key,
destination=destination,
libs=[path_dso],
options=["-quiet"])
libs=[path_dso1, path_dso2])
# connect to the proxy
remote = rpc.connect(proxy_host, proxy_port, key=key)
ctx = remote.metal(0)
......@@ -60,5 +72,15 @@ def test_rpc_module():
cost = time_f(a, b).mean
print('%g secs/op' % cost)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
# CPU
ctx = remote.cpu(0)
f2 = remote.load_module("cpu_lib.dylib")
a_np = np.random.uniform(size=1024).astype(A.dtype)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
time_f = f2.time_evaluator(f1.entry_name, ctx, number=10)
cost = time_f(a, b).mean
print('%g secs/op' % cost)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
test_rpc_module()
......@@ -7,6 +7,7 @@
#include "../../src/runtime/c_runtime_api.cc"
#include "../../src/runtime/cpu_device_api.cc"
#include "../../src/runtime/workspace_pool.cc"
#include "../../src/runtime/thread_pool.cc"
#include "../../src/runtime/module_util.cc"
#include "../../src/runtime/system_lib_module.cc"
#include "../../src/runtime/module.cc"
......
......@@ -45,7 +45,7 @@ tvm.ir_pass
tvm.ir_pass.StorageFlatten
tvm.ir_pass.VectorizeLoop
tvm.ir_pass.UnrollLoop
tvm.ir_pass.StorageSync
tvm.ir_pass.ThreadSync
tvm.ir_pass.StorageRewrite
tvm.ir_pass.MakeAPI
tvm.ir_pass.SplitHostDevice
......
......@@ -166,6 +166,8 @@ constexpr const char* device_context_type = "device_context_type";
constexpr const char* loop_scope = "loop_scope";
/*! \brief Mark of reduce scope */
constexpr const char* reduce_scope = "reduce_scope";
/*! \brief Mark region is guarded by the pragma */
constexpr const char* pragma_scope = "pragma_scope";
/*!
* \brief Mark of prefetch scope, value=offset,
* run prefetch of Tensor on the current loop scope
......
......@@ -66,21 +66,49 @@ TVM_DLL void* TVMBackendAllocWorkspace(int device_type,
TVM_DLL int TVMBackendFreeWorkspace(int device_type,
int device_id,
void* ptr);
/*!
* \brief Environment for TVM parallel task.
*/
typedef struct {
/*!
* \brief Auxiliary used for synchronization
*/
void* sync_handle;
/*! \brief total amount of task */
int32_t num_task;
} TVMParallelGroupEnv;
/*!
* \brief Backend function for running parallel for loop.
* \brief The callback function to execute a parallel lambda
* \param task_id the task id of the function.
* \param penv The parallel environment backs the execution.
* \param cdata The supporting closure data.
*/
typedef int (*FTVMParallelLambda)(
int task_id, TVMParallelGroupEnv* penv, void* cdata);
/*!
* \brief Backend function for running parallel jobs.
*
* \param begin The start of iteration.
* \param end The end of iteration.
* \param lambda The lambda function to be executed.
* \param env The environment of lambda function.
* \param flambda The parallel function to be launched.
* \param cdata The closure data.
* \param num_task Number of tasks to launch, can be 0, means launch
* with all available threads.
*
* \return 0 when no error is thrown, -1 when failure happens
*/
TVM_DLL int TVMBackendParallelFor(
int64_t begin,
int64_t end,
int (*lambda)(int64_t begin, int64_t end, void* env),
void* env);
TVM_DLL int TVMBackendParallelLaunch(FTVMParallelLambda flambda,
void* cdata,
int num_task);
/*!
* \brief BSP barrrier between parallel threads
* \param task_id the task id of the function.
* \param penv The parallel environment backs the execution.
* \return 0 when no error is thrown, -1 when failure happens
*/
TVM_DLL int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv);
#ifdef __cplusplus
} // TVM_EXTERN_C
......
......@@ -181,6 +181,15 @@ class Stage : public NodeRef {
*/
Stage& parallel(IterVar var); // NOLINT(*)
/*!
* \brief Annotate the iteration with pragma
*
* \param var The axis to be parallelized.
* \param pragma_type The pragma type.
*
* \return reference to self.
*/
Stage& pragma(IterVar var, const std::string& pragma_type); // NOLINT(*)
/*!
* \brief Fetch data in advance.
* \param domain the tensor to be prefetched
* \param var the iteration point at which to apply prefetching
......@@ -487,6 +496,10 @@ class IterVarAttrNode : public Node {
* when the axis is marked as Tensorized
*/
TensorIntrin tensor_intrin;
/*!
* \brief Additional pragmas, array of StringImm
*/
Array<Expr> pragmas;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("iter_type", &iter_type);
......@@ -494,6 +507,7 @@ class IterVarAttrNode : public Node {
v->Visit("prefetch_data", &prefetch_data);
v->Visit("prefetch_offset", &prefetch_offset);
v->Visit("tensor_intrin", &tensor_intrin);
v->Visit("pragmas", &pragmas);
}
static constexpr const char* _type_key = "IterVarAttr";
......
......@@ -78,7 +78,7 @@ class Module(ModuleBase):
file_name : str
The name of the shared library.
fcompile : function(target, file_list, **kwargs), optional
fcompile : function(target, file_list, kwargs), optional
Compilation function to use create dynamic library.
kwargs : dict, optiona;
......
......@@ -26,7 +26,7 @@ class Buffer(NodeBase):
WRITE = 2
def access_ptr(self, access_mask, ptr_type="handle"):
"""Get an access pointer to the head of buffer
"""Get an access pointer to the head of buffer.
This is the recommended method to get buffer data
ptress when interacting with external functions.
......@@ -37,7 +37,6 @@ class Buffer(NodeBase):
The access pattern MASK. Indicate whether the
access will read or write to the data content.
ptr_type : str, optional
The data type of the result pointer. Do not specify
unless we want to cast pointer to specific type.
......@@ -45,8 +44,8 @@ class Buffer(NodeBase):
Examples
--------
.. code-block:: python
import tvm.schedule.Buffer
import tvm.schedule.Buffer
# Get access ptr for read
buffer.access_ptr("r")
# Get access ptr for read/write with bitmask
......@@ -465,6 +464,48 @@ class Stage(NodeBase):
"""
_api_internal._StageParallel(self, var)
def pragma(self, var, pragma_type):
"""Annotate the iteration with pragma
This will translate to a pragma_scope surrounding
the corresponding loop generated.
Useful to support experimental features and extensions.
Parameters
----------
var : IterVar
The iteration to be anotated
pragma_type : str
The pragma string to be annotated
Note
----
Most pragmas are advanced/experimental features
and may subject to change. List of supported pragmas:
- **parallel_launch_point**
Specify to launch parallel threads outside the
specified iteration loop. By default the threads
launch at the point of parallel construct.
This pragma moves the launching point to even outer scope.
The threads are launched once and reused across multiple
parallel constructs as BSP style program.
- **parallel_barrier_when_finish**
Insert a synchronization barrier between working threads
after the specified loop iteration finishes.
- **parallel_stride_pattern**
Hint parallel loop to execute in strided pattern.
:code:`for (int i = task_id; i < end; i += num_task)`
"""
_api_internal._StagePragma(self, var, pragma_type)
def prefetch(self, tensor, var, offset):
"""Prefetch the specified variable
......
......@@ -364,6 +364,12 @@ TVM_REGISTER_API("_StageParallel")
.parallel(args[1]);
});
TVM_REGISTER_API("_StagePragma")
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage()
.pragma(args[1], args[2]);
});
TVM_REGISTER_API("_StagePrefetch")
.set_body([](TVMArgs args, TVMRetValue *ret) {
args[0].operator Stage()
......
......@@ -63,8 +63,14 @@ void CodeGenLLVM::Init(const std::string& module_name,
t_tvm_shape_index_->getPointerTo(),
t_int64_});
t_tvm_value_ = llvm::StructType::create({t_float64_});
ftype_tvm_par_for_lambda_ = llvm::FunctionType::get(
t_int_, {t_int64_, t_int64_, t_void_p_}, false);
t_tvm_parallel_group_env_ = llvm::StructType::create({
t_int32_->getPointerTo(),
t_int32_});
ftype_tvm_parallel_lambda_ = llvm::FunctionType::get(
t_int_,
{t_int_,
t_tvm_parallel_group_env_->getPointerTo(),
t_void_p_}, false);
md_builder_.reset(new llvm::MDBuilder(*ctx));
md_very_likely_branch_ =
md_builder_->createBranchWeights(1 << 30, 0);
......@@ -90,9 +96,13 @@ void CodeGenLLVM::Init(const std::string& module_name,
t_tvm_func_handle_->getPointerTo()}, false);
ftype_tvm_api_set_last_error_ = llvm::FunctionType::get(
t_void_, {t_char_->getPointerTo()}, false);
ftype_tvm_parallel_for_ =
ftype_tvm_parallel_launch_ =
llvm::FunctionType::get(t_int_, {
t_int64_, t_int64_, ftype_tvm_par_for_lambda_->getPointerTo(), t_void_p_}
ftype_tvm_parallel_lambda_->getPointerTo(), t_void_p_, t_int_}
, false);
ftype_tvm_parallel_barrier_ =
llvm::FunctionType::get(t_int_, {
t_int_, t_tvm_parallel_group_env_->getPointerTo()}
, false);
// initialize TVM runtime API
if (system_lib) {
......@@ -113,9 +123,12 @@ void CodeGenLLVM::Init(const std::string& module_name,
f_tvm_api_set_last_error_ = llvm::Function::Create(
ftype_tvm_api_set_last_error_,
llvm::Function::ExternalLinkage, "TVMAPISetLastError", module_.get());
f_tvm_parallel_for_ = llvm::Function::Create(
ftype_tvm_parallel_for_,
llvm::Function::ExternalLinkage, "TVMBackendParallelFor", module_.get());
f_tvm_parallel_launch_ = llvm::Function::Create(
ftype_tvm_parallel_launch_,
llvm::Function::ExternalLinkage, "TVMBackendParallelLaunch", module_.get());
f_tvm_parallel_barrier_ = llvm::Function::Create(
ftype_tvm_parallel_barrier_,
llvm::Function::ExternalLinkage, "TVMBackendParallelBarrier", module_.get());
}
this->InitTarget(tm);
// initialize builder
......@@ -179,8 +192,10 @@ void CodeGenLLVM::InitGlobalContext(bool dynamic_lookup) {
ftype_tvm_get_func_from_env_->getPointerTo(), "__TVMBackendGetFuncFromEnv");
gv_tvm_api_set_last_error_ = InitContextPtr(
ftype_tvm_api_set_last_error_->getPointerTo(), "__TVMAPISetLastError");
gv_tvm_parallel_for_ = InitContextPtr(
ftype_tvm_parallel_for_->getPointerTo(), "__TVMBackendParallelFor");
gv_tvm_parallel_launch_ = InitContextPtr(
ftype_tvm_parallel_launch_->getPointerTo(), "__TVMBackendParallelLaunch");
gv_tvm_parallel_barrier_ = InitContextPtr(
ftype_tvm_parallel_barrier_->getPointerTo(), "__TVMBackendParallelBarrier");
// Mark as context functions
gv_func_map_["TVMBackendAllocWorkspace"] = nullptr;
gv_func_map_["TVMBackendFreeWorkspace"] = nullptr;
......@@ -702,9 +717,14 @@ llvm::Value* CodeGenLLVM::RuntimeTVMAPISetLastError() {
if (f_tvm_api_set_last_error_ != nullptr) return f_tvm_api_set_last_error_;
return GetContextPtr(gv_tvm_api_set_last_error_);
}
llvm::Value* CodeGenLLVM::RuntimeTVMParallelFor() {
if (f_tvm_parallel_for_ != nullptr) return f_tvm_parallel_for_;
return GetContextPtr(gv_tvm_parallel_for_);
llvm::Value* CodeGenLLVM::RuntimeTVMParallelLaunch() {
if (f_tvm_parallel_launch_ != nullptr) return f_tvm_parallel_launch_;
return GetContextPtr(gv_tvm_parallel_launch_);
}
llvm::Value* CodeGenLLVM::RuntimeTVMParallelBarrier() {
if (f_tvm_parallel_barrier_ != nullptr) return f_tvm_parallel_barrier_;
return GetContextPtr(gv_tvm_parallel_barrier_);
}
llvm::Value* CodeGenLLVM::GetVarValue(const Variable* v) const {
......@@ -782,15 +802,9 @@ void CodeGenLLVM::CreateComputeScope(const AttrStmt* op) {
builder_->SetInsertPoint(compute_call_end);
}
void CodeGenLLVM::CreateParallelFor(const For* op) {
void CodeGenLLVM::CreateParallelLaunch(const Stmt& body, int num_task) {
using llvm::BasicBlock;
llvm::Value* min = MakeValue(op->min);
llvm::Value* extent = MakeValue(op->extent);
min = builder_->CreateIntCast(min, t_int64_, op->min.type().is_int());
extent = builder_->CreateIntCast(extent, t_int64_, op->min.type().is_int());
// fields to be packed into closure.
Var loop_var(op->loop_var.node_);
Array<Var> vfields = ir::UndefinedVars(op->body, {loop_var});
Array<Var> vfields = ir::UndefinedVars(body, {});
std::vector<llvm::Type*> fields;
for (Var v : vfields) {
auto it = var_map_.find(v.get());
......@@ -800,9 +814,9 @@ void CodeGenLLVM::CreateParallelFor(const For* op) {
// closure data
llvm::StructType* tcdata = llvm::StructType::create(fields);
llvm::Function* f = llvm::Function::Create(
ftype_tvm_par_for_lambda_,
ftype_tvm_parallel_lambda_,
llvm::Function::PrivateLinkage,
"__tvm_par_for_lambda", module_.get());
"__tvm_parallel_lambda", module_.get());
// allocate and setup the closure, call the closure.
llvm::Value* cdata = builder_->CreateAlloca(tcdata, ConstInt32(1));
llvm::Value* zero = ConstInt32(0);
......@@ -812,19 +826,17 @@ void CodeGenLLVM::CreateParallelFor(const For* op) {
var_map_.at(vfields[i].get()),
builder_->CreateInBoundsGEP(cdata, {zero, ConstInt32(i)}));
}
BasicBlock* par_for_end = CheckCallSuccess(
BasicBlock* par_launch_end = CheckCallSuccess(
builder_->CreateCall(
RuntimeTVMParallelFor(),
{min, extent, f, builder_->CreatePointerCast(cdata, t_void_p_)}));
RuntimeTVMParallelLaunch(),
{f, builder_->CreatePointerCast(cdata, t_void_p_), ConstInt32(num_task)}));
// Setup the closure function.
BasicBlock *lambda_entry = BasicBlock::Create(*ctx_, "entry", f);
builder_->SetInsertPoint(lambda_entry);
auto it = f->arg_begin();
llvm::Value* begin = &(*it++);
llvm::Value* end = &(*it++);
llvm::Value* task_id = &(*it++);
llvm::Value* penv = &(*it++);
cdata = &(*it++);
begin = CreateCast(Int(64), op->loop_var.type(), begin);
end = CreateCast(Int(64), op->loop_var.type(), end);
cdata = builder_->CreatePointerCast(cdata, tcdata->getPointerTo());
// setup new variable map, swap it with current var context.
std::unordered_map<const Variable*, llvm::Value*> new_vmap;
......@@ -833,17 +845,32 @@ void CodeGenLLVM::CreateParallelFor(const For* op) {
builder_->CreateLoad(builder_->CreateInBoundsGEP(
cdata, {zero, ConstInt32(i)}));
}
// setup parallel env
ParallelEnv par_env;
par_env.task_id = Var("task_id", Int(32));
par_env.num_task = Var("num_task", Int(32));
new_vmap[par_env.task_id.get()] = task_id;
new_vmap[par_env.num_task.get()] = builder_->CreateLoad(
builder_->CreateInBoundsGEP(
penv, {zero, ConstInt32(1)}));
par_env.penv = penv;
std::swap(function_, f);
std::swap(new_vmap, var_map_);
CreateSerialFor(begin, end, op->loop_var, op->body);
std::swap(parallel_env_, par_env);
std::swap(var_map_, new_vmap);
this->VisitStmt(body);
builder_->CreateRet(ConstInt32(0));
// swap the var map back, now we are back on track.
std::swap(new_vmap, var_map_);
std::swap(var_map_, new_vmap);
std::swap(parallel_env_, par_env);
std::swap(function_, f);
builder_->SetInsertPoint(par_for_end);
CHECK(par_env.hit_parallel_loop)
<< "Cannot find parallel loop within parallel launch";
builder_->SetInsertPoint(par_launch_end);
}
void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end,
void CodeGenLLVM::CreateSerialFor(llvm::Value* begin,
llvm::Value* end,
llvm::Value* stride,
const VarExpr& loop_var, const Stmt& body) {
using llvm::BasicBlock;
Type t = loop_var.type();
......@@ -864,7 +891,7 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end,
builder_->SetInsertPoint(for_body);
var_map_[loop_var.get()] = index;
this->VisitStmt(body);
llvm::Value* next_index = CreateAdd(t, index, ConstInt32(1));
llvm::Value* next_index = CreateAdd(t, index, stride);
index->addIncoming(next_index, builder_->GetInsertBlock());
builder_->CreateBr(for_head);
// end of for
......@@ -1481,10 +1508,45 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Call* op) {
void CodeGenLLVM::VisitStmt_(const For* op) {
CHECK(is_zero(op->min));
if (op->for_type == ForType::Serial) {
CreateSerialFor(ConstInt32(0), MakeValue(op->extent),
op->loop_var, op->body);
CreateSerialFor(ConstInt32(0),
MakeValue(op->extent),
ConstInt32(1),
op->loop_var,
op->body);
} else if (op->for_type == ForType::Parallel) {
CreateParallelFor(op);
if (parallel_env_.penv == nullptr) {
CreateParallelLaunch(
For::make(
op->loop_var, op->min, op->extent,
op->for_type, op->device_api, op->body), 0);
} else {
// already in parallel env.
CHECK(parallel_env_.task_id.defined());
CHECK(parallel_env_.num_task.defined());
CHECK(parallel_env_.penv != nullptr);
Type t = op->extent.type();
Expr num_task = cast(t, parallel_env_.num_task);
Expr task_id = cast(t, parallel_env_.task_id);
CHECK(!parallel_env_.hit_parallel_loop)
<< "Nested parallel loop is not supported by threadpool, try fuse them instead";
parallel_env_.hit_parallel_loop = true;
if (parallel_env_.stride_pattern) {
CreateSerialFor(MakeValue(task_id),
MakeValue(op->extent),
MakeValue(num_task),
op->loop_var,
op->body);
} else {
Expr step = (op->extent + num_task - make_const(t, 1)) / num_task;
Expr begin = Min::make(task_id * step, op->extent);
Expr end = Min::make((task_id + make_const(t, 1)) * step, op->extent);
CreateSerialFor(MakeValue(begin),
MakeValue(end),
ConstInt32(1),
op->loop_var,
op->body);
}
}
} else {
LOG(FATAL) << "cannot handle for type " << op->for_type;
}
......@@ -1566,6 +1628,29 @@ void CodeGenLLVM::VisitStmt_(const AttrStmt* op) {
this->VisitStmt(op->body);
} else if (op->attr_key == ir::attr::compute_scope) {
this->CreateComputeScope(op);
} else if (op->attr_key == ir::attr::pragma_scope) {
const std::string& pname = op->value.as<StringImm>()->value;
if (pname == "parallel_stride_pattern") {
CHECK(parallel_env_.penv != nullptr)
<< "Pragma parallel_stride_pattern only valid in parallel launch";
parallel_env_.stride_pattern = true;
this->VisitStmt(op->body);
} else if (pname == "parallel_launch_point") {
CreateParallelLaunch(op->body, 0);
} else if (pname == "parallel_barrier_when_finish") {
CHECK(parallel_env_.penv != nullptr)
<< "Cannot run barrier without parallel environment";
CHECK(!parallel_env_.hit_parallel_loop)
<< "Cannot not place within parallel loop as the workload may differ, "
<< " place it between parallel and parallel_launch_point";
this->VisitStmt(op->body);
builder_->CreateCall(
RuntimeTVMParallelBarrier(),
{MakeValue(parallel_env_.task_id), parallel_env_.penv});
} else {
LOG(WARNING) << "Unknown pragma " << pname;
this->VisitStmt(op->body);
}
} else {
this->VisitStmt(op->body);
}
......
......@@ -189,11 +189,13 @@ class CodeGenLLVM :
llvm::StructType* t_tvm_type_{nullptr};
llvm::StructType* t_tvm_array_{nullptr};
llvm::StructType* t_tvm_value_{nullptr};
llvm::FunctionType* ftype_tvm_par_for_lambda_{nullptr};
llvm::StructType* t_tvm_parallel_group_env_{nullptr};
llvm::FunctionType* ftype_tvm_parallel_lambda_{nullptr};
llvm::FunctionType* ftype_tvm_func_call_{nullptr};
llvm::FunctionType* ftype_tvm_get_func_from_env_{nullptr};
llvm::FunctionType* ftype_tvm_api_set_last_error_{nullptr};
llvm::FunctionType* ftype_tvm_parallel_for_{nullptr};
llvm::FunctionType* ftype_tvm_parallel_launch_{nullptr};
llvm::FunctionType* ftype_tvm_parallel_barrier_{nullptr};
llvm::FunctionType* ftype_tvm_register_system_symbol_{nullptr};
// The acting body
llvm::BasicBlock* block_{nullptr};
......@@ -203,13 +205,22 @@ class CodeGenLLVM :
std::unordered_map<const Variable*, StorageInfo> alloc_storage_info_;
private:
// the parallel group information
struct ParallelEnv {
VarExpr task_id;
VarExpr num_task;
bool stride_pattern{false};
bool hit_parallel_loop{false};
llvm::Value* penv{nullptr};
};
// Get runtime functions
llvm::GlobalVariable* InitContextPtr(llvm::Type* type, std::string name);
llvm::Value* GetContextPtr(llvm::GlobalVariable* gv);
llvm::Value* RuntimeTVMFuncCall();
llvm::Value* RuntimeTVMGetFuncFromEnv();
llvm::Value* RuntimeTVMAPISetLastError();
llvm::Value* RuntimeTVMParallelFor();
llvm::Value* RuntimeTVMParallelLaunch();
llvm::Value* RuntimeTVMParallelBarrier();
// comparison op
llvm::Value* GetVarValue(const Variable* v) const;
llvm::Value* CreateLT(Type t, llvm::Value* a, llvm::Value* b);
......@@ -230,10 +241,12 @@ class CodeGenLLVM :
llvm::Value* CreateVecFlip(llvm::Value* vec);
llvm::Value* CreateVecConcat(std::vector<llvm::Value*> vecs);
llvm::Value* CreateVecPad(llvm::Value* vec, int target_lanes);
// Create parallel for.
void CreateParallelFor(const For* op);
// Create parallel launch
void CreateParallelLaunch(const Stmt& body, int num_task);
// Create serial for
void CreateSerialFor(llvm::Value* begin, llvm::Value* end,
void CreateSerialFor(llvm::Value* begin,
llvm::Value* end,
llvm::Value* stride,
const VarExpr& loop_var, const Stmt& body);
// Create a new compute scope.
void CreateComputeScope(const AttrStmt* op);
......@@ -262,14 +275,18 @@ class CodeGenLLVM :
llvm::GlobalVariable* gv_tvm_func_call_{nullptr};
llvm::GlobalVariable* gv_tvm_get_func_from_env_{nullptr};
llvm::GlobalVariable* gv_tvm_api_set_last_error_{nullptr};
llvm::GlobalVariable* gv_tvm_parallel_for_{nullptr};
llvm::GlobalVariable* gv_tvm_parallel_launch_{nullptr};
llvm::GlobalVariable* gv_tvm_parallel_barrier_{nullptr};
std::unordered_map<std::string, llvm::GlobalVariable*> gv_func_map_;
// context for direct dynamic lookup
llvm::Function* f_tvm_func_call_{nullptr};
llvm::Function* f_tvm_get_func_from_env_{nullptr};
llvm::Function* f_tvm_api_set_last_error_{nullptr};
llvm::Function* f_tvm_parallel_for_{nullptr};
llvm::Function* f_tvm_parallel_launch_{nullptr};
llvm::Function* f_tvm_parallel_barrier_{nullptr};
llvm::Function* f_tvm_register_system_symbol_{nullptr};
// Current parallel environment scope.
ParallelEnv parallel_env_;
// global to packed function handle
std::unordered_map<std::string, llvm::GlobalVariable*> func_handle_map_;
// List of symbols to be exported to TVM system lib.
......
......@@ -70,6 +70,10 @@ MakeLoopNest(const Stage& stage,
<< it_attr->iter_type
<< " in the iter_var_attrs";
}
for (Expr p : it_attr->pragmas) {
nest[i + 1].emplace_back(
AttrStmt::make(iv, ir::attr::pragma_scope, p, no_op));
}
}
if (is_one(dom->extent)) {
nest[i + 1].emplace_back(
......
......@@ -14,8 +14,6 @@
#include <algorithm>
#include <string>
#include <cstdlib>
#include <thread>
#include <mutex>
#include "./runtime_base.h"
namespace tvm {
......@@ -158,24 +156,6 @@ struct TVMRuntimeEntry {
std::string ret_str;
std::string last_error;
TVMByteArray ret_bytes;
// threads used in parallel for
std::vector<std::thread> par_threads;
// errors created in parallel for.
std::vector<std::string> par_errors;
// number of parallel threads
int num_par_threads{1};
TVMRuntimeEntry() {
const char *val = getenv("TVM_NUM_THREADS");
if (val == nullptr) {
val = getenv("OMP_NUM_THREADS");
}
if (val != nullptr) {
num_par_threads = atoi(val);
} else {
num_par_threads = std::thread::hardware_concurrency() / 2;
}
}
};
typedef dmlc::ThreadLocalStore<TVMRuntimeEntry> TVMAPIRuntimeStore;
......@@ -254,46 +234,6 @@ int TVMBackendFreeWorkspace(int device_type,
return 0;
}
int TVMBackendParallelFor(
int64_t begin,
int64_t end,
int (*lambda)(int64_t begin, int64_t end, void* env),
void* env) {
TVMRuntimeEntry* rt = TVMAPIRuntimeStore::Get();
int nthread = rt->num_par_threads;
rt->par_threads.resize(nthread);
rt->par_errors.clear();
rt->par_errors.resize(nthread);
int64_t step = (end - begin + nthread - 1) / nthread;
auto fexec = [lambda, env, begin, end, step, rt](int i) {
int64_t ibegin = std::min(end, begin + step * i);
int64_t iend = std::min(end, begin + step * (i + 1));
int rv = (*lambda)(ibegin, iend, env);
if (rv != 0) {
std::ostringstream os;
os << "Thread " << i << " error:" << TVMGetLastError();
rt->par_errors[i] = os.str();
}
};
for (int i = 0; i < nthread; ++i) {
rt->par_threads[i] = std::thread(fexec, i);
}
int ret = 0;
for (int i = 0; i < nthread; ++i) {
rt->par_threads[i].join();
if (rt->par_errors[i].length() != 0) ret = -1;
}
if (ret == 0) return ret;
std::ostringstream os;
for (int i = 0; i < nthread; ++i) {
if (rt->par_errors[i].length() != 0) {
os << rt->par_errors[i] << '\n';
}
}
rt->last_error = os.str();
return -1;
}
int TVMFuncFree(TVMFunctionHandle func) {
API_BEGIN();
delete static_cast<PackedFunc*>(func);
......
......@@ -40,30 +40,21 @@ void ImportModuleBlob(const char* mblob, std::vector<Module>* module_list);
*/
template<typename FLookup>
void InitContextFunctions(FLookup flookup) {
if (auto *fp = reinterpret_cast<decltype(&TVMFuncCall)*>
(flookup("__TVMFuncCall"))) {
*fp = TVMFuncCall;
}
if (auto *fp = reinterpret_cast<decltype(&TVMAPISetLastError)*>
(flookup("__TVMAPISetLastError"))) {
*fp = TVMAPISetLastError;
}
if (auto *fp = reinterpret_cast<decltype(&TVMBackendGetFuncFromEnv)*>
(flookup("__TVMBackendGetFuncFromEnv"))) {
*fp = TVMBackendGetFuncFromEnv;
}
if (auto *fp = reinterpret_cast<decltype(&TVMBackendAllocWorkspace)*>
(flookup("__TVMBackendAllocWorkspace"))) {
*fp = TVMBackendAllocWorkspace;
}
if (auto *fp = reinterpret_cast<decltype(&TVMBackendFreeWorkspace)*>
(flookup("__TVMBackendFreeWorkspace"))) {
*fp = TVMBackendFreeWorkspace;
}
if (auto *fp = reinterpret_cast<decltype(&TVMBackendParallelFor)*>
(flookup("__TVMBackendParallelFor"))) {
*fp = TVMBackendParallelFor;
}
#define TVM_INIT_CONTEXT_FUNC(FuncName) \
if (auto *fp = reinterpret_cast<decltype(&FuncName)*> \
(flookup("__" #FuncName))) { \
*fp = FuncName; \
}
// Initialize the functions
TVM_INIT_CONTEXT_FUNC(TVMFuncCall);
TVM_INIT_CONTEXT_FUNC(TVMAPISetLastError);
TVM_INIT_CONTEXT_FUNC(TVMBackendGetFuncFromEnv);
TVM_INIT_CONTEXT_FUNC(TVMBackendAllocWorkspace);
TVM_INIT_CONTEXT_FUNC(TVMBackendFreeWorkspace);
TVM_INIT_CONTEXT_FUNC(TVMBackendParallelLaunch);
TVM_INIT_CONTEXT_FUNC(TVMBackendParallelBarrier);
#undef TVM_INIT_CONTEXT_FUNC
}
} // namespace runtime
} // namespace tvm
......
/*!
* Copyright (c) 2017 by Contributors
* \file thread_pool.cc
* \brief Threadpool for multi-threading runtime.
*/
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/c_backend_api.h>
#include <dmlc/thread_local.h>
#include <dmlc/logging.h>
#include <thread>
#include <condition_variable>
#include <mutex>
#include <atomic>
#include <vector>
#include <string>
#include <cstring>
#include <memory>
#include <sstream>
namespace tvm {
namespace runtime {
// stride in the page, fit to cache line.
constexpr int kSyncStride = 64 / sizeof(std::atomic<int>);
/*!
* \brief Thread local master environment.
*/
class ParallelLauncher {
public:
// Reset the the task request.
void Init(FTVMParallelLambda flambda,
void* cdata,
int num_task,
bool need_sync) {
std::lock_guard<std::mutex> lock(mutex_);
num_pending_ = num_task;
this->cdata = cdata;
this->flambda = flambda;
this->env.num_task = num_task;
has_error_ = false;
// reshape
if (static_cast<size_t>(num_task) > par_errors_.size()) {
par_errors_.resize(num_task + 1);
if (need_sync) {
delete[] sync_counter_;
sync_counter_ = new std::atomic<int>[num_task * kSyncStride];
}
}
if (need_sync) {
for (int i = 0; i < num_task; ++i) {
sync_counter_[i * kSyncStride].store(
0, std::memory_order_relaxed);
}
this->env.sync_handle = sync_counter_;
} else {
this->env.sync_handle = nullptr;
}
}
~ParallelLauncher() {
delete[] sync_counter_;
}
// Wait n jobs to finish
int WaitForJobs() {
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [this] {
return num_pending_ == 0;
});
if (!has_error_) return 0;
std::ostringstream os;
for (size_t i = 0; i < par_errors_.size(); ++i) {
if (par_errors_[i].length() != 0) {
os << "Task " << i << " error: " << par_errors_[i] << '\n';
par_errors_[i].clear();
}
}
TVMAPISetLastError(os.str().c_str());
return -1;
}
// Signal that one job has finished.
void SignalJobError(int task_id) {
std::unique_lock<std::mutex> lock(mutex_);
--num_pending_;
par_errors_[task_id] = TVMGetLastError();
has_error_ = true;
if (num_pending_ == 0) {
lock.unlock();
cv_.notify_one();
}
}
// Signal that one job has finished.
void SignalJobFinish() {
std::unique_lock<std::mutex> lock(mutex_);
--num_pending_;
if (num_pending_ == 0) {
lock.unlock();
cv_.notify_one();
}
}
// Get thread local version of the store.
static ParallelLauncher* ThreadLocal() {
return dmlc::ThreadLocalStore<ParallelLauncher>::Get();
}
// The parallel lambda
FTVMParallelLambda flambda;
// The closure data
void* cdata;
// Local env
TVMParallelGroupEnv env;
// Whether this thread is worker of the pool.
// used to prevent recursive launch.
bool is_worker{false};
private:
// The mutex to access local env.
std::mutex mutex_;
// The conditional variable.
std::condition_variable cv_;
// The pending jobs.
uint32_t num_pending_;
// Whether error has been countered.
bool has_error_;
// The counter page.
std::atomic<int32_t>* sync_counter_{nullptr};
// The error message
std::vector<std::string> par_errors_;
};
/*! \brief Working queue for each thread */
class ParallelTaskQueue {
public:
/*! \brief The task entry */
struct Task {
ParallelLauncher* launcher;
int32_t task_id;
};
ParallelTaskQueue() {
ring_.resize(2);
}
/*!
* \brief Signal to kill the job.
*/
void SignalForKill() {
std::lock_guard<std::mutex> lock(mutex_);
exit_now_.store(true);
cv_.notify_all();
}
/*!
* \brief Push task into the queue.
* \param task The task to be pushed.
*/
void Push(Task task) {
std::unique_lock<std::mutex> lock(mutex_);
if (num_pending_ < ring_.size()) {
CHECK_NE(ring_.size(), 0U);
ring_[(head_ + num_pending_) % ring_.size()] = task;
++num_pending_;
} else {
size_t old_size = ring_.size();
ring_.resize(old_size * 2);
if (head_ + num_pending_ > old_size) {
// copy the ring overflow part into the tail.
size_t ncopy = head_ + num_pending_ - old_size;
memcpy(&ring_[0] + old_size, &ring_[0], ncopy * sizeof(Task));
}
ring_[(head_ + num_pending_) % ring_.size()] = task;
++num_pending_;
}
if (nwait_consumer_ != 0) {
lock.unlock();
cv_.notify_one();
}
}
/*!
* \brief Pop task from the queue
* \param task The task to be poped.
* \param timeout The number of cycles to spin before sleep.
* \return Whether pop is successful or we need to exit now.
*/
bool Pop(Task* task, int timeout = 10) {
std::unique_lock<std::mutex> lock(mutex_);
if (num_pending_ != 0) {
*task = ring_[head_];
head_ = (head_ + 1) % ring_.size();
--num_pending_;
if (exit_now_.load()) return false;
} else {
lock.unlock();
// do a bit spin and busy waiting before sleep.
for (int i = 0; i < timeout && num_pending_ == 0; ++i) {
std::this_thread::yield();
}
lock.lock();
++nwait_consumer_;
cv_.wait(lock, [this] {
return num_pending_ != 0 || exit_now_.load();
});
--nwait_consumer_;
*task = ring_[head_];
head_ = (head_ + 1) % ring_.size();
--num_pending_;
if (exit_now_.load()) return false;
}
return true;
}
private:
// Number of the elments in the queue
uint32_t num_pending_{0};
// Queue head
uint32_t head_{0};
// Number of consumers to wait.
uint32_t nwait_consumer_{0};
// internal mutex
std::mutex mutex_;
// cv for consumer
std::condition_variable cv_;
// signal for exit now
std::atomic<bool> exit_now_{false};
// The internal ring.
std::vector<Task> ring_;
};
// The thread pool
class ThreadPool {
public:
ThreadPool() {
const char *val = getenv("TVM_NUM_THREADS");
if (val == nullptr) {
val = getenv("OMP_NUM_THREADS");
}
if (val != nullptr) {
num_workers_ = atoi(val);
} else {
#if defined(_M_X64) || defined(__x86_64__)
// Half to not count hyper threading.
num_workers_ = std::thread::hardware_concurrency() / 2;
#else
num_workers_ = std::thread::hardware_concurrency();
#endif
}
num_workers_ = std::max(num_workers_, 1);
this->Init();
}
~ThreadPool() {
for (std::unique_ptr<ParallelTaskQueue>& q : queues_) {
q->SignalForKill();
}
for (std::thread& t : threads_) {
t.join();
}
}
int Launch(FTVMParallelLambda flambda,
void* cdata,
int num_task,
int need_sync) {
ParallelLauncher* launcher = ParallelLauncher::ThreadLocal();
CHECK(!launcher->is_worker)
<< "Cannot launch parallel job inside worker, consider fuse then parallel";
if (num_task == 0) {
num_task = num_workers_;
}
if (need_sync != 0) {
CHECK_LE(num_task, num_workers_)
<< "Request parallel sync task larger than number of threads available "
<< " workers=" << num_workers_ << " request=" << num_task;
}
launcher->Init(flambda, cdata, num_task, need_sync != 0);
ParallelTaskQueue::Task tsk;
tsk.launcher = launcher;
for (int i = 0; i < num_task; ++i) {
tsk.task_id = i;
queues_[i]->Push(tsk);
}
return launcher->WaitForJobs();
}
static ThreadPool* Global() {
static ThreadPool inst;
return &inst;
}
private:
// Initialize the pool.
void Init() {
for (int i = 0; i < num_workers_; ++i) {
queues_.emplace_back(
std::unique_ptr<ParallelTaskQueue>(new ParallelTaskQueue()));
}
threads_.resize(num_workers_);
for (int i = 0; i < num_workers_; ++i) {
threads_[i] = std::thread([this, i] {
this->RunWorker(queues_[i].get());
});
}
}
// Internal worker function.
void RunWorker(ParallelTaskQueue* queue) {
ParallelTaskQueue::Task task;
ParallelLauncher::ThreadLocal()->is_worker = true;
while (queue->Pop(&task)) {
CHECK(task.launcher != nullptr);
TVMParallelGroupEnv* penv = &(task.launcher->env);
void* cdata = task.launcher->cdata;
if ((*task.launcher->flambda)(task.task_id, penv, cdata) == 0) {
task.launcher->SignalJobFinish();
} else {
task.launcher->SignalJobError(task.task_id);
}
}
}
// Number of workers
int num_workers_;
std::vector<std::unique_ptr<ParallelTaskQueue> > queues_;
std::vector<std::thread> threads_;
};
} // namespace runtime
} // namespace tvm
int TVMBackendParallelLaunch(
FTVMParallelLambda flambda,
void* cdata,
int num_task) {
return tvm::runtime::ThreadPool::Global()->Launch(
flambda, cdata, num_task, 1);
}
int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) {
using tvm::runtime::kSyncStride;
int num_task = penv->num_task;
std::atomic<int>* sync_counter =
reinterpret_cast<std::atomic<int>*>(penv->sync_handle);
int old_counter = sync_counter[task_id * kSyncStride].fetch_add(
1, std::memory_order_release);
for (int i = 0; i < num_task; ++i) {
if (i != task_id) {
while (sync_counter[i * kSyncStride].load(
std::memory_order_relaxed) <= old_counter) {
std::this_thread::yield();
}
}
}
std::atomic_thread_fence(std::memory_order_acquire);
return 0;
}
......@@ -340,6 +340,19 @@ Stage& Stage::parallel(IterVar var) { // NOLINT(*)
return *this;
}
Stage& Stage::pragma(IterVar var, const std::string& pragma_type) { // NOLINT(*)
if (pragma_type == "unroll") {
this->unroll(var);
} else if (pragma_type == "vectorize") {
this->vectorize(var);
} else {
UpdateIterVarAttr(operator->(), var, [pragma_type](IterVarAttrNode* n) {
n->pragmas.push_back(ir::StringImm::make(pragma_type));
});
}
return *this;
}
Stage& Stage::prefetch(const Tensor &tensor, IterVar var, Expr offset) {
StageNode *self = operator->();
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
......
......@@ -28,8 +28,13 @@ def test_llvm_add_pipeline():
C = tvm.compute(A.shape, lambda *i: T(*i), name='C')
s = tvm.create_schedule(C.op)
xo, xi = s[C].split(C.op.axis[0], factor=4)
s[C].parallel(xo)
xo1, xo2 = s[C].split(xo, factor=13)
s[C].parallel(xo2)
s[C].pragma(xo1, "parallel_launch_point")
s[C].pragma(xo2, "parallel_stride_pattern")
s[C].pragma(xo2, "parallel_barrier_when_finish")
s[C].vectorize(xi)
def check_llvm():
if not tvm.module.enabled("llvm"):
return
......@@ -167,9 +172,9 @@ def test_multiple_func():
if __name__ == "__main__":
test_llvm_add_pipeline()
test_llvm_intrin()
test_multiple_func()
test_llvm_add_pipeline()
test_llvm_flip_pipeline()
test_llvm_madd_pipeline()
test_llvm_temp_space()
......@@ -74,7 +74,27 @@ def test_stack_vm_cond():
np.testing.assert_equal(a.asnumpy(), y)
run_jit(fapi, check)
def test_vm_parallel():
dtype = 'int64'
n = tvm.var('n')
Ab = tvm.decl_buffer((n, ), dtype)
i = tvm.var('i')
ib = tvm.ir_builder.create()
A = ib.buffer_ptr(Ab)
with ib.for_range(0, n, "i", for_type="parallel") as i:
A[i] = A[i] + 1
stmt = ib.get()
fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
fapi = tvm.ir_pass.LowerTVMBuiltin(fapi)
def check(f):
a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a)
np.testing.assert_equal(a.asnumpy(), np.ones(a.shape[0]))
run_jit(fapi, check)
if __name__ == "__main__":
test_vm_parallel()
test_stack_vm_loop()
test_stack_vm_basic()
test_stack_vm_cond()
test_stack_vm_loop()
......@@ -30,6 +30,7 @@ def test_schedule_create():
assert isinstance(s_loaded, tvm.schedule.Schedule)
assert(str(s_loaded.outputs[0].body) == str(s.outputs[0].body))
def test_reorder():
m = tvm.var('m')
A = tvm.placeholder((m,), name='A')
......@@ -91,6 +92,21 @@ def test_vectorize():
assert s[T].iter_var_attrs[xi].iter_type == UNROLL
assert s[T].iter_var_attrs[yi].iter_type == VECTORIZE
def test_pragma():
m = 100
A = tvm.placeholder((m,), name='A')
T = tvm.compute((m,), lambda i: A[i])
s = tvm.create_schedule(T.op)
xo, xi = s[T].split(T.op.axis[0], factor=10)
s[T].pragma(xo, "pragma1")
s[T].pragma(xi, "vectorize")
VECTORIZE = tvm.schedule.IterVar.Vectorized
assert s[T].iter_var_attrs[xo].pragmas[0].value == "pragma1"
assert s[T].iter_var_attrs[xi].iter_type == VECTORIZE
def test_rfactor():
n = tvm.var('n')
k1 = tvm.reduce_axis((0, n), name="k1")
......@@ -141,6 +157,7 @@ def test_tensor_intrin():
if __name__ == "__main__":
test_pragma()
test_tensor_intrin()
test_rfactor()
test_schedule_create()
......
......@@ -50,3 +50,16 @@ TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.load_module")
});
} // namespace contrib
} // namespace tvm
// dummy parallel runtime
int TVMBackendParallelLaunch(
FTVMParallelLambda flambda,
void* cdata,
int num_task) {
TVMAPISetLastError("Parallel is not supported in Web runtime");
return -1;
}
int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) {
return 0;
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment