Commit 41768cf9 by Tianqi Chen Committed by GitHub

[SCHEDULE][RUNIME] Introduce pragma for additional extension hint, threadpool runtime. (#299)

parent fd96d285
...@@ -32,23 +32,35 @@ def test_rpc_module(): ...@@ -32,23 +32,35 @@ def test_rpc_module():
n = tvm.convert(1024) n = tvm.convert(1024)
A = tvm.placeholder((n,), name='A') A = tvm.placeholder((n,), name='A')
B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name='B') B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
temp = util.tempdir()
s = tvm.create_schedule(B.op) s = tvm.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], factor=64) xo, xi = s[B].split(B.op.axis[0], factor=64)
s[B].bind(xi, tvm.thread_axis("threadIdx.x")) s[B].bind(xi, tvm.thread_axis("threadIdx.x"))
s[B].bind(xo, tvm.thread_axis("blockIdx.x")) s[B].bind(xo, tvm.thread_axis("blockIdx.x"))
temp = util.tempdir()
# Build the dynamic lib. # Build the dynamic lib.
# If we don't want to do metal and only use cpu, just set target to be target # If we don't want to do metal and only use cpu, just set target to be target
f = tvm.build(s, [A, B], "metal", target_host=target, name="myadd") f = tvm.build(s, [A, B], "metal", target_host=target, name="myadd")
path_dso = temp.relpath("dev_lib.dylib") path_dso1 = temp.relpath("dev_lib.dylib")
f.export_library(path_dso, xcode.create_dylib, f.export_library(path_dso1, xcode.create_dylib,
arch=arch, sdk=sdk) arch=arch, sdk=sdk)
xcode.codesign(path_dso) xcode.codesign(path_dso1)
s = tvm.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], factor=64)
s[B].parallel(xi)
s[B].pragma(xo, "parallel_launch_point")
s[B].pragma(xi, "parallel_barrier_when_finish")
f = tvm.build(s, [A, B], target, name="myadd_cpu")
path_dso2 = temp.relpath("cpu_lib.dylib")
f.export_library(path_dso2, xcode.create_dylib,
arch=arch, sdk=sdk)
xcode.codesign(path_dso2)
# Start RPC test server that contains the compiled library. # Start RPC test server that contains the compiled library.
server = xcode.popen_test_rpc(proxy_host, proxy_port, key, server = xcode.popen_test_rpc(proxy_host, proxy_port, key,
destination=destination, destination=destination,
libs=[path_dso], libs=[path_dso1, path_dso2])
options=["-quiet"])
# connect to the proxy # connect to the proxy
remote = rpc.connect(proxy_host, proxy_port, key=key) remote = rpc.connect(proxy_host, proxy_port, key=key)
ctx = remote.metal(0) ctx = remote.metal(0)
...@@ -60,5 +72,15 @@ def test_rpc_module(): ...@@ -60,5 +72,15 @@ def test_rpc_module():
cost = time_f(a, b).mean cost = time_f(a, b).mean
print('%g secs/op' % cost) print('%g secs/op' % cost)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
# CPU
ctx = remote.cpu(0)
f2 = remote.load_module("cpu_lib.dylib")
a_np = np.random.uniform(size=1024).astype(A.dtype)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
time_f = f2.time_evaluator(f1.entry_name, ctx, number=10)
cost = time_f(a, b).mean
print('%g secs/op' % cost)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
test_rpc_module() test_rpc_module()
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "../../src/runtime/c_runtime_api.cc" #include "../../src/runtime/c_runtime_api.cc"
#include "../../src/runtime/cpu_device_api.cc" #include "../../src/runtime/cpu_device_api.cc"
#include "../../src/runtime/workspace_pool.cc" #include "../../src/runtime/workspace_pool.cc"
#include "../../src/runtime/thread_pool.cc"
#include "../../src/runtime/module_util.cc" #include "../../src/runtime/module_util.cc"
#include "../../src/runtime/system_lib_module.cc" #include "../../src/runtime/system_lib_module.cc"
#include "../../src/runtime/module.cc" #include "../../src/runtime/module.cc"
......
...@@ -45,7 +45,7 @@ tvm.ir_pass ...@@ -45,7 +45,7 @@ tvm.ir_pass
tvm.ir_pass.StorageFlatten tvm.ir_pass.StorageFlatten
tvm.ir_pass.VectorizeLoop tvm.ir_pass.VectorizeLoop
tvm.ir_pass.UnrollLoop tvm.ir_pass.UnrollLoop
tvm.ir_pass.StorageSync tvm.ir_pass.ThreadSync
tvm.ir_pass.StorageRewrite tvm.ir_pass.StorageRewrite
tvm.ir_pass.MakeAPI tvm.ir_pass.MakeAPI
tvm.ir_pass.SplitHostDevice tvm.ir_pass.SplitHostDevice
......
...@@ -166,6 +166,8 @@ constexpr const char* device_context_type = "device_context_type"; ...@@ -166,6 +166,8 @@ constexpr const char* device_context_type = "device_context_type";
constexpr const char* loop_scope = "loop_scope"; constexpr const char* loop_scope = "loop_scope";
/*! \brief Mark of reduce scope */ /*! \brief Mark of reduce scope */
constexpr const char* reduce_scope = "reduce_scope"; constexpr const char* reduce_scope = "reduce_scope";
/*! \brief Mark region is guarded by the pragma */
constexpr const char* pragma_scope = "pragma_scope";
/*! /*!
* \brief Mark of prefetch scope, value=offset, * \brief Mark of prefetch scope, value=offset,
* run prefetch of Tensor on the current loop scope * run prefetch of Tensor on the current loop scope
......
...@@ -66,21 +66,49 @@ TVM_DLL void* TVMBackendAllocWorkspace(int device_type, ...@@ -66,21 +66,49 @@ TVM_DLL void* TVMBackendAllocWorkspace(int device_type,
TVM_DLL int TVMBackendFreeWorkspace(int device_type, TVM_DLL int TVMBackendFreeWorkspace(int device_type,
int device_id, int device_id,
void* ptr); void* ptr);
/*!
* \brief Environment for TVM parallel task.
*/
typedef struct {
/*!
* \brief Auxiliary used for synchronization
*/
void* sync_handle;
/*! \brief total amount of task */
int32_t num_task;
} TVMParallelGroupEnv;
/*!
* \brief The callback function to execute a parallel lambda
* \param task_id the task id of the function.
* \param penv The parallel environment backs the execution.
* \param cdata The supporting closure data.
*/
typedef int (*FTVMParallelLambda)(
int task_id, TVMParallelGroupEnv* penv, void* cdata);
/*! /*!
* \brief Backend function for running parallel for loop. * \brief Backend function for running parallel jobs.
* *
* \param begin The start of iteration. * \param flambda The parallel function to be launched.
* \param end The end of iteration. * \param cdata The closure data.
* \param lambda The lambda function to be executed. * \param num_task Number of tasks to launch, can be 0, means launch
* \param env The environment of lambda function. * with all available threads.
* *
* \return 0 when no error is thrown, -1 when failure happens * \return 0 when no error is thrown, -1 when failure happens
*/ */
TVM_DLL int TVMBackendParallelFor( TVM_DLL int TVMBackendParallelLaunch(FTVMParallelLambda flambda,
int64_t begin, void* cdata,
int64_t end, int num_task);
int (*lambda)(int64_t begin, int64_t end, void* env),
void* env); /*!
* \brief BSP barrrier between parallel threads
* \param task_id the task id of the function.
* \param penv The parallel environment backs the execution.
* \return 0 when no error is thrown, -1 when failure happens
*/
TVM_DLL int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv);
#ifdef __cplusplus #ifdef __cplusplus
} // TVM_EXTERN_C } // TVM_EXTERN_C
......
...@@ -181,6 +181,15 @@ class Stage : public NodeRef { ...@@ -181,6 +181,15 @@ class Stage : public NodeRef {
*/ */
Stage& parallel(IterVar var); // NOLINT(*) Stage& parallel(IterVar var); // NOLINT(*)
/*! /*!
* \brief Annotate the iteration with pragma
*
* \param var The axis to be parallelized.
* \param pragma_type The pragma type.
*
* \return reference to self.
*/
Stage& pragma(IterVar var, const std::string& pragma_type); // NOLINT(*)
/*!
* \brief Fetch data in advance. * \brief Fetch data in advance.
* \param domain the tensor to be prefetched * \param domain the tensor to be prefetched
* \param var the iteration point at which to apply prefetching * \param var the iteration point at which to apply prefetching
...@@ -487,6 +496,10 @@ class IterVarAttrNode : public Node { ...@@ -487,6 +496,10 @@ class IterVarAttrNode : public Node {
* when the axis is marked as Tensorized * when the axis is marked as Tensorized
*/ */
TensorIntrin tensor_intrin; TensorIntrin tensor_intrin;
/*!
* \brief Additional pragmas, array of StringImm
*/
Array<Expr> pragmas;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("iter_type", &iter_type); v->Visit("iter_type", &iter_type);
...@@ -494,6 +507,7 @@ class IterVarAttrNode : public Node { ...@@ -494,6 +507,7 @@ class IterVarAttrNode : public Node {
v->Visit("prefetch_data", &prefetch_data); v->Visit("prefetch_data", &prefetch_data);
v->Visit("prefetch_offset", &prefetch_offset); v->Visit("prefetch_offset", &prefetch_offset);
v->Visit("tensor_intrin", &tensor_intrin); v->Visit("tensor_intrin", &tensor_intrin);
v->Visit("pragmas", &pragmas);
} }
static constexpr const char* _type_key = "IterVarAttr"; static constexpr const char* _type_key = "IterVarAttr";
......
...@@ -78,7 +78,7 @@ class Module(ModuleBase): ...@@ -78,7 +78,7 @@ class Module(ModuleBase):
file_name : str file_name : str
The name of the shared library. The name of the shared library.
fcompile : function(target, file_list, **kwargs), optional fcompile : function(target, file_list, kwargs), optional
Compilation function to use create dynamic library. Compilation function to use create dynamic library.
kwargs : dict, optiona; kwargs : dict, optiona;
......
...@@ -26,7 +26,7 @@ class Buffer(NodeBase): ...@@ -26,7 +26,7 @@ class Buffer(NodeBase):
WRITE = 2 WRITE = 2
def access_ptr(self, access_mask, ptr_type="handle"): def access_ptr(self, access_mask, ptr_type="handle"):
"""Get an access pointer to the head of buffer """Get an access pointer to the head of buffer.
This is the recommended method to get buffer data This is the recommended method to get buffer data
ptress when interacting with external functions. ptress when interacting with external functions.
...@@ -37,7 +37,6 @@ class Buffer(NodeBase): ...@@ -37,7 +37,6 @@ class Buffer(NodeBase):
The access pattern MASK. Indicate whether the The access pattern MASK. Indicate whether the
access will read or write to the data content. access will read or write to the data content.
ptr_type : str, optional ptr_type : str, optional
The data type of the result pointer. Do not specify The data type of the result pointer. Do not specify
unless we want to cast pointer to specific type. unless we want to cast pointer to specific type.
...@@ -45,8 +44,8 @@ class Buffer(NodeBase): ...@@ -45,8 +44,8 @@ class Buffer(NodeBase):
Examples Examples
-------- --------
.. code-block:: python .. code-block:: python
import tvm.schedule.Buffer
import tvm.schedule.Buffer
# Get access ptr for read # Get access ptr for read
buffer.access_ptr("r") buffer.access_ptr("r")
# Get access ptr for read/write with bitmask # Get access ptr for read/write with bitmask
...@@ -465,6 +464,48 @@ class Stage(NodeBase): ...@@ -465,6 +464,48 @@ class Stage(NodeBase):
""" """
_api_internal._StageParallel(self, var) _api_internal._StageParallel(self, var)
def pragma(self, var, pragma_type):
"""Annotate the iteration with pragma
This will translate to a pragma_scope surrounding
the corresponding loop generated.
Useful to support experimental features and extensions.
Parameters
----------
var : IterVar
The iteration to be anotated
pragma_type : str
The pragma string to be annotated
Note
----
Most pragmas are advanced/experimental features
and may subject to change. List of supported pragmas:
- **parallel_launch_point**
Specify to launch parallel threads outside the
specified iteration loop. By default the threads
launch at the point of parallel construct.
This pragma moves the launching point to even outer scope.
The threads are launched once and reused across multiple
parallel constructs as BSP style program.
- **parallel_barrier_when_finish**
Insert a synchronization barrier between working threads
after the specified loop iteration finishes.
- **parallel_stride_pattern**
Hint parallel loop to execute in strided pattern.
:code:`for (int i = task_id; i < end; i += num_task)`
"""
_api_internal._StagePragma(self, var, pragma_type)
def prefetch(self, tensor, var, offset): def prefetch(self, tensor, var, offset):
"""Prefetch the specified variable """Prefetch the specified variable
......
...@@ -364,6 +364,12 @@ TVM_REGISTER_API("_StageParallel") ...@@ -364,6 +364,12 @@ TVM_REGISTER_API("_StageParallel")
.parallel(args[1]); .parallel(args[1]);
}); });
TVM_REGISTER_API("_StagePragma")
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage()
.pragma(args[1], args[2]);
});
TVM_REGISTER_API("_StagePrefetch") TVM_REGISTER_API("_StagePrefetch")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
args[0].operator Stage() args[0].operator Stage()
......
...@@ -63,8 +63,14 @@ void CodeGenLLVM::Init(const std::string& module_name, ...@@ -63,8 +63,14 @@ void CodeGenLLVM::Init(const std::string& module_name,
t_tvm_shape_index_->getPointerTo(), t_tvm_shape_index_->getPointerTo(),
t_int64_}); t_int64_});
t_tvm_value_ = llvm::StructType::create({t_float64_}); t_tvm_value_ = llvm::StructType::create({t_float64_});
ftype_tvm_par_for_lambda_ = llvm::FunctionType::get( t_tvm_parallel_group_env_ = llvm::StructType::create({
t_int_, {t_int64_, t_int64_, t_void_p_}, false); t_int32_->getPointerTo(),
t_int32_});
ftype_tvm_parallel_lambda_ = llvm::FunctionType::get(
t_int_,
{t_int_,
t_tvm_parallel_group_env_->getPointerTo(),
t_void_p_}, false);
md_builder_.reset(new llvm::MDBuilder(*ctx)); md_builder_.reset(new llvm::MDBuilder(*ctx));
md_very_likely_branch_ = md_very_likely_branch_ =
md_builder_->createBranchWeights(1 << 30, 0); md_builder_->createBranchWeights(1 << 30, 0);
...@@ -90,9 +96,13 @@ void CodeGenLLVM::Init(const std::string& module_name, ...@@ -90,9 +96,13 @@ void CodeGenLLVM::Init(const std::string& module_name,
t_tvm_func_handle_->getPointerTo()}, false); t_tvm_func_handle_->getPointerTo()}, false);
ftype_tvm_api_set_last_error_ = llvm::FunctionType::get( ftype_tvm_api_set_last_error_ = llvm::FunctionType::get(
t_void_, {t_char_->getPointerTo()}, false); t_void_, {t_char_->getPointerTo()}, false);
ftype_tvm_parallel_for_ = ftype_tvm_parallel_launch_ =
llvm::FunctionType::get(t_int_, {
ftype_tvm_parallel_lambda_->getPointerTo(), t_void_p_, t_int_}
, false);
ftype_tvm_parallel_barrier_ =
llvm::FunctionType::get(t_int_, { llvm::FunctionType::get(t_int_, {
t_int64_, t_int64_, ftype_tvm_par_for_lambda_->getPointerTo(), t_void_p_} t_int_, t_tvm_parallel_group_env_->getPointerTo()}
, false); , false);
// initialize TVM runtime API // initialize TVM runtime API
if (system_lib) { if (system_lib) {
...@@ -113,9 +123,12 @@ void CodeGenLLVM::Init(const std::string& module_name, ...@@ -113,9 +123,12 @@ void CodeGenLLVM::Init(const std::string& module_name,
f_tvm_api_set_last_error_ = llvm::Function::Create( f_tvm_api_set_last_error_ = llvm::Function::Create(
ftype_tvm_api_set_last_error_, ftype_tvm_api_set_last_error_,
llvm::Function::ExternalLinkage, "TVMAPISetLastError", module_.get()); llvm::Function::ExternalLinkage, "TVMAPISetLastError", module_.get());
f_tvm_parallel_for_ = llvm::Function::Create( f_tvm_parallel_launch_ = llvm::Function::Create(
ftype_tvm_parallel_for_, ftype_tvm_parallel_launch_,
llvm::Function::ExternalLinkage, "TVMBackendParallelFor", module_.get()); llvm::Function::ExternalLinkage, "TVMBackendParallelLaunch", module_.get());
f_tvm_parallel_barrier_ = llvm::Function::Create(
ftype_tvm_parallel_barrier_,
llvm::Function::ExternalLinkage, "TVMBackendParallelBarrier", module_.get());
} }
this->InitTarget(tm); this->InitTarget(tm);
// initialize builder // initialize builder
...@@ -179,8 +192,10 @@ void CodeGenLLVM::InitGlobalContext(bool dynamic_lookup) { ...@@ -179,8 +192,10 @@ void CodeGenLLVM::InitGlobalContext(bool dynamic_lookup) {
ftype_tvm_get_func_from_env_->getPointerTo(), "__TVMBackendGetFuncFromEnv"); ftype_tvm_get_func_from_env_->getPointerTo(), "__TVMBackendGetFuncFromEnv");
gv_tvm_api_set_last_error_ = InitContextPtr( gv_tvm_api_set_last_error_ = InitContextPtr(
ftype_tvm_api_set_last_error_->getPointerTo(), "__TVMAPISetLastError"); ftype_tvm_api_set_last_error_->getPointerTo(), "__TVMAPISetLastError");
gv_tvm_parallel_for_ = InitContextPtr( gv_tvm_parallel_launch_ = InitContextPtr(
ftype_tvm_parallel_for_->getPointerTo(), "__TVMBackendParallelFor"); ftype_tvm_parallel_launch_->getPointerTo(), "__TVMBackendParallelLaunch");
gv_tvm_parallel_barrier_ = InitContextPtr(
ftype_tvm_parallel_barrier_->getPointerTo(), "__TVMBackendParallelBarrier");
// Mark as context functions // Mark as context functions
gv_func_map_["TVMBackendAllocWorkspace"] = nullptr; gv_func_map_["TVMBackendAllocWorkspace"] = nullptr;
gv_func_map_["TVMBackendFreeWorkspace"] = nullptr; gv_func_map_["TVMBackendFreeWorkspace"] = nullptr;
...@@ -702,9 +717,14 @@ llvm::Value* CodeGenLLVM::RuntimeTVMAPISetLastError() { ...@@ -702,9 +717,14 @@ llvm::Value* CodeGenLLVM::RuntimeTVMAPISetLastError() {
if (f_tvm_api_set_last_error_ != nullptr) return f_tvm_api_set_last_error_; if (f_tvm_api_set_last_error_ != nullptr) return f_tvm_api_set_last_error_;
return GetContextPtr(gv_tvm_api_set_last_error_); return GetContextPtr(gv_tvm_api_set_last_error_);
} }
llvm::Value* CodeGenLLVM::RuntimeTVMParallelFor() { llvm::Value* CodeGenLLVM::RuntimeTVMParallelLaunch() {
if (f_tvm_parallel_for_ != nullptr) return f_tvm_parallel_for_; if (f_tvm_parallel_launch_ != nullptr) return f_tvm_parallel_launch_;
return GetContextPtr(gv_tvm_parallel_for_); return GetContextPtr(gv_tvm_parallel_launch_);
}
llvm::Value* CodeGenLLVM::RuntimeTVMParallelBarrier() {
if (f_tvm_parallel_barrier_ != nullptr) return f_tvm_parallel_barrier_;
return GetContextPtr(gv_tvm_parallel_barrier_);
} }
llvm::Value* CodeGenLLVM::GetVarValue(const Variable* v) const { llvm::Value* CodeGenLLVM::GetVarValue(const Variable* v) const {
...@@ -782,15 +802,9 @@ void CodeGenLLVM::CreateComputeScope(const AttrStmt* op) { ...@@ -782,15 +802,9 @@ void CodeGenLLVM::CreateComputeScope(const AttrStmt* op) {
builder_->SetInsertPoint(compute_call_end); builder_->SetInsertPoint(compute_call_end);
} }
void CodeGenLLVM::CreateParallelFor(const For* op) { void CodeGenLLVM::CreateParallelLaunch(const Stmt& body, int num_task) {
using llvm::BasicBlock; using llvm::BasicBlock;
llvm::Value* min = MakeValue(op->min); Array<Var> vfields = ir::UndefinedVars(body, {});
llvm::Value* extent = MakeValue(op->extent);
min = builder_->CreateIntCast(min, t_int64_, op->min.type().is_int());
extent = builder_->CreateIntCast(extent, t_int64_, op->min.type().is_int());
// fields to be packed into closure.
Var loop_var(op->loop_var.node_);
Array<Var> vfields = ir::UndefinedVars(op->body, {loop_var});
std::vector<llvm::Type*> fields; std::vector<llvm::Type*> fields;
for (Var v : vfields) { for (Var v : vfields) {
auto it = var_map_.find(v.get()); auto it = var_map_.find(v.get());
...@@ -800,9 +814,9 @@ void CodeGenLLVM::CreateParallelFor(const For* op) { ...@@ -800,9 +814,9 @@ void CodeGenLLVM::CreateParallelFor(const For* op) {
// closure data // closure data
llvm::StructType* tcdata = llvm::StructType::create(fields); llvm::StructType* tcdata = llvm::StructType::create(fields);
llvm::Function* f = llvm::Function::Create( llvm::Function* f = llvm::Function::Create(
ftype_tvm_par_for_lambda_, ftype_tvm_parallel_lambda_,
llvm::Function::PrivateLinkage, llvm::Function::PrivateLinkage,
"__tvm_par_for_lambda", module_.get()); "__tvm_parallel_lambda", module_.get());
// allocate and setup the closure, call the closure. // allocate and setup the closure, call the closure.
llvm::Value* cdata = builder_->CreateAlloca(tcdata, ConstInt32(1)); llvm::Value* cdata = builder_->CreateAlloca(tcdata, ConstInt32(1));
llvm::Value* zero = ConstInt32(0); llvm::Value* zero = ConstInt32(0);
...@@ -812,19 +826,17 @@ void CodeGenLLVM::CreateParallelFor(const For* op) { ...@@ -812,19 +826,17 @@ void CodeGenLLVM::CreateParallelFor(const For* op) {
var_map_.at(vfields[i].get()), var_map_.at(vfields[i].get()),
builder_->CreateInBoundsGEP(cdata, {zero, ConstInt32(i)})); builder_->CreateInBoundsGEP(cdata, {zero, ConstInt32(i)}));
} }
BasicBlock* par_for_end = CheckCallSuccess( BasicBlock* par_launch_end = CheckCallSuccess(
builder_->CreateCall( builder_->CreateCall(
RuntimeTVMParallelFor(), RuntimeTVMParallelLaunch(),
{min, extent, f, builder_->CreatePointerCast(cdata, t_void_p_)})); {f, builder_->CreatePointerCast(cdata, t_void_p_), ConstInt32(num_task)}));
// Setup the closure function. // Setup the closure function.
BasicBlock *lambda_entry = BasicBlock::Create(*ctx_, "entry", f); BasicBlock *lambda_entry = BasicBlock::Create(*ctx_, "entry", f);
builder_->SetInsertPoint(lambda_entry); builder_->SetInsertPoint(lambda_entry);
auto it = f->arg_begin(); auto it = f->arg_begin();
llvm::Value* begin = &(*it++); llvm::Value* task_id = &(*it++);
llvm::Value* end = &(*it++); llvm::Value* penv = &(*it++);
cdata = &(*it++); cdata = &(*it++);
begin = CreateCast(Int(64), op->loop_var.type(), begin);
end = CreateCast(Int(64), op->loop_var.type(), end);
cdata = builder_->CreatePointerCast(cdata, tcdata->getPointerTo()); cdata = builder_->CreatePointerCast(cdata, tcdata->getPointerTo());
// setup new variable map, swap it with current var context. // setup new variable map, swap it with current var context.
std::unordered_map<const Variable*, llvm::Value*> new_vmap; std::unordered_map<const Variable*, llvm::Value*> new_vmap;
...@@ -833,17 +845,32 @@ void CodeGenLLVM::CreateParallelFor(const For* op) { ...@@ -833,17 +845,32 @@ void CodeGenLLVM::CreateParallelFor(const For* op) {
builder_->CreateLoad(builder_->CreateInBoundsGEP( builder_->CreateLoad(builder_->CreateInBoundsGEP(
cdata, {zero, ConstInt32(i)})); cdata, {zero, ConstInt32(i)}));
} }
// setup parallel env
ParallelEnv par_env;
par_env.task_id = Var("task_id", Int(32));
par_env.num_task = Var("num_task", Int(32));
new_vmap[par_env.task_id.get()] = task_id;
new_vmap[par_env.num_task.get()] = builder_->CreateLoad(
builder_->CreateInBoundsGEP(
penv, {zero, ConstInt32(1)}));
par_env.penv = penv;
std::swap(function_, f); std::swap(function_, f);
std::swap(new_vmap, var_map_); std::swap(parallel_env_, par_env);
CreateSerialFor(begin, end, op->loop_var, op->body); std::swap(var_map_, new_vmap);
this->VisitStmt(body);
builder_->CreateRet(ConstInt32(0)); builder_->CreateRet(ConstInt32(0));
// swap the var map back, now we are back on track. // swap the var map back, now we are back on track.
std::swap(new_vmap, var_map_); std::swap(var_map_, new_vmap);
std::swap(parallel_env_, par_env);
std::swap(function_, f); std::swap(function_, f);
builder_->SetInsertPoint(par_for_end); CHECK(par_env.hit_parallel_loop)
<< "Cannot find parallel loop within parallel launch";
builder_->SetInsertPoint(par_launch_end);
} }
void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, void CodeGenLLVM::CreateSerialFor(llvm::Value* begin,
llvm::Value* end,
llvm::Value* stride,
const VarExpr& loop_var, const Stmt& body) { const VarExpr& loop_var, const Stmt& body) {
using llvm::BasicBlock; using llvm::BasicBlock;
Type t = loop_var.type(); Type t = loop_var.type();
...@@ -864,7 +891,7 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, ...@@ -864,7 +891,7 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end,
builder_->SetInsertPoint(for_body); builder_->SetInsertPoint(for_body);
var_map_[loop_var.get()] = index; var_map_[loop_var.get()] = index;
this->VisitStmt(body); this->VisitStmt(body);
llvm::Value* next_index = CreateAdd(t, index, ConstInt32(1)); llvm::Value* next_index = CreateAdd(t, index, stride);
index->addIncoming(next_index, builder_->GetInsertBlock()); index->addIncoming(next_index, builder_->GetInsertBlock());
builder_->CreateBr(for_head); builder_->CreateBr(for_head);
// end of for // end of for
...@@ -1481,10 +1508,45 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Call* op) { ...@@ -1481,10 +1508,45 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Call* op) {
void CodeGenLLVM::VisitStmt_(const For* op) { void CodeGenLLVM::VisitStmt_(const For* op) {
CHECK(is_zero(op->min)); CHECK(is_zero(op->min));
if (op->for_type == ForType::Serial) { if (op->for_type == ForType::Serial) {
CreateSerialFor(ConstInt32(0), MakeValue(op->extent), CreateSerialFor(ConstInt32(0),
op->loop_var, op->body); MakeValue(op->extent),
ConstInt32(1),
op->loop_var,
op->body);
} else if (op->for_type == ForType::Parallel) { } else if (op->for_type == ForType::Parallel) {
CreateParallelFor(op); if (parallel_env_.penv == nullptr) {
CreateParallelLaunch(
For::make(
op->loop_var, op->min, op->extent,
op->for_type, op->device_api, op->body), 0);
} else {
// already in parallel env.
CHECK(parallel_env_.task_id.defined());
CHECK(parallel_env_.num_task.defined());
CHECK(parallel_env_.penv != nullptr);
Type t = op->extent.type();
Expr num_task = cast(t, parallel_env_.num_task);
Expr task_id = cast(t, parallel_env_.task_id);
CHECK(!parallel_env_.hit_parallel_loop)
<< "Nested parallel loop is not supported by threadpool, try fuse them instead";
parallel_env_.hit_parallel_loop = true;
if (parallel_env_.stride_pattern) {
CreateSerialFor(MakeValue(task_id),
MakeValue(op->extent),
MakeValue(num_task),
op->loop_var,
op->body);
} else {
Expr step = (op->extent + num_task - make_const(t, 1)) / num_task;
Expr begin = Min::make(task_id * step, op->extent);
Expr end = Min::make((task_id + make_const(t, 1)) * step, op->extent);
CreateSerialFor(MakeValue(begin),
MakeValue(end),
ConstInt32(1),
op->loop_var,
op->body);
}
}
} else { } else {
LOG(FATAL) << "cannot handle for type " << op->for_type; LOG(FATAL) << "cannot handle for type " << op->for_type;
} }
...@@ -1566,6 +1628,29 @@ void CodeGenLLVM::VisitStmt_(const AttrStmt* op) { ...@@ -1566,6 +1628,29 @@ void CodeGenLLVM::VisitStmt_(const AttrStmt* op) {
this->VisitStmt(op->body); this->VisitStmt(op->body);
} else if (op->attr_key == ir::attr::compute_scope) { } else if (op->attr_key == ir::attr::compute_scope) {
this->CreateComputeScope(op); this->CreateComputeScope(op);
} else if (op->attr_key == ir::attr::pragma_scope) {
const std::string& pname = op->value.as<StringImm>()->value;
if (pname == "parallel_stride_pattern") {
CHECK(parallel_env_.penv != nullptr)
<< "Pragma parallel_stride_pattern only valid in parallel launch";
parallel_env_.stride_pattern = true;
this->VisitStmt(op->body);
} else if (pname == "parallel_launch_point") {
CreateParallelLaunch(op->body, 0);
} else if (pname == "parallel_barrier_when_finish") {
CHECK(parallel_env_.penv != nullptr)
<< "Cannot run barrier without parallel environment";
CHECK(!parallel_env_.hit_parallel_loop)
<< "Cannot not place within parallel loop as the workload may differ, "
<< " place it between parallel and parallel_launch_point";
this->VisitStmt(op->body);
builder_->CreateCall(
RuntimeTVMParallelBarrier(),
{MakeValue(parallel_env_.task_id), parallel_env_.penv});
} else {
LOG(WARNING) << "Unknown pragma " << pname;
this->VisitStmt(op->body);
}
} else { } else {
this->VisitStmt(op->body); this->VisitStmt(op->body);
} }
......
...@@ -189,11 +189,13 @@ class CodeGenLLVM : ...@@ -189,11 +189,13 @@ class CodeGenLLVM :
llvm::StructType* t_tvm_type_{nullptr}; llvm::StructType* t_tvm_type_{nullptr};
llvm::StructType* t_tvm_array_{nullptr}; llvm::StructType* t_tvm_array_{nullptr};
llvm::StructType* t_tvm_value_{nullptr}; llvm::StructType* t_tvm_value_{nullptr};
llvm::FunctionType* ftype_tvm_par_for_lambda_{nullptr}; llvm::StructType* t_tvm_parallel_group_env_{nullptr};
llvm::FunctionType* ftype_tvm_parallel_lambda_{nullptr};
llvm::FunctionType* ftype_tvm_func_call_{nullptr}; llvm::FunctionType* ftype_tvm_func_call_{nullptr};
llvm::FunctionType* ftype_tvm_get_func_from_env_{nullptr}; llvm::FunctionType* ftype_tvm_get_func_from_env_{nullptr};
llvm::FunctionType* ftype_tvm_api_set_last_error_{nullptr}; llvm::FunctionType* ftype_tvm_api_set_last_error_{nullptr};
llvm::FunctionType* ftype_tvm_parallel_for_{nullptr}; llvm::FunctionType* ftype_tvm_parallel_launch_{nullptr};
llvm::FunctionType* ftype_tvm_parallel_barrier_{nullptr};
llvm::FunctionType* ftype_tvm_register_system_symbol_{nullptr}; llvm::FunctionType* ftype_tvm_register_system_symbol_{nullptr};
// The acting body // The acting body
llvm::BasicBlock* block_{nullptr}; llvm::BasicBlock* block_{nullptr};
...@@ -203,13 +205,22 @@ class CodeGenLLVM : ...@@ -203,13 +205,22 @@ class CodeGenLLVM :
std::unordered_map<const Variable*, StorageInfo> alloc_storage_info_; std::unordered_map<const Variable*, StorageInfo> alloc_storage_info_;
private: private:
// the parallel group information
struct ParallelEnv {
VarExpr task_id;
VarExpr num_task;
bool stride_pattern{false};
bool hit_parallel_loop{false};
llvm::Value* penv{nullptr};
};
// Get runtime functions // Get runtime functions
llvm::GlobalVariable* InitContextPtr(llvm::Type* type, std::string name); llvm::GlobalVariable* InitContextPtr(llvm::Type* type, std::string name);
llvm::Value* GetContextPtr(llvm::GlobalVariable* gv); llvm::Value* GetContextPtr(llvm::GlobalVariable* gv);
llvm::Value* RuntimeTVMFuncCall(); llvm::Value* RuntimeTVMFuncCall();
llvm::Value* RuntimeTVMGetFuncFromEnv(); llvm::Value* RuntimeTVMGetFuncFromEnv();
llvm::Value* RuntimeTVMAPISetLastError(); llvm::Value* RuntimeTVMAPISetLastError();
llvm::Value* RuntimeTVMParallelFor(); llvm::Value* RuntimeTVMParallelLaunch();
llvm::Value* RuntimeTVMParallelBarrier();
// comparison op // comparison op
llvm::Value* GetVarValue(const Variable* v) const; llvm::Value* GetVarValue(const Variable* v) const;
llvm::Value* CreateLT(Type t, llvm::Value* a, llvm::Value* b); llvm::Value* CreateLT(Type t, llvm::Value* a, llvm::Value* b);
...@@ -230,10 +241,12 @@ class CodeGenLLVM : ...@@ -230,10 +241,12 @@ class CodeGenLLVM :
llvm::Value* CreateVecFlip(llvm::Value* vec); llvm::Value* CreateVecFlip(llvm::Value* vec);
llvm::Value* CreateVecConcat(std::vector<llvm::Value*> vecs); llvm::Value* CreateVecConcat(std::vector<llvm::Value*> vecs);
llvm::Value* CreateVecPad(llvm::Value* vec, int target_lanes); llvm::Value* CreateVecPad(llvm::Value* vec, int target_lanes);
// Create parallel for. // Create parallel launch
void CreateParallelFor(const For* op); void CreateParallelLaunch(const Stmt& body, int num_task);
// Create serial for // Create serial for
void CreateSerialFor(llvm::Value* begin, llvm::Value* end, void CreateSerialFor(llvm::Value* begin,
llvm::Value* end,
llvm::Value* stride,
const VarExpr& loop_var, const Stmt& body); const VarExpr& loop_var, const Stmt& body);
// Create a new compute scope. // Create a new compute scope.
void CreateComputeScope(const AttrStmt* op); void CreateComputeScope(const AttrStmt* op);
...@@ -262,14 +275,18 @@ class CodeGenLLVM : ...@@ -262,14 +275,18 @@ class CodeGenLLVM :
llvm::GlobalVariable* gv_tvm_func_call_{nullptr}; llvm::GlobalVariable* gv_tvm_func_call_{nullptr};
llvm::GlobalVariable* gv_tvm_get_func_from_env_{nullptr}; llvm::GlobalVariable* gv_tvm_get_func_from_env_{nullptr};
llvm::GlobalVariable* gv_tvm_api_set_last_error_{nullptr}; llvm::GlobalVariable* gv_tvm_api_set_last_error_{nullptr};
llvm::GlobalVariable* gv_tvm_parallel_for_{nullptr}; llvm::GlobalVariable* gv_tvm_parallel_launch_{nullptr};
llvm::GlobalVariable* gv_tvm_parallel_barrier_{nullptr};
std::unordered_map<std::string, llvm::GlobalVariable*> gv_func_map_; std::unordered_map<std::string, llvm::GlobalVariable*> gv_func_map_;
// context for direct dynamic lookup // context for direct dynamic lookup
llvm::Function* f_tvm_func_call_{nullptr}; llvm::Function* f_tvm_func_call_{nullptr};
llvm::Function* f_tvm_get_func_from_env_{nullptr}; llvm::Function* f_tvm_get_func_from_env_{nullptr};
llvm::Function* f_tvm_api_set_last_error_{nullptr}; llvm::Function* f_tvm_api_set_last_error_{nullptr};
llvm::Function* f_tvm_parallel_for_{nullptr}; llvm::Function* f_tvm_parallel_launch_{nullptr};
llvm::Function* f_tvm_parallel_barrier_{nullptr};
llvm::Function* f_tvm_register_system_symbol_{nullptr}; llvm::Function* f_tvm_register_system_symbol_{nullptr};
// Current parallel environment scope.
ParallelEnv parallel_env_;
// global to packed function handle // global to packed function handle
std::unordered_map<std::string, llvm::GlobalVariable*> func_handle_map_; std::unordered_map<std::string, llvm::GlobalVariable*> func_handle_map_;
// List of symbols to be exported to TVM system lib. // List of symbols to be exported to TVM system lib.
......
...@@ -70,6 +70,10 @@ MakeLoopNest(const Stage& stage, ...@@ -70,6 +70,10 @@ MakeLoopNest(const Stage& stage,
<< it_attr->iter_type << it_attr->iter_type
<< " in the iter_var_attrs"; << " in the iter_var_attrs";
} }
for (Expr p : it_attr->pragmas) {
nest[i + 1].emplace_back(
AttrStmt::make(iv, ir::attr::pragma_scope, p, no_op));
}
} }
if (is_one(dom->extent)) { if (is_one(dom->extent)) {
nest[i + 1].emplace_back( nest[i + 1].emplace_back(
......
...@@ -14,8 +14,6 @@ ...@@ -14,8 +14,6 @@
#include <algorithm> #include <algorithm>
#include <string> #include <string>
#include <cstdlib> #include <cstdlib>
#include <thread>
#include <mutex>
#include "./runtime_base.h" #include "./runtime_base.h"
namespace tvm { namespace tvm {
...@@ -158,24 +156,6 @@ struct TVMRuntimeEntry { ...@@ -158,24 +156,6 @@ struct TVMRuntimeEntry {
std::string ret_str; std::string ret_str;
std::string last_error; std::string last_error;
TVMByteArray ret_bytes; TVMByteArray ret_bytes;
// threads used in parallel for
std::vector<std::thread> par_threads;
// errors created in parallel for.
std::vector<std::string> par_errors;
// number of parallel threads
int num_par_threads{1};
TVMRuntimeEntry() {
const char *val = getenv("TVM_NUM_THREADS");
if (val == nullptr) {
val = getenv("OMP_NUM_THREADS");
}
if (val != nullptr) {
num_par_threads = atoi(val);
} else {
num_par_threads = std::thread::hardware_concurrency() / 2;
}
}
}; };
typedef dmlc::ThreadLocalStore<TVMRuntimeEntry> TVMAPIRuntimeStore; typedef dmlc::ThreadLocalStore<TVMRuntimeEntry> TVMAPIRuntimeStore;
...@@ -254,46 +234,6 @@ int TVMBackendFreeWorkspace(int device_type, ...@@ -254,46 +234,6 @@ int TVMBackendFreeWorkspace(int device_type,
return 0; return 0;
} }
int TVMBackendParallelFor(
int64_t begin,
int64_t end,
int (*lambda)(int64_t begin, int64_t end, void* env),
void* env) {
TVMRuntimeEntry* rt = TVMAPIRuntimeStore::Get();
int nthread = rt->num_par_threads;
rt->par_threads.resize(nthread);
rt->par_errors.clear();
rt->par_errors.resize(nthread);
int64_t step = (end - begin + nthread - 1) / nthread;
auto fexec = [lambda, env, begin, end, step, rt](int i) {
int64_t ibegin = std::min(end, begin + step * i);
int64_t iend = std::min(end, begin + step * (i + 1));
int rv = (*lambda)(ibegin, iend, env);
if (rv != 0) {
std::ostringstream os;
os << "Thread " << i << " error:" << TVMGetLastError();
rt->par_errors[i] = os.str();
}
};
for (int i = 0; i < nthread; ++i) {
rt->par_threads[i] = std::thread(fexec, i);
}
int ret = 0;
for (int i = 0; i < nthread; ++i) {
rt->par_threads[i].join();
if (rt->par_errors[i].length() != 0) ret = -1;
}
if (ret == 0) return ret;
std::ostringstream os;
for (int i = 0; i < nthread; ++i) {
if (rt->par_errors[i].length() != 0) {
os << rt->par_errors[i] << '\n';
}
}
rt->last_error = os.str();
return -1;
}
int TVMFuncFree(TVMFunctionHandle func) { int TVMFuncFree(TVMFunctionHandle func) {
API_BEGIN(); API_BEGIN();
delete static_cast<PackedFunc*>(func); delete static_cast<PackedFunc*>(func);
......
...@@ -40,30 +40,21 @@ void ImportModuleBlob(const char* mblob, std::vector<Module>* module_list); ...@@ -40,30 +40,21 @@ void ImportModuleBlob(const char* mblob, std::vector<Module>* module_list);
*/ */
template<typename FLookup> template<typename FLookup>
void InitContextFunctions(FLookup flookup) { void InitContextFunctions(FLookup flookup) {
if (auto *fp = reinterpret_cast<decltype(&TVMFuncCall)*> #define TVM_INIT_CONTEXT_FUNC(FuncName) \
(flookup("__TVMFuncCall"))) { if (auto *fp = reinterpret_cast<decltype(&FuncName)*> \
*fp = TVMFuncCall; (flookup("__" #FuncName))) { \
} *fp = FuncName; \
if (auto *fp = reinterpret_cast<decltype(&TVMAPISetLastError)*>
(flookup("__TVMAPISetLastError"))) {
*fp = TVMAPISetLastError;
}
if (auto *fp = reinterpret_cast<decltype(&TVMBackendGetFuncFromEnv)*>
(flookup("__TVMBackendGetFuncFromEnv"))) {
*fp = TVMBackendGetFuncFromEnv;
}
if (auto *fp = reinterpret_cast<decltype(&TVMBackendAllocWorkspace)*>
(flookup("__TVMBackendAllocWorkspace"))) {
*fp = TVMBackendAllocWorkspace;
}
if (auto *fp = reinterpret_cast<decltype(&TVMBackendFreeWorkspace)*>
(flookup("__TVMBackendFreeWorkspace"))) {
*fp = TVMBackendFreeWorkspace;
}
if (auto *fp = reinterpret_cast<decltype(&TVMBackendParallelFor)*>
(flookup("__TVMBackendParallelFor"))) {
*fp = TVMBackendParallelFor;
} }
// Initialize the functions
TVM_INIT_CONTEXT_FUNC(TVMFuncCall);
TVM_INIT_CONTEXT_FUNC(TVMAPISetLastError);
TVM_INIT_CONTEXT_FUNC(TVMBackendGetFuncFromEnv);
TVM_INIT_CONTEXT_FUNC(TVMBackendAllocWorkspace);
TVM_INIT_CONTEXT_FUNC(TVMBackendFreeWorkspace);
TVM_INIT_CONTEXT_FUNC(TVMBackendParallelLaunch);
TVM_INIT_CONTEXT_FUNC(TVMBackendParallelBarrier);
#undef TVM_INIT_CONTEXT_FUNC
} }
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
......
/*!
* Copyright (c) 2017 by Contributors
* \file thread_pool.cc
* \brief Threadpool for multi-threading runtime.
*/
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/c_backend_api.h>
#include <dmlc/thread_local.h>
#include <dmlc/logging.h>
#include <thread>
#include <condition_variable>
#include <mutex>
#include <atomic>
#include <vector>
#include <string>
#include <cstring>
#include <memory>
#include <sstream>
namespace tvm {
namespace runtime {
// stride in the page, fit to cache line.
constexpr int kSyncStride = 64 / sizeof(std::atomic<int>);
/*!
* \brief Thread local master environment.
*/
class ParallelLauncher {
public:
// Reset the the task request.
void Init(FTVMParallelLambda flambda,
void* cdata,
int num_task,
bool need_sync) {
std::lock_guard<std::mutex> lock(mutex_);
num_pending_ = num_task;
this->cdata = cdata;
this->flambda = flambda;
this->env.num_task = num_task;
has_error_ = false;
// reshape
if (static_cast<size_t>(num_task) > par_errors_.size()) {
par_errors_.resize(num_task + 1);
if (need_sync) {
delete[] sync_counter_;
sync_counter_ = new std::atomic<int>[num_task * kSyncStride];
}
}
if (need_sync) {
for (int i = 0; i < num_task; ++i) {
sync_counter_[i * kSyncStride].store(
0, std::memory_order_relaxed);
}
this->env.sync_handle = sync_counter_;
} else {
this->env.sync_handle = nullptr;
}
}
~ParallelLauncher() {
delete[] sync_counter_;
}
// Wait n jobs to finish
int WaitForJobs() {
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [this] {
return num_pending_ == 0;
});
if (!has_error_) return 0;
std::ostringstream os;
for (size_t i = 0; i < par_errors_.size(); ++i) {
if (par_errors_[i].length() != 0) {
os << "Task " << i << " error: " << par_errors_[i] << '\n';
par_errors_[i].clear();
}
}
TVMAPISetLastError(os.str().c_str());
return -1;
}
// Signal that one job has finished.
void SignalJobError(int task_id) {
std::unique_lock<std::mutex> lock(mutex_);
--num_pending_;
par_errors_[task_id] = TVMGetLastError();
has_error_ = true;
if (num_pending_ == 0) {
lock.unlock();
cv_.notify_one();
}
}
// Signal that one job has finished.
void SignalJobFinish() {
std::unique_lock<std::mutex> lock(mutex_);
--num_pending_;
if (num_pending_ == 0) {
lock.unlock();
cv_.notify_one();
}
}
// Get thread local version of the store.
static ParallelLauncher* ThreadLocal() {
return dmlc::ThreadLocalStore<ParallelLauncher>::Get();
}
// The parallel lambda
FTVMParallelLambda flambda;
// The closure data
void* cdata;
// Local env
TVMParallelGroupEnv env;
// Whether this thread is worker of the pool.
// used to prevent recursive launch.
bool is_worker{false};
private:
// The mutex to access local env.
std::mutex mutex_;
// The conditional variable.
std::condition_variable cv_;
// The pending jobs.
uint32_t num_pending_;
// Whether error has been countered.
bool has_error_;
// The counter page.
std::atomic<int32_t>* sync_counter_{nullptr};
// The error message
std::vector<std::string> par_errors_;
};
/*! \brief Working queue for each thread */
class ParallelTaskQueue {
public:
/*! \brief The task entry */
struct Task {
ParallelLauncher* launcher;
int32_t task_id;
};
ParallelTaskQueue() {
ring_.resize(2);
}
/*!
* \brief Signal to kill the job.
*/
void SignalForKill() {
std::lock_guard<std::mutex> lock(mutex_);
exit_now_.store(true);
cv_.notify_all();
}
/*!
* \brief Push task into the queue.
* \param task The task to be pushed.
*/
void Push(Task task) {
std::unique_lock<std::mutex> lock(mutex_);
if (num_pending_ < ring_.size()) {
CHECK_NE(ring_.size(), 0U);
ring_[(head_ + num_pending_) % ring_.size()] = task;
++num_pending_;
} else {
size_t old_size = ring_.size();
ring_.resize(old_size * 2);
if (head_ + num_pending_ > old_size) {
// copy the ring overflow part into the tail.
size_t ncopy = head_ + num_pending_ - old_size;
memcpy(&ring_[0] + old_size, &ring_[0], ncopy * sizeof(Task));
}
ring_[(head_ + num_pending_) % ring_.size()] = task;
++num_pending_;
}
if (nwait_consumer_ != 0) {
lock.unlock();
cv_.notify_one();
}
}
/*!
* \brief Pop task from the queue
* \param task The task to be poped.
* \param timeout The number of cycles to spin before sleep.
* \return Whether pop is successful or we need to exit now.
*/
bool Pop(Task* task, int timeout = 10) {
std::unique_lock<std::mutex> lock(mutex_);
if (num_pending_ != 0) {
*task = ring_[head_];
head_ = (head_ + 1) % ring_.size();
--num_pending_;
if (exit_now_.load()) return false;
} else {
lock.unlock();
// do a bit spin and busy waiting before sleep.
for (int i = 0; i < timeout && num_pending_ == 0; ++i) {
std::this_thread::yield();
}
lock.lock();
++nwait_consumer_;
cv_.wait(lock, [this] {
return num_pending_ != 0 || exit_now_.load();
});
--nwait_consumer_;
*task = ring_[head_];
head_ = (head_ + 1) % ring_.size();
--num_pending_;
if (exit_now_.load()) return false;
}
return true;
}
private:
// Number of the elments in the queue
uint32_t num_pending_{0};
// Queue head
uint32_t head_{0};
// Number of consumers to wait.
uint32_t nwait_consumer_{0};
// internal mutex
std::mutex mutex_;
// cv for consumer
std::condition_variable cv_;
// signal for exit now
std::atomic<bool> exit_now_{false};
// The internal ring.
std::vector<Task> ring_;
};
// The thread pool
class ThreadPool {
public:
ThreadPool() {
const char *val = getenv("TVM_NUM_THREADS");
if (val == nullptr) {
val = getenv("OMP_NUM_THREADS");
}
if (val != nullptr) {
num_workers_ = atoi(val);
} else {
#if defined(_M_X64) || defined(__x86_64__)
// Half to not count hyper threading.
num_workers_ = std::thread::hardware_concurrency() / 2;
#else
num_workers_ = std::thread::hardware_concurrency();
#endif
}
num_workers_ = std::max(num_workers_, 1);
this->Init();
}
~ThreadPool() {
for (std::unique_ptr<ParallelTaskQueue>& q : queues_) {
q->SignalForKill();
}
for (std::thread& t : threads_) {
t.join();
}
}
int Launch(FTVMParallelLambda flambda,
void* cdata,
int num_task,
int need_sync) {
ParallelLauncher* launcher = ParallelLauncher::ThreadLocal();
CHECK(!launcher->is_worker)
<< "Cannot launch parallel job inside worker, consider fuse then parallel";
if (num_task == 0) {
num_task = num_workers_;
}
if (need_sync != 0) {
CHECK_LE(num_task, num_workers_)
<< "Request parallel sync task larger than number of threads available "
<< " workers=" << num_workers_ << " request=" << num_task;
}
launcher->Init(flambda, cdata, num_task, need_sync != 0);
ParallelTaskQueue::Task tsk;
tsk.launcher = launcher;
for (int i = 0; i < num_task; ++i) {
tsk.task_id = i;
queues_[i]->Push(tsk);
}
return launcher->WaitForJobs();
}
static ThreadPool* Global() {
static ThreadPool inst;
return &inst;
}
private:
// Initialize the pool.
void Init() {
for (int i = 0; i < num_workers_; ++i) {
queues_.emplace_back(
std::unique_ptr<ParallelTaskQueue>(new ParallelTaskQueue()));
}
threads_.resize(num_workers_);
for (int i = 0; i < num_workers_; ++i) {
threads_[i] = std::thread([this, i] {
this->RunWorker(queues_[i].get());
});
}
}
// Internal worker function.
void RunWorker(ParallelTaskQueue* queue) {
ParallelTaskQueue::Task task;
ParallelLauncher::ThreadLocal()->is_worker = true;
while (queue->Pop(&task)) {
CHECK(task.launcher != nullptr);
TVMParallelGroupEnv* penv = &(task.launcher->env);
void* cdata = task.launcher->cdata;
if ((*task.launcher->flambda)(task.task_id, penv, cdata) == 0) {
task.launcher->SignalJobFinish();
} else {
task.launcher->SignalJobError(task.task_id);
}
}
}
// Number of workers
int num_workers_;
std::vector<std::unique_ptr<ParallelTaskQueue> > queues_;
std::vector<std::thread> threads_;
};
} // namespace runtime
} // namespace tvm
int TVMBackendParallelLaunch(
FTVMParallelLambda flambda,
void* cdata,
int num_task) {
return tvm::runtime::ThreadPool::Global()->Launch(
flambda, cdata, num_task, 1);
}
int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) {
using tvm::runtime::kSyncStride;
int num_task = penv->num_task;
std::atomic<int>* sync_counter =
reinterpret_cast<std::atomic<int>*>(penv->sync_handle);
int old_counter = sync_counter[task_id * kSyncStride].fetch_add(
1, std::memory_order_release);
for (int i = 0; i < num_task; ++i) {
if (i != task_id) {
while (sync_counter[i * kSyncStride].load(
std::memory_order_relaxed) <= old_counter) {
std::this_thread::yield();
}
}
}
std::atomic_thread_fence(std::memory_order_acquire);
return 0;
}
...@@ -340,6 +340,19 @@ Stage& Stage::parallel(IterVar var) { // NOLINT(*) ...@@ -340,6 +340,19 @@ Stage& Stage::parallel(IterVar var) { // NOLINT(*)
return *this; return *this;
} }
Stage& Stage::pragma(IterVar var, const std::string& pragma_type) { // NOLINT(*)
if (pragma_type == "unroll") {
this->unroll(var);
} else if (pragma_type == "vectorize") {
this->vectorize(var);
} else {
UpdateIterVarAttr(operator->(), var, [pragma_type](IterVarAttrNode* n) {
n->pragmas.push_back(ir::StringImm::make(pragma_type));
});
}
return *this;
}
Stage& Stage::prefetch(const Tensor &tensor, IterVar var, Expr offset) { Stage& Stage::prefetch(const Tensor &tensor, IterVar var, Expr offset) {
StageNode *self = operator->(); StageNode *self = operator->();
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
......
...@@ -28,8 +28,13 @@ def test_llvm_add_pipeline(): ...@@ -28,8 +28,13 @@ def test_llvm_add_pipeline():
C = tvm.compute(A.shape, lambda *i: T(*i), name='C') C = tvm.compute(A.shape, lambda *i: T(*i), name='C')
s = tvm.create_schedule(C.op) s = tvm.create_schedule(C.op)
xo, xi = s[C].split(C.op.axis[0], factor=4) xo, xi = s[C].split(C.op.axis[0], factor=4)
s[C].parallel(xo) xo1, xo2 = s[C].split(xo, factor=13)
s[C].parallel(xo2)
s[C].pragma(xo1, "parallel_launch_point")
s[C].pragma(xo2, "parallel_stride_pattern")
s[C].pragma(xo2, "parallel_barrier_when_finish")
s[C].vectorize(xi) s[C].vectorize(xi)
def check_llvm(): def check_llvm():
if not tvm.module.enabled("llvm"): if not tvm.module.enabled("llvm"):
return return
...@@ -167,9 +172,9 @@ def test_multiple_func(): ...@@ -167,9 +172,9 @@ def test_multiple_func():
if __name__ == "__main__": if __name__ == "__main__":
test_llvm_add_pipeline()
test_llvm_intrin() test_llvm_intrin()
test_multiple_func() test_multiple_func()
test_llvm_add_pipeline()
test_llvm_flip_pipeline() test_llvm_flip_pipeline()
test_llvm_madd_pipeline() test_llvm_madd_pipeline()
test_llvm_temp_space() test_llvm_temp_space()
...@@ -74,7 +74,27 @@ def test_stack_vm_cond(): ...@@ -74,7 +74,27 @@ def test_stack_vm_cond():
np.testing.assert_equal(a.asnumpy(), y) np.testing.assert_equal(a.asnumpy(), y)
run_jit(fapi, check) run_jit(fapi, check)
def test_vm_parallel():
dtype = 'int64'
n = tvm.var('n')
Ab = tvm.decl_buffer((n, ), dtype)
i = tvm.var('i')
ib = tvm.ir_builder.create()
A = ib.buffer_ptr(Ab)
with ib.for_range(0, n, "i", for_type="parallel") as i:
A[i] = A[i] + 1
stmt = ib.get()
fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
fapi = tvm.ir_pass.LowerTVMBuiltin(fapi)
def check(f):
a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a)
np.testing.assert_equal(a.asnumpy(), np.ones(a.shape[0]))
run_jit(fapi, check)
if __name__ == "__main__": if __name__ == "__main__":
test_vm_parallel()
test_stack_vm_loop()
test_stack_vm_basic() test_stack_vm_basic()
test_stack_vm_cond() test_stack_vm_cond()
test_stack_vm_loop()
...@@ -30,6 +30,7 @@ def test_schedule_create(): ...@@ -30,6 +30,7 @@ def test_schedule_create():
assert isinstance(s_loaded, tvm.schedule.Schedule) assert isinstance(s_loaded, tvm.schedule.Schedule)
assert(str(s_loaded.outputs[0].body) == str(s.outputs[0].body)) assert(str(s_loaded.outputs[0].body) == str(s.outputs[0].body))
def test_reorder(): def test_reorder():
m = tvm.var('m') m = tvm.var('m')
A = tvm.placeholder((m,), name='A') A = tvm.placeholder((m,), name='A')
...@@ -91,6 +92,21 @@ def test_vectorize(): ...@@ -91,6 +92,21 @@ def test_vectorize():
assert s[T].iter_var_attrs[xi].iter_type == UNROLL assert s[T].iter_var_attrs[xi].iter_type == UNROLL
assert s[T].iter_var_attrs[yi].iter_type == VECTORIZE assert s[T].iter_var_attrs[yi].iter_type == VECTORIZE
def test_pragma():
m = 100
A = tvm.placeholder((m,), name='A')
T = tvm.compute((m,), lambda i: A[i])
s = tvm.create_schedule(T.op)
xo, xi = s[T].split(T.op.axis[0], factor=10)
s[T].pragma(xo, "pragma1")
s[T].pragma(xi, "vectorize")
VECTORIZE = tvm.schedule.IterVar.Vectorized
assert s[T].iter_var_attrs[xo].pragmas[0].value == "pragma1"
assert s[T].iter_var_attrs[xi].iter_type == VECTORIZE
def test_rfactor(): def test_rfactor():
n = tvm.var('n') n = tvm.var('n')
k1 = tvm.reduce_axis((0, n), name="k1") k1 = tvm.reduce_axis((0, n), name="k1")
...@@ -141,6 +157,7 @@ def test_tensor_intrin(): ...@@ -141,6 +157,7 @@ def test_tensor_intrin():
if __name__ == "__main__": if __name__ == "__main__":
test_pragma()
test_tensor_intrin() test_tensor_intrin()
test_rfactor() test_rfactor()
test_schedule_create() test_schedule_create()
......
...@@ -50,3 +50,16 @@ TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.load_module") ...@@ -50,3 +50,16 @@ TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.load_module")
}); });
} // namespace contrib } // namespace contrib
} // namespace tvm } // namespace tvm
// dummy parallel runtime
int TVMBackendParallelLaunch(
FTVMParallelLambda flambda,
void* cdata,
int num_task) {
TVMAPISetLastError("Parallel is not supported in Web runtime");
return -1;
}
int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) {
return 0;
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment