Commit d9bab380 by Tianqi Chen Committed by GitHub

[IR] Change pragma convention, enable pass unroll option via pragma (#1112)

* [IR] Change pragma scope convention, enable pass unroll option via pragma

* add coverage test

* add explicit unroll as option
parent a20c741b
......@@ -177,8 +177,8 @@ constexpr const char* device_context_type = "device_context_type";
constexpr const char* loop_scope = "loop_scope";
/*! \brief Mark of reduce scope */
constexpr const char* reduce_scope = "reduce_scope";
/*! \brief Mark region is guarded by the pragma */
constexpr const char* pragma_scope = "pragma_scope";
/*! \brief Mark region is guarded by the pragma extension */
constexpr const char* pragma_scope_prefix = "pragma_";
/*!
* \brief Mark of prefetch scope, value=offset,
* run prefetch of Tensor on the current loop scope
......@@ -233,6 +233,16 @@ constexpr const char* pipeline_exec_scope = "pipeline_exec_scope";
* Store statement.
*/
constexpr const char* opengl_stage_scope = "opengl_stage_scope";
/*!
* \brief Check if attr_key is a pragma key extension
* \param attr_key The attr key to be compared
* \return true if it is a pragma key
*/
inline bool IsPragmaKey(const std::string& attr_key) {
return attr_key.compare(0, 7, "pragma_") == 0;
}
} // namespace attr
/*! \brief namespace of TVM Intrinsic functions */
......
......@@ -185,10 +185,13 @@ class Stage : public NodeRef {
*
* \param var The axis to be parallelized.
* \param pragma_type The pragma type.
* \param pragma_value The pragma value
*
* \return reference to self.
*/
EXPORT Stage& pragma(IterVar var, const std::string& pragma_type); // NOLINT(*)
EXPORT Stage& pragma(IterVar var,
const std::string& pragma_type,
const Expr& pragma_value = Expr()); // NOLINT(*)
/*!
* \brief Fetch data in advance.
* \param domain the tensor to be prefetched
......@@ -539,9 +542,13 @@ class IterVarAttrNode : public Node {
/*! \brief Alignment offset of buffer dimension */
int dim_align_offset{0};
/*!
* \brief Additional pragmas, array of StringImm
* \brief Additional pragma keys, array of StringImm
*/
Array<Expr> pragmas;
Array<Expr> pragma_keys;
/*!
* \brief Additional values of pragma, if any
*/
Array<Expr> pragma_values;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("iter_type", &iter_type);
......@@ -551,7 +558,8 @@ class IterVarAttrNode : public Node {
v->Visit("tensor_intrin", &tensor_intrin);
v->Visit("dim_align_factor", &dim_align_factor);
v->Visit("dim_align_offset", &dim_align_offset);
v->Visit("pragmas", &pragmas);
v->Visit("pragma_keys", &pragma_keys);
v->Visit("pragma_values", &pragma_values);
}
static constexpr const char* _type_key = "IterVarAttr";
......
......@@ -552,7 +552,7 @@ class Stage(NodeBase):
"""
_api_internal._StageParallel(self, var)
def pragma(self, var, pragma_type):
def pragma(self, var, pragma_type, pragma_value=None):
"""Annotate the iteration with pragma
This will translate to a pragma_scope surrounding
......@@ -567,6 +567,9 @@ class Stage(NodeBase):
pragma_type : str
The pragma string to be annotated
pragma_value : Expr, optional
The pragma value to pass along the pragma
Note
----
Most pragmas are advanced/experimental features
......@@ -597,7 +600,7 @@ class Stage(NodeBase):
:code:`for (int i = task_id; i < end; i += num_task)`
"""
_api_internal._StagePragma(self, var, pragma_type)
_api_internal._StagePragma(self, var, pragma_type, pragma_value)
def prefetch(self, tensor, var, offset):
"""Prefetch the specified variable
......
......@@ -380,7 +380,7 @@ TVM_REGISTER_API("_StageParallel")
TVM_REGISTER_API("_StagePragma")
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage()
.pragma(args[1], args[2]);
.pragma(args[1], args[2], args[3]);
});
TVM_REGISTER_API("_StagePrefetch")
......
......@@ -683,16 +683,15 @@ void CodeGenCPU::VisitStmt_(const AttrStmt* op) {
this->CreateStaticInit(op->value.as<StringImm>()->value, op->body);
} else if (op->attr_key == ir::attr::compute_scope) {
this->CreateComputeScope(op);
} else if (op->attr_key == ir::attr::pragma_scope) {
const std::string& pname = op->value.as<StringImm>()->value;
if (pname == "parallel_stride_pattern") {
} else if (attr::IsPragmaKey(op->attr_key)) {
if (op->attr_key == "pragma_parallel_stride_pattern") {
CHECK(parallel_env_.penv != nullptr)
<< "Pragma parallel_stride_pattern only valid in parallel launch";
parallel_env_.stride_pattern = true;
this->VisitStmt(op->body);
} else if (pname == "parallel_launch_point") {
} else if (op->attr_key == "pragma_parallel_launch_point") {
CreateParallelLaunch(op->body, 0);
} else if (pname == "parallel_barrier_when_finish") {
} else if (op->attr_key == "pragma_parallel_barrier_when_finish") {
CHECK(parallel_env_.penv != nullptr)
<< "Cannot run barrier without parallel environment";
CHECK(!parallel_env_.in_parallel_loop)
......@@ -703,7 +702,7 @@ void CodeGenCPU::VisitStmt_(const AttrStmt* op) {
RuntimeTVMParallelBarrier(),
{MakeValue(parallel_env_.task_id), parallel_env_.penv});
} else {
LOG(WARNING) << "Unknown pragma " << pname;
LOG(WARNING) << "Unknown pragma " << op->attr_key;
this->VisitStmt(op->body);
}
} else {
......
......@@ -77,7 +77,7 @@ spirv::Value CodeGenSPIRV::GetThreadIndex(
spirv::Value v;
if (ts.rank == 1) {
v = builder_->GetLocalID(ts.dim_index);
int size;
int size = 0;
CHECK(arith::GetConstInt(extent, &size))
<< "SPIRV only allows constant thread group size " << " get " << extent;
CHECK_LT(ts.dim_index, 3);
......
......@@ -71,9 +71,15 @@ MakeLoopNest(const Stage& stage,
<< it_attr->iter_type
<< " in the iter_var_attrs";
}
for (Expr p : it_attr->pragmas) {
CHECK_EQ(it_attr->pragma_keys.size(), it_attr->pragma_values.size());
for (size_t k = 0; k < it_attr->pragma_keys.size(); ++k) {
const std::string& pkey = it_attr->pragma_keys[k].as<StringImm>()->value;
Expr pvalue = it_attr->pragma_values[k];
if (!pvalue.defined()) {
pvalue = make_const(Int(32), 1);
}
nest[i + 1].emplace_back(
AttrStmt::make(iv, ir::attr::pragma_scope, p, no_op));
AttrStmt::make(iv, ir::attr::pragma_scope_prefix + pkey, pvalue, no_op));
}
}
if (!debug_keep_trivial_loop && is_one(dom->extent)) {
......
......@@ -17,7 +17,7 @@ class CopyIntrinInjector : public IRMutator {
public:
CopyIntrinInjector(const std::string& pragma_key,
const PackedFunc& flower_copy_fromto)
: pragma_key_(pragma_key),
: pragma_key_(attr::pragma_scope_prefix+ pragma_key),
flower_copy_fromto_(flower_copy_fromto) {
}
......@@ -25,14 +25,11 @@ class CopyIntrinInjector : public IRMutator {
if (op->attr_key == attr::storage_scope) {
const Variable* buf = op->node.as<Variable>();
storage_scope_[buf] = op->value.as<StringImm>()->value;
} else if (op->attr_key == ir::attr::pragma_scope) {
const std::string& pname = op->value.as<StringImm>()->value;
if (pname == pragma_key_) {
Stmt ret;
CHECK(MatchCopyPattern(op->body, &ret))
<< "Cannot match copy pattern of " << op->body;
return ret;
}
} else if (op->attr_key == pragma_key_) {
Stmt ret;
CHECK(MatchCopyPattern(op->body, &ret))
<< "Cannot match copy pattern of " << op->body;
return ret;
}
return IRMutator::Mutate_(op, s);
}
......
......@@ -20,11 +20,8 @@ class NoOpRemover : public IRMutator {
return is_no_op(op->body) ? MakeEvaluate(op->value) : stmt;
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->attr_key == ir::attr::pragma_scope) {
const std::string& pname = op->value.as<StringImm>()->value;
if (pname == "debug_skip_region") {
return MakeEvaluate(0);
}
if (op->attr_key == "pragma_debug_skip_region") {
return MakeEvaluate(0);
}
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<AttrStmt>();
......
......@@ -401,7 +401,7 @@ class StoragePlanRewriter : public IRMutator {
return this->Mutate(op->body);
} else if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread ||
op->attr_key == attr::pragma_scope) {
attr::IsPragmaKey(op->attr_key)) {
// remake all the allocation at the attach scope.
if (attach_map_.count(op)) {
auto& svec = attach_map_[op];
......@@ -737,8 +737,8 @@ class StoragePlanRewriter : public IRMutator {
if (s.stmt->is_type<AttrStmt>()) {
const auto* op = static_cast<const AttrStmt*>(s.stmt);
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::pragma_scope ||
op->attr_key == attr::virtual_thread) {
op->attr_key == attr::virtual_thread ||
attr::IsPragmaKey(op->attr_key)) {
PlanNewScope(op);
} else {
CHECK(op->attr_key == attr::extern_scope);
......
......@@ -27,6 +27,27 @@ class LoopUnroller : public IRMutator {
explicit_unroll_(explicit_unroll) {
}
Stmt Mutate_(const AttrStmt* op, const Stmt& stmt) final {
if (op->attr_key == "pragma_auto_unroll_max_step") {
int value;
CHECK(arith::GetConstInt(op->value, &value));
std::swap(value, auto_max_step_);
Stmt ret = this->Mutate(op->body);
std::swap(value, auto_max_step_);
return ret;
} else if (op->attr_key == "pragma_unroll_explicit") {
int value;
CHECK(arith::GetConstInt(op->value, &value));
bool explicit_unroll = value;
std::swap(explicit_unroll, explicit_unroll_);
Stmt ret = this->Mutate(op->body);
std::swap(explicit_unroll, explicit_unroll_);
return ret;
} else {
return IRMutator::Mutate_(op, stmt);
}
}
Stmt Mutate_(const For* op, const Stmt& s) {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<For>();
......
......@@ -350,15 +350,19 @@ Stage& Stage::parallel(IterVar var) { // NOLINT(*)
return *this;
}
Stage& Stage::pragma(IterVar var, const std::string& pragma_type) { // NOLINT(*)
Stage& Stage::pragma(IterVar var,
const std::string& pragma_type,
const Expr& pragma_value) { // NOLINT(*)
if (pragma_type == "unroll") {
this->unroll(var);
} else if (pragma_type == "vectorize") {
this->vectorize(var);
} else {
UpdateIterVarAttr(operator->(), var, [pragma_type](IterVarAttrNode* n) {
n->pragmas.push_back(ir::StringImm::make(pragma_type));
});
UpdateIterVarAttr(
operator->(), var, [pragma_type, pragma_value](IterVarAttrNode* n) {
n->pragma_keys.push_back(ir::StringImm::make(pragma_type));
n->pragma_values.push_back(pragma_value);
});
}
return *this;
}
......
......@@ -110,7 +110,7 @@ def test_pragma():
s[T].pragma(xo, "pragma1")
s[T].pragma(xi, "vectorize")
VECTORIZE = tvm.schedule.IterVar.Vectorized
assert s[T].iter_var_attrs[xo].pragmas[0].value == "pragma1"
assert s[T].iter_var_attrs[xo].pragma_keys[0].value == "pragma1"
assert s[T].iter_var_attrs[xi].iter_type == VECTORIZE
......
import tvm
import os
def test_unroll_loop():
ib = tvm.ir_builder.create()
dtype = 'int64'
n = tvm.var('n')
Ab = tvm.decl_buffer((n, ), dtype)
i = tvm.var('i')
j = tvm.var('j')
Aptr = ib.buffer_ptr(Ab)
# for i in 0 to n-1:
stmt = tvm.make.For(
i, n, 2, 0, 0,
tvm.make.For(j, 0, 8, 3, 0,
tvm.make.Store(Ab.data,
tvm.make.Load(dtype, Ab.data, i) + 1,
j + 1)))
with ib.for_range(n, n + 2, name="i") as i:
with ib.for_range(0, 8, name="i", for_type="unroll") as j:
Aptr[j + 1] = Aptr[i] + 1
stmt = ib.get()
assert isinstance(stmt, tvm.stmt.For)
ret = tvm.ir_pass.UnrollLoop(stmt, 16, 8, 0, True)
assert not isinstance(ret, tvm.stmt.For)
......@@ -23,23 +23,18 @@ def test_unroll_loop():
assert isinstance(ret, tvm.stmt.For)
assert ret.for_type == tvm.stmt.For.Unrolled
ib = tvm.ir_builder.create()
ib.scope_attr(tvm.const(0), "pragma_auto_unroll_max_step", 16)
ib.emit(stmt)
wrapped = ib.get()
wrapped = tvm.make.Block(wrapped, stmt)
assert isinstance(ret, tvm.stmt.For)
ret = tvm.ir_pass.UnrollLoop(wrapped, 0, 8, 0, False)
assert isinstance(ret.first, tvm.stmt.For)
assert ret.first.for_type == tvm.stmt.For.Unrolled
assert isinstance(ret.rest, tvm.stmt.For)
assert ret.rest.for_type != tvm.stmt.For.Unrolled
if __name__ == "__main__":
with tvm.build_config(dump_pass_ir=True):
test_unroll_loop()
def end_with(*suffix):
ends = suffix
def run(s):
f = map(s.endswith, ends)
if True in f: return s
return run
file_list = os.listdir('./')
cc_file = end_with('.cc')
cc_file = filter(cc_file, file_list)
cc_file = [f for f in cc_file]
assert len(cc_file) == 3
for i in cc_file:
os.remove(i)
if __name__ == "__main__":
test_unroll_loop()
......@@ -20,6 +20,7 @@ def test_schedule1():
s = tvm.create_schedule(A1.op)
xo, xi = s[A1].split(A1.op.axis[0], 8)
s[A1].pragma(xo, "auto_unroll_max_step", 10)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment