Commit b03c3243 by Tianqi Chen Committed by GitHub

[CODEGEN] Multiple parallel in one launch (#399)

parent ad8733ea
...@@ -404,7 +404,7 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task) { ...@@ -404,7 +404,7 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task) {
std::swap(var_map_, new_vmap); std::swap(var_map_, new_vmap);
std::swap(parallel_env_, par_env); std::swap(parallel_env_, par_env);
std::swap(function_, f); std::swap(function_, f);
CHECK(par_env.hit_parallel_loop) CHECK_NE(par_env.parallel_loop_count, 0)
<< "Cannot find parallel loop within parallel launch"; << "Cannot find parallel loop within parallel launch";
builder_->SetInsertPoint(par_launch_end); builder_->SetInsertPoint(par_launch_end);
} }
...@@ -679,7 +679,7 @@ void CodeGenCPU::VisitStmt_(const AttrStmt* op) { ...@@ -679,7 +679,7 @@ void CodeGenCPU::VisitStmt_(const AttrStmt* op) {
} else if (pname == "parallel_barrier_when_finish") { } else if (pname == "parallel_barrier_when_finish") {
CHECK(parallel_env_.penv != nullptr) CHECK(parallel_env_.penv != nullptr)
<< "Cannot run barrier without parallel environment"; << "Cannot run barrier without parallel environment";
CHECK(!parallel_env_.hit_parallel_loop) CHECK(!parallel_env_.in_parallel_loop)
<< "Cannot not place within parallel loop as the workload may differ, " << "Cannot not place within parallel loop as the workload may differ, "
<< " place it between parallel and parallel_launch_point"; << " place it between parallel and parallel_launch_point";
this->VisitStmt(op->body); this->VisitStmt(op->body);
...@@ -713,9 +713,9 @@ void CodeGenCPU::VisitStmt_(const For* op) { ...@@ -713,9 +713,9 @@ void CodeGenCPU::VisitStmt_(const For* op) {
Type t = op->extent.type(); Type t = op->extent.type();
Expr num_task = cast(t, parallel_env_.num_task); Expr num_task = cast(t, parallel_env_.num_task);
Expr task_id = cast(t, parallel_env_.task_id); Expr task_id = cast(t, parallel_env_.task_id);
CHECK(!parallel_env_.hit_parallel_loop) CHECK(!parallel_env_.in_parallel_loop)
<< "Nested parallel loop is not supported by threadpool, try fuse them instead"; << "Nested parallel loop is not supported by threadpool, try fuse them instead";
parallel_env_.hit_parallel_loop = true; parallel_env_.in_parallel_loop = true;
if (parallel_env_.stride_pattern) { if (parallel_env_.stride_pattern) {
CreateSerialFor(MakeValue(task_id), CreateSerialFor(MakeValue(task_id),
MakeValue(op->extent), MakeValue(op->extent),
...@@ -732,6 +732,8 @@ void CodeGenCPU::VisitStmt_(const For* op) { ...@@ -732,6 +732,8 @@ void CodeGenCPU::VisitStmt_(const For* op) {
op->loop_var, op->loop_var,
op->body); op->body);
} }
parallel_env_.in_parallel_loop = false;
++parallel_env_.parallel_loop_count;
} }
} else { } else {
LOG(FATAL) << "cannot handle for type " << op->for_type; LOG(FATAL) << "cannot handle for type " << op->for_type;
......
...@@ -24,7 +24,6 @@ class CodeGenCPU : public CodeGenLLVM { ...@@ -24,7 +24,6 @@ class CodeGenCPU : public CodeGenLLVM {
bool dynamic_lookup) override; bool dynamic_lookup) override;
void AddFunction(const LoweredFunc& f) override; void AddFunction(const LoweredFunc& f) override;
void AddMainFunction(const std::string& entry_func_name) override; void AddMainFunction(const std::string& entry_func_name) override;
void VisitStmt_(const AssertStmt* op) override; void VisitStmt_(const AssertStmt* op) override;
void VisitStmt_(const AttrStmt* op) override; void VisitStmt_(const AttrStmt* op) override;
void VisitStmt_(const For* op) override; void VisitStmt_(const For* op) override;
...@@ -60,7 +59,8 @@ class CodeGenCPU : public CodeGenLLVM { ...@@ -60,7 +59,8 @@ class CodeGenCPU : public CodeGenLLVM {
VarExpr task_id; VarExpr task_id;
VarExpr num_task; VarExpr num_task;
bool stride_pattern{false}; bool stride_pattern{false};
bool hit_parallel_loop{false}; bool in_parallel_loop{false};
int parallel_loop_count{0};
llvm::Value* penv{nullptr}; llvm::Value* penv{nullptr};
}; };
// Get runtime functions // Get runtime functions
......
...@@ -138,9 +138,6 @@ class LinearAccessPatternFinder final : public IRVisitor { ...@@ -138,9 +138,6 @@ class LinearAccessPatternFinder final : public IRVisitor {
in_thread_env_ = true; in_thread_env_ = true;
VisitNewScope(op); VisitNewScope(op);
in_thread_env_ = false; in_thread_env_ = false;
} else if (op->attr_key == attr::pragma_scope &&
op->value.as<StringImm>()->value == "parallel_launch_point") {
VisitNewScope(op);
} else if (op->attr_key == attr::storage_scope) { } else if (op->attr_key == attr::storage_scope) {
const Variable* buf = op->node.as<Variable>(); const Variable* buf = op->node.as<Variable>();
storage_scope_[buf] = storage_scope_[buf] =
......
...@@ -61,6 +61,36 @@ def test_llvm_add_pipeline(): ...@@ -61,6 +61,36 @@ def test_llvm_add_pipeline():
check_llvm() check_llvm()
def test_llvm_persist_parallel():
n = 128
A = tvm.placeholder((n,), name='A')
B = tvm.compute(A.shape, lambda *i: A(*i) + 1, name='B')
C = tvm.compute(A.shape, lambda *i: B(*i) + 2, name='C')
s = tvm.create_schedule(C.op)
xo, xi = s[C].split(C.op.axis[0], factor=8)
xo1, xo2 = s[C].split(xo, nparts=1)
s[B].compute_at(s[C], xo1)
s[B].parallel(s[B].op.axis[0])
s[B].pragma(s[B].op.axis[0], "parallel_barrier_when_finish")
s[C].parallel(xi)
s[C].pragma(xo1, "parallel_launch_point")
s[C].pragma(xi, "parallel_stride_pattern")
def check_llvm():
if not tvm.module.enabled("llvm"):
return
# BUILD and invoke the kernel.
f = tvm.build(s, [A, C], "llvm")
ctx = tvm.cpu(0)
# launch the kernel.
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
f(a, c)
np.testing.assert_allclose(c.asnumpy(), a.asnumpy() + 3)
check_llvm()
def test_llvm_flip_pipeline(): def test_llvm_flip_pipeline():
def check_llvm(nn, base): def check_llvm(nn, base):
if not tvm.module.enabled("llvm"): if not tvm.module.enabled("llvm"):
...@@ -222,6 +252,7 @@ def test_llvm_select(): ...@@ -222,6 +252,7 @@ def test_llvm_select():
if __name__ == "__main__": if __name__ == "__main__":
test_llvm_persist_parallel()
test_llvm_select() test_llvm_select()
test_llvm_vadd_pipeline() test_llvm_vadd_pipeline()
test_llvm_add_pipeline() test_llvm_add_pipeline()
......
...@@ -121,7 +121,7 @@ def test_parallel_alloc(): ...@@ -121,7 +121,7 @@ def test_parallel_alloc():
A[j] = A[j] + 2 A[j] = A[j] + 2
body = ib.get() body = ib.get()
body = tvm.ir_pass.StorageRewrite(body) body = tvm.ir_pass.StorageRewrite(body)
assert(isinstance(body.body.body.body, tvm.stmt.Allocate)) assert(isinstance(body.body.body.body.body, tvm.stmt.Allocate))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment