Commit d9bab380 by Tianqi Chen Committed by GitHub

[IR] Change pragma convention, enable pass unroll option via pragma (#1112)

* [IR] Change pragma scope convention, enable pass unroll option via pragma

* add coverage test

* add explicit unroll as option
parent a20c741b
...@@ -177,8 +177,8 @@ constexpr const char* device_context_type = "device_context_type"; ...@@ -177,8 +177,8 @@ constexpr const char* device_context_type = "device_context_type";
constexpr const char* loop_scope = "loop_scope"; constexpr const char* loop_scope = "loop_scope";
/*! \brief Mark of reduce scope */ /*! \brief Mark of reduce scope */
constexpr const char* reduce_scope = "reduce_scope"; constexpr const char* reduce_scope = "reduce_scope";
/*! \brief Mark region is guarded by the pragma */ /*! \brief Mark region is guarded by the pragma extension */
constexpr const char* pragma_scope = "pragma_scope"; constexpr const char* pragma_scope_prefix = "pragma_";
/*! /*!
* \brief Mark of prefetch scope, value=offset, * \brief Mark of prefetch scope, value=offset,
* run prefetch of Tensor on the current loop scope * run prefetch of Tensor on the current loop scope
...@@ -233,6 +233,16 @@ constexpr const char* pipeline_exec_scope = "pipeline_exec_scope"; ...@@ -233,6 +233,16 @@ constexpr const char* pipeline_exec_scope = "pipeline_exec_scope";
* Store statement. * Store statement.
*/ */
constexpr const char* opengl_stage_scope = "opengl_stage_scope"; constexpr const char* opengl_stage_scope = "opengl_stage_scope";
/*!
* \brief Check if attr_key is a pragma key extension
* \param attr_key The attr key to be compared
* \return true if it is a pragma key
*/
inline bool IsPragmaKey(const std::string& attr_key) {
return attr_key.compare(0, 7, "pragma_") == 0;
}
} // namespace attr } // namespace attr
/*! \brief namespace of TVM Intrinsic functions */ /*! \brief namespace of TVM Intrinsic functions */
......
...@@ -185,10 +185,13 @@ class Stage : public NodeRef { ...@@ -185,10 +185,13 @@ class Stage : public NodeRef {
* *
* \param var The axis to be parallelized. * \param var The axis to be parallelized.
* \param pragma_type The pragma type. * \param pragma_type The pragma type.
* \param pragma_value The pragma value
* *
* \return reference to self. * \return reference to self.
*/ */
EXPORT Stage& pragma(IterVar var, const std::string& pragma_type); // NOLINT(*) EXPORT Stage& pragma(IterVar var,
const std::string& pragma_type,
const Expr& pragma_value = Expr()); // NOLINT(*)
/*! /*!
* \brief Fetch data in advance. * \brief Fetch data in advance.
* \param domain the tensor to be prefetched * \param domain the tensor to be prefetched
...@@ -539,9 +542,13 @@ class IterVarAttrNode : public Node { ...@@ -539,9 +542,13 @@ class IterVarAttrNode : public Node {
/*! \brief Alignment offset of buffer dimension */ /*! \brief Alignment offset of buffer dimension */
int dim_align_offset{0}; int dim_align_offset{0};
/*! /*!
* \brief Additional pragmas, array of StringImm * \brief Additional pragma keys, array of StringImm
*/ */
Array<Expr> pragmas; Array<Expr> pragma_keys;
/*!
* \brief Additional values of pragma, if any
*/
Array<Expr> pragma_values;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("iter_type", &iter_type); v->Visit("iter_type", &iter_type);
...@@ -551,7 +558,8 @@ class IterVarAttrNode : public Node { ...@@ -551,7 +558,8 @@ class IterVarAttrNode : public Node {
v->Visit("tensor_intrin", &tensor_intrin); v->Visit("tensor_intrin", &tensor_intrin);
v->Visit("dim_align_factor", &dim_align_factor); v->Visit("dim_align_factor", &dim_align_factor);
v->Visit("dim_align_offset", &dim_align_offset); v->Visit("dim_align_offset", &dim_align_offset);
v->Visit("pragmas", &pragmas); v->Visit("pragma_keys", &pragma_keys);
v->Visit("pragma_values", &pragma_values);
} }
static constexpr const char* _type_key = "IterVarAttr"; static constexpr const char* _type_key = "IterVarAttr";
......
...@@ -552,7 +552,7 @@ class Stage(NodeBase): ...@@ -552,7 +552,7 @@ class Stage(NodeBase):
""" """
_api_internal._StageParallel(self, var) _api_internal._StageParallel(self, var)
def pragma(self, var, pragma_type): def pragma(self, var, pragma_type, pragma_value=None):
"""Annotate the iteration with pragma """Annotate the iteration with pragma
This will translate to a pragma_scope surrounding This will translate to a pragma_scope surrounding
...@@ -567,6 +567,9 @@ class Stage(NodeBase): ...@@ -567,6 +567,9 @@ class Stage(NodeBase):
pragma_type : str pragma_type : str
The pragma string to be annotated The pragma string to be annotated
pragma_value : Expr, optional
The pragma value to pass along the pragma
Note Note
---- ----
Most pragmas are advanced/experimental features Most pragmas are advanced/experimental features
...@@ -597,7 +600,7 @@ class Stage(NodeBase): ...@@ -597,7 +600,7 @@ class Stage(NodeBase):
:code:`for (int i = task_id; i < end; i += num_task)` :code:`for (int i = task_id; i < end; i += num_task)`
""" """
_api_internal._StagePragma(self, var, pragma_type) _api_internal._StagePragma(self, var, pragma_type, pragma_value)
def prefetch(self, tensor, var, offset): def prefetch(self, tensor, var, offset):
"""Prefetch the specified variable """Prefetch the specified variable
......
...@@ -380,7 +380,7 @@ TVM_REGISTER_API("_StageParallel") ...@@ -380,7 +380,7 @@ TVM_REGISTER_API("_StageParallel")
TVM_REGISTER_API("_StagePragma") TVM_REGISTER_API("_StagePragma")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage() args[0].operator Stage()
.pragma(args[1], args[2]); .pragma(args[1], args[2], args[3]);
}); });
TVM_REGISTER_API("_StagePrefetch") TVM_REGISTER_API("_StagePrefetch")
......
...@@ -683,16 +683,15 @@ void CodeGenCPU::VisitStmt_(const AttrStmt* op) { ...@@ -683,16 +683,15 @@ void CodeGenCPU::VisitStmt_(const AttrStmt* op) {
this->CreateStaticInit(op->value.as<StringImm>()->value, op->body); this->CreateStaticInit(op->value.as<StringImm>()->value, op->body);
} else if (op->attr_key == ir::attr::compute_scope) { } else if (op->attr_key == ir::attr::compute_scope) {
this->CreateComputeScope(op); this->CreateComputeScope(op);
} else if (op->attr_key == ir::attr::pragma_scope) { } else if (attr::IsPragmaKey(op->attr_key)) {
const std::string& pname = op->value.as<StringImm>()->value; if (op->attr_key == "pragma_parallel_stride_pattern") {
if (pname == "parallel_stride_pattern") {
CHECK(parallel_env_.penv != nullptr) CHECK(parallel_env_.penv != nullptr)
<< "Pragma parallel_stride_pattern only valid in parallel launch"; << "Pragma parallel_stride_pattern only valid in parallel launch";
parallel_env_.stride_pattern = true; parallel_env_.stride_pattern = true;
this->VisitStmt(op->body); this->VisitStmt(op->body);
} else if (pname == "parallel_launch_point") { } else if (op->attr_key == "pragma_parallel_launch_point") {
CreateParallelLaunch(op->body, 0); CreateParallelLaunch(op->body, 0);
} else if (pname == "parallel_barrier_when_finish") { } else if (op->attr_key == "pragma_parallel_barrier_when_finish") {
CHECK(parallel_env_.penv != nullptr) CHECK(parallel_env_.penv != nullptr)
<< "Cannot run barrier without parallel environment"; << "Cannot run barrier without parallel environment";
CHECK(!parallel_env_.in_parallel_loop) CHECK(!parallel_env_.in_parallel_loop)
...@@ -703,7 +702,7 @@ void CodeGenCPU::VisitStmt_(const AttrStmt* op) { ...@@ -703,7 +702,7 @@ void CodeGenCPU::VisitStmt_(const AttrStmt* op) {
RuntimeTVMParallelBarrier(), RuntimeTVMParallelBarrier(),
{MakeValue(parallel_env_.task_id), parallel_env_.penv}); {MakeValue(parallel_env_.task_id), parallel_env_.penv});
} else { } else {
LOG(WARNING) << "Unknown pragma " << pname; LOG(WARNING) << "Unknown pragma " << op->attr_key;
this->VisitStmt(op->body); this->VisitStmt(op->body);
} }
} else { } else {
......
...@@ -77,7 +77,7 @@ spirv::Value CodeGenSPIRV::GetThreadIndex( ...@@ -77,7 +77,7 @@ spirv::Value CodeGenSPIRV::GetThreadIndex(
spirv::Value v; spirv::Value v;
if (ts.rank == 1) { if (ts.rank == 1) {
v = builder_->GetLocalID(ts.dim_index); v = builder_->GetLocalID(ts.dim_index);
int size; int size = 0;
CHECK(arith::GetConstInt(extent, &size)) CHECK(arith::GetConstInt(extent, &size))
<< "SPIRV only allows constant thread group size " << " get " << extent; << "SPIRV only allows constant thread group size " << " get " << extent;
CHECK_LT(ts.dim_index, 3); CHECK_LT(ts.dim_index, 3);
......
...@@ -71,9 +71,15 @@ MakeLoopNest(const Stage& stage, ...@@ -71,9 +71,15 @@ MakeLoopNest(const Stage& stage,
<< it_attr->iter_type << it_attr->iter_type
<< " in the iter_var_attrs"; << " in the iter_var_attrs";
} }
for (Expr p : it_attr->pragmas) { CHECK_EQ(it_attr->pragma_keys.size(), it_attr->pragma_values.size());
for (size_t k = 0; k < it_attr->pragma_keys.size(); ++k) {
const std::string& pkey = it_attr->pragma_keys[k].as<StringImm>()->value;
Expr pvalue = it_attr->pragma_values[k];
if (!pvalue.defined()) {
pvalue = make_const(Int(32), 1);
}
nest[i + 1].emplace_back( nest[i + 1].emplace_back(
AttrStmt::make(iv, ir::attr::pragma_scope, p, no_op)); AttrStmt::make(iv, ir::attr::pragma_scope_prefix + pkey, pvalue, no_op));
} }
} }
if (!debug_keep_trivial_loop && is_one(dom->extent)) { if (!debug_keep_trivial_loop && is_one(dom->extent)) {
......
...@@ -17,7 +17,7 @@ class CopyIntrinInjector : public IRMutator { ...@@ -17,7 +17,7 @@ class CopyIntrinInjector : public IRMutator {
public: public:
CopyIntrinInjector(const std::string& pragma_key, CopyIntrinInjector(const std::string& pragma_key,
const PackedFunc& flower_copy_fromto) const PackedFunc& flower_copy_fromto)
: pragma_key_(pragma_key), : pragma_key_(attr::pragma_scope_prefix+ pragma_key),
flower_copy_fromto_(flower_copy_fromto) { flower_copy_fromto_(flower_copy_fromto) {
} }
...@@ -25,14 +25,11 @@ class CopyIntrinInjector : public IRMutator { ...@@ -25,14 +25,11 @@ class CopyIntrinInjector : public IRMutator {
if (op->attr_key == attr::storage_scope) { if (op->attr_key == attr::storage_scope) {
const Variable* buf = op->node.as<Variable>(); const Variable* buf = op->node.as<Variable>();
storage_scope_[buf] = op->value.as<StringImm>()->value; storage_scope_[buf] = op->value.as<StringImm>()->value;
} else if (op->attr_key == ir::attr::pragma_scope) { } else if (op->attr_key == pragma_key_) {
const std::string& pname = op->value.as<StringImm>()->value; Stmt ret;
if (pname == pragma_key_) { CHECK(MatchCopyPattern(op->body, &ret))
Stmt ret; << "Cannot match copy pattern of " << op->body;
CHECK(MatchCopyPattern(op->body, &ret)) return ret;
<< "Cannot match copy pattern of " << op->body;
return ret;
}
} }
return IRMutator::Mutate_(op, s); return IRMutator::Mutate_(op, s);
} }
......
...@@ -20,11 +20,8 @@ class NoOpRemover : public IRMutator { ...@@ -20,11 +20,8 @@ class NoOpRemover : public IRMutator {
return is_no_op(op->body) ? MakeEvaluate(op->value) : stmt; return is_no_op(op->body) ? MakeEvaluate(op->value) : stmt;
} }
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->attr_key == ir::attr::pragma_scope) { if (op->attr_key == "pragma_debug_skip_region") {
const std::string& pname = op->value.as<StringImm>()->value; return MakeEvaluate(0);
if (pname == "debug_skip_region") {
return MakeEvaluate(0);
}
} }
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<AttrStmt>(); op = stmt.as<AttrStmt>();
......
...@@ -401,7 +401,7 @@ class StoragePlanRewriter : public IRMutator { ...@@ -401,7 +401,7 @@ class StoragePlanRewriter : public IRMutator {
return this->Mutate(op->body); return this->Mutate(op->body);
} else if (op->attr_key == attr::thread_extent || } else if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread || op->attr_key == attr::virtual_thread ||
op->attr_key == attr::pragma_scope) { attr::IsPragmaKey(op->attr_key)) {
// remake all the allocation at the attach scope. // remake all the allocation at the attach scope.
if (attach_map_.count(op)) { if (attach_map_.count(op)) {
auto& svec = attach_map_[op]; auto& svec = attach_map_[op];
...@@ -737,8 +737,8 @@ class StoragePlanRewriter : public IRMutator { ...@@ -737,8 +737,8 @@ class StoragePlanRewriter : public IRMutator {
if (s.stmt->is_type<AttrStmt>()) { if (s.stmt->is_type<AttrStmt>()) {
const auto* op = static_cast<const AttrStmt*>(s.stmt); const auto* op = static_cast<const AttrStmt*>(s.stmt);
if (op->attr_key == attr::thread_extent || if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::pragma_scope || op->attr_key == attr::virtual_thread ||
op->attr_key == attr::virtual_thread) { attr::IsPragmaKey(op->attr_key)) {
PlanNewScope(op); PlanNewScope(op);
} else { } else {
CHECK(op->attr_key == attr::extern_scope); CHECK(op->attr_key == attr::extern_scope);
......
...@@ -27,6 +27,27 @@ class LoopUnroller : public IRMutator { ...@@ -27,6 +27,27 @@ class LoopUnroller : public IRMutator {
explicit_unroll_(explicit_unroll) { explicit_unroll_(explicit_unroll) {
} }
Stmt Mutate_(const AttrStmt* op, const Stmt& stmt) final {
if (op->attr_key == "pragma_auto_unroll_max_step") {
int value;
CHECK(arith::GetConstInt(op->value, &value));
std::swap(value, auto_max_step_);
Stmt ret = this->Mutate(op->body);
std::swap(value, auto_max_step_);
return ret;
} else if (op->attr_key == "pragma_unroll_explicit") {
int value;
CHECK(arith::GetConstInt(op->value, &value));
bool explicit_unroll = value;
std::swap(explicit_unroll, explicit_unroll_);
Stmt ret = this->Mutate(op->body);
std::swap(explicit_unroll, explicit_unroll_);
return ret;
} else {
return IRMutator::Mutate_(op, stmt);
}
}
Stmt Mutate_(const For* op, const Stmt& s) { Stmt Mutate_(const For* op, const Stmt& s) {
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<For>(); op = stmt.as<For>();
......
...@@ -350,15 +350,19 @@ Stage& Stage::parallel(IterVar var) { // NOLINT(*) ...@@ -350,15 +350,19 @@ Stage& Stage::parallel(IterVar var) { // NOLINT(*)
return *this; return *this;
} }
Stage& Stage::pragma(IterVar var, const std::string& pragma_type) { // NOLINT(*) Stage& Stage::pragma(IterVar var,
const std::string& pragma_type,
const Expr& pragma_value) { // NOLINT(*)
if (pragma_type == "unroll") { if (pragma_type == "unroll") {
this->unroll(var); this->unroll(var);
} else if (pragma_type == "vectorize") { } else if (pragma_type == "vectorize") {
this->vectorize(var); this->vectorize(var);
} else { } else {
UpdateIterVarAttr(operator->(), var, [pragma_type](IterVarAttrNode* n) { UpdateIterVarAttr(
n->pragmas.push_back(ir::StringImm::make(pragma_type)); operator->(), var, [pragma_type, pragma_value](IterVarAttrNode* n) {
}); n->pragma_keys.push_back(ir::StringImm::make(pragma_type));
n->pragma_values.push_back(pragma_value);
});
} }
return *this; return *this;
} }
......
...@@ -110,7 +110,7 @@ def test_pragma(): ...@@ -110,7 +110,7 @@ def test_pragma():
s[T].pragma(xo, "pragma1") s[T].pragma(xo, "pragma1")
s[T].pragma(xi, "vectorize") s[T].pragma(xi, "vectorize")
VECTORIZE = tvm.schedule.IterVar.Vectorized VECTORIZE = tvm.schedule.IterVar.Vectorized
assert s[T].iter_var_attrs[xo].pragmas[0].value == "pragma1" assert s[T].iter_var_attrs[xo].pragma_keys[0].value == "pragma1"
assert s[T].iter_var_attrs[xi].iter_type == VECTORIZE assert s[T].iter_var_attrs[xi].iter_type == VECTORIZE
......
import tvm import tvm
import os import os
def test_unroll_loop(): def test_unroll_loop():
ib = tvm.ir_builder.create()
dtype = 'int64' dtype = 'int64'
n = tvm.var('n') n = tvm.var('n')
Ab = tvm.decl_buffer((n, ), dtype) Ab = tvm.decl_buffer((n, ), dtype)
i = tvm.var('i') Aptr = ib.buffer_ptr(Ab)
j = tvm.var('j')
# for i in 0 to n-1: # for i in 0 to n-1:
stmt = tvm.make.For( with ib.for_range(n, n + 2, name="i") as i:
i, n, 2, 0, 0, with ib.for_range(0, 8, name="i", for_type="unroll") as j:
tvm.make.For(j, 0, 8, 3, 0, Aptr[j + 1] = Aptr[i] + 1
tvm.make.Store(Ab.data,
tvm.make.Load(dtype, Ab.data, i) + 1, stmt = ib.get()
j + 1)))
assert isinstance(stmt, tvm.stmt.For) assert isinstance(stmt, tvm.stmt.For)
ret = tvm.ir_pass.UnrollLoop(stmt, 16, 8, 0, True) ret = tvm.ir_pass.UnrollLoop(stmt, 16, 8, 0, True)
assert not isinstance(ret, tvm.stmt.For) assert not isinstance(ret, tvm.stmt.For)
...@@ -23,23 +23,18 @@ def test_unroll_loop(): ...@@ -23,23 +23,18 @@ def test_unroll_loop():
assert isinstance(ret, tvm.stmt.For) assert isinstance(ret, tvm.stmt.For)
assert ret.for_type == tvm.stmt.For.Unrolled assert ret.for_type == tvm.stmt.For.Unrolled
ib = tvm.ir_builder.create()
ib.scope_attr(tvm.const(0), "pragma_auto_unroll_max_step", 16)
ib.emit(stmt)
wrapped = ib.get()
wrapped = tvm.make.Block(wrapped, stmt)
assert isinstance(ret, tvm.stmt.For)
ret = tvm.ir_pass.UnrollLoop(wrapped, 0, 8, 0, False)
assert isinstance(ret.first, tvm.stmt.For)
assert ret.first.for_type == tvm.stmt.For.Unrolled
assert isinstance(ret.rest, tvm.stmt.For)
assert ret.rest.for_type != tvm.stmt.For.Unrolled
if __name__ == "__main__":
with tvm.build_config(dump_pass_ir=True):
test_unroll_loop()
def end_with(*suffix):
ends = suffix
def run(s):
f = map(s.endswith, ends)
if True in f: return s
return run
file_list = os.listdir('./') if __name__ == "__main__":
cc_file = end_with('.cc') test_unroll_loop()
cc_file = filter(cc_file, file_list)
cc_file = [f for f in cc_file]
assert len(cc_file) == 3
for i in cc_file:
os.remove(i)
...@@ -20,6 +20,7 @@ def test_schedule1(): ...@@ -20,6 +20,7 @@ def test_schedule1():
s = tvm.create_schedule(A1.op) s = tvm.create_schedule(A1.op)
xo, xi = s[A1].split(A1.op.axis[0], 8) xo, xi = s[A1].split(A1.op.axis[0], 8)
s[A1].pragma(xo, "auto_unroll_max_step", 10)
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map) assert isinstance(bounds, tvm.container.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment