Commit 46b4a914 by Tianqi Chen Committed by GitHub

[PASS] Refactor build config, allow implicit unroll pragma (#167)

parent 86e56824
......@@ -3,3 +3,5 @@ tvm.build
.. autofunction:: tvm.lower
.. autofunction:: tvm.build
.. autofunction:: tvm.build_config
......@@ -95,15 +95,12 @@ def test_gemm():
s[BB].bind(ty, thread_y)
s[BB].bind(tx, thread_x)
s[BB].vectorize(xi)
max_auto_unroll_step = 8
# correctness
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
f = tvm.build(s, [A, B, C], device,
max_auto_unroll_step=max_auto_unroll_step)
f = tvm.build(s, [A, B, C], device)
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
# launch the kernel.
n, m, l = nn, nn, nn
......@@ -117,7 +114,10 @@ def test_gemm():
np.testing.assert_allclose(
c.asnumpy(), np.dot(b_np.T, a_np), rtol=1e-5)
check_device("cuda")
with tvm.build_config(auto_unroll_max_step=32,
auto_unroll_min_depth=0,
unroll_explicit=False):
check_device("cuda")
if __name__ == "__main__":
test_gemm()
......@@ -147,8 +147,7 @@ def lstm():
def check_device(target):
num_step = n_num_step
flstm = tvm.build(s, [Xi2h, Wh2h, scan_h, scan_c],
target,
detect_global_barrier=DETECT_GLOBAL_BARRIER)
target)
ctx = tvm.gpu(0) if target == "cuda" else tvm.cl(0)
# launch the kernel.
scan_h_np = np.zeros(
......@@ -172,7 +171,12 @@ def lstm():
tgap = time.time() - tstart
print("Time cost=%g" % tgap)
check_device("cuda")
# set unroll_explicit for more readable code.
with tvm.build_config(
detect_global_barrier=DETECT_GLOBAL_BARRIER,
auto_unroll_max_step=128,
unroll_explicit=False):
check_device("cuda")
if __name__ == "__main__":
lstm()
......@@ -15,7 +15,7 @@ from tvm.contrib import nvcc_compiler
import numpy as np
# Quick knobs
TASK="rnn_matexp"
TASK="matexp"
USE_MANUAL_CODE = False
PERSIST_KERNEL = True
DETECT_GLOBAL_BARRIER = PERSIST_KERNEL
......@@ -44,7 +44,6 @@ def rnn_matexp():
n_num_step = 128
n_num_hidden = 1152
n_batch_size = 4
max_auto_unroll_step = 0
detect_global_barrier = DETECT_GLOBAL_BARRIER
num_step = tvm.var("num_step")
......@@ -111,10 +110,12 @@ def rnn_matexp():
s[SS].bind(tx, thread_x)
def check_device(target):
f = tvm.build(s, [s_scan, Whh],
target,
max_auto_unroll_step=max_auto_unroll_step,
detect_global_barrier=detect_global_barrier)
with tvm.build_config(
detect_global_barrier=detect_global_barrier,
auto_unroll_min_depth=2,
auto_unroll_max_step=128,
unroll_explicit=False):
f = tvm.build(s, [s_scan, Whh], target)
ctx = tvm.gpu(0) if target == "cuda" else tvm.cl(0)
# launch the kernel.
res_np = np.zeros(
......
......@@ -144,12 +144,16 @@ Stmt SplitPipeline(Stmt stmt, bool split_load);
Stmt NarrowChannelAccess(Stmt stmt);
/*!
* \brief unroll the constant loops
* \brief unroll the constant loop marked by unroll.
* This pass also automatically attach pragma unroll tag to loops which meets the standard.
*
* \param stmt The statment to be unrolled.
* \param max_auto_step The maximum step to stop performing automatic unrolling.
* \param auto_max_step The maximum step before stop attach automatic unroll
* \param auto_min_depth The minimum depth before we can start automatic unroll
* \param explicit_unroll Whether explicitly unroll the loop, or leave unroll annotation to codegen.
* \return Transformed stmt.
*/
Stmt UnrollLoop(Stmt stmt, int max_auto_step);
Stmt UnrollLoop(Stmt stmt, int auto_max_step, int auto_min_depth, bool explicit_unroll);
/*!
* \brief vectorize the constant loops
......
......@@ -161,7 +161,7 @@ class Stage : public NodeRef {
Stage& vectorize(IterVar var); // NOLINT(*)
/*!
* \brief Unroll iteration.
* \param var The axis to be vectorized.
* \param var The axis to be unrolled.
* \return reference to self.
*/
Stage& unroll(IterVar var); // NOLINT(*)
......
......@@ -26,4 +26,4 @@ from .intrin import *
from .node import register_node
from .ndarray import register_extension
from .schedule import create_schedule
from .build import build, lower
from .build import build, lower, build_config
......@@ -13,6 +13,77 @@ from . import collections
from . import module
from . import codegen
class BuildConfig(object):
"""Configuration scope to set a build config option.
Parameters
----------
kwargs
Keyword arguments of configurations to set.
"""
current = None
defaults = {
'auto_unroll_max_step': 0,
'auto_unroll_min_depth': 1,
'unroll_explicit': True,
'detect_global_barrier': True
}
def __init__(self, **kwargs):
self._old_scope = None
for k, _ in kwargs.items():
if k not in BuildConfig.defaults:
raise ValueError(
"invalid argument %s, candidates are %s" % (k, BuildConfig.defaults.keys()))
self._attr = kwargs
def __getattr__(self, name):
if name not in self._attr:
return BuildConfig.defaults[name]
return self._attr[name]
def __enter__(self):
# pylint: disable=protected-access
self._old_scope = BuildConfig.current
attr = BuildConfig.current._attr.copy()
attr.update(self._attr)
self._attr = attr
BuildConfig.current = self
return self
def __exit__(self, ptype, value, trace):
assert self._old_scope
BuildConfig.current = self._old_scope
BuildConfig.current = BuildConfig()
def build_config(**kwargs):
"""Configure the build behavior by setting config variables.
Parameters
----------
auto_unroll_max_step: int, default=0
Threshold of loop extent to be automatically unrolled.
auto_unroll_min_depth: int, default=1
The minimum loop nest level before the loop can be automatically unrolled.
unroll_explicit: bool, default=True
Whether explicitly unroll the loop, if set false, the unroll hint will
be passed to the CodeGen phase, which may generate pragma unroll hint.
Set this to be true if CodeGen support unroll pragma and
when we want to be more readable.
detect_global_barrier: bool, default=True
Whether detect global barrier.
Returns
-------
config: BuildConfig
The build configuration
"""
return BuildConfig(**kwargs)
def get_binds(args, binds=None):
"""Internal function to get binds and arg_list given arguments.
......@@ -49,12 +120,12 @@ def get_binds(args, binds=None):
raise ValueError("args must be Tensor, Buffer or Var")
return binds, arg_list
def lower(sch,
args,
name="default_function",
binds=None,
simple_mode=False,
max_auto_unroll_step=0):
simple_mode=False):
"""Lowering step before build into target.
Parameters
......@@ -76,9 +147,6 @@ def lower(sch,
Whether only output simple and compact statement, this will skip
LoopPartition, api wrapper generation and Unrolling.
max_auto_unroll_step: int, optional
Maximum step to perform automatic unrolling
Returns
-------
f : LoweredFunc or Stmt
......@@ -97,8 +165,12 @@ def lower(sch,
stmt = ir_pass.VectorizeLoop(stmt)
stmt = ir_pass.InjectVirtualThread(stmt)
stmt = ir_pass.StorageRewrite(stmt)
if not simple_mode:
stmt = ir_pass.UnrollLoop(stmt, max_auto_unroll_step)
cfg = BuildConfig.current
stmt = ir_pass.UnrollLoop(
stmt,
cfg.auto_unroll_max_step,
cfg.auto_unroll_min_depth,
cfg.unroll_explicit)
stmt = ir_pass.Simplify(stmt)
if simple_mode:
return stmt
......@@ -110,9 +182,7 @@ def build(sch,
target="llvm",
target_host=None,
name="default_function",
binds=None,
max_auto_unroll_step=0,
detect_global_barrier=True):
binds=None):
"""Build a function with arguments as signiture.
Parameters
......@@ -142,12 +212,6 @@ def build(sch,
Dictionary that maps the binding of symbolic buffer to Tensor.
By default, a new buffer is created for each tensor in the argument.
max_auto_unroll_step: int, optional
Maximum step to perform automatic unrolling
detect_global_barrier: boolean, optional
Whether detect and inser global barrier
Returns
-------
f : Function, or pair of functions
......@@ -158,8 +222,7 @@ def build(sch,
raise ValueError("args must be given for build from schedule")
fapi = lower(sch, args,
name=name,
binds=binds,
max_auto_unroll_step=max_auto_unroll_step)
binds=binds)
elif isinstance(sch, collections.LoweredFunc):
if args:
raise ValueError("args must be done when build from LoweredFunc")
......@@ -167,7 +230,7 @@ def build(sch,
else:
raise ValueError("sch have to be Schedule or LoweredFunc")
# device related lowering
if detect_global_barrier:
if BuildConfig.current.detect_global_barrier:
fapi = ir_pass.StorageSync(fapi, "global")
fapi = ir_pass.StorageSync(fapi, "shared")
warp_size = 32 if target == "cuda" else 1
......
......@@ -51,6 +51,12 @@ TVM_REGISTER_API("ir_pass.PostOrderVisit")
*ret = PassName(args[0], args[1]); \
}) \
#define REGISTER_PASS3(PassName) \
TVM_REGISTER_API("ir_pass."#PassName) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = PassName(args[0], args[1], args[2]); \
}) \
#define REGISTER_PASS4(PassName) \
TVM_REGISTER_API("ir_pass."#PassName) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
......@@ -64,7 +70,7 @@ REGISTER_PASS4(Inline);
REGISTER_PASS2(StorageFlatten);
REGISTER_PASS1(VectorizeLoop);
REGISTER_PASS2(ExprUseVar);
REGISTER_PASS2(UnrollLoop);
REGISTER_PASS4(UnrollLoop);
REGISTER_PASS2(StorageSync);
REGISTER_PASS4(MakeAPI);
REGISTER_PASS1(SplitHostDevice);
......
......@@ -27,10 +27,8 @@ void CodeGenCUDA::AddFunction(LoweredFunc f) {
}
void CodeGenCUDA::VisitStmt_(const ir::For* op) {
int ext;
CHECK(is_zero(op->min));
if (arith::GetConstInt(op->extent, &ext) &&
ext <= max_auto_unroll_) {
if (op->for_type == ir::ForType::Unrolled) {
PrintIndent();
stream << "#pragma unroll\n";
}
......
......@@ -36,9 +36,6 @@ class CodeGenCUDA final : public CodeGenC {
void VisitStmt_(const Evaluate *op) final;
private:
// magic number to add pragma unroll to it.
// used to generate code that is compact but still unrolls.
int max_auto_unroll_{32};
// Whether global barrier is needed.
bool need_global_barrier_{false};
// Global barrier state
......
......@@ -16,8 +16,12 @@ namespace ir {
class LoopUnroller : public IRMutator {
public:
explicit LoopUnroller(int max_auto_step)
: max_auto_step_(max_auto_step) {
explicit LoopUnroller(int auto_max_step,
int auto_min_depth,
bool explicit_unroll)
: auto_max_step_(auto_max_step),
auto_min_depth_(auto_min_depth),
explicit_unroll_(explicit_unroll) {
}
Stmt Mutate_(const For* op, const Stmt& s) {
......@@ -33,15 +37,16 @@ class LoopUnroller : public IRMutator {
if (v2 != nullptr) {
value = static_cast<int>(v2->value);
}
bool allow_unroll = (op->for_type == ForType::Serial &&
value >= 0 && value <= max_auto_step_);
bool auto_unroll = (op->for_type == ForType::Serial &&
value >= 0 && value <= auto_max_step_ &&
loop_depth_ >= auto_min_depth_);
if (op->for_type == ForType::Unrolled) {
CHECK_GE(value, 0)
<< "Cannot unroll non-constant loop";
allow_unroll = true;
auto_unroll = true;
}
if (allow_unroll) {
if (auto_unroll && explicit_unroll_) {
using arith::ComputeExpr;
if (value == 0) return Evaluate::make(0);
Stmt body = op->body;
......@@ -59,20 +64,48 @@ class LoopUnroller : public IRMutator {
unrolled = step;
}
}
return this->Mutate(unrolled);
++loop_depth_;
Stmt ret = this->Mutate(unrolled);
--loop_depth_;
return ret;
} else {
return IRMutator::Mutate_(op, stmt);
++loop_depth_;
Stmt ret = IRMutator::Mutate_(op, stmt);
if (auto_unroll) {
op = ret.as<For>();
if (op->for_type != ForType::Unrolled) {
ret = For::make(
op->loop_var, op->min, op->extent,
ForType::Unrolled, op->device_api, op->body);
}
}
--loop_depth_;
return ret;
}
}
private:
int max_auto_step_;
// maximum number of step to perform auto unroll.
int auto_max_step_;
int auto_min_depth_;
bool explicit_unroll_;
int loop_depth_{0};
};
Stmt UnrollLoop(Stmt stmt, int max_auto_step) {
Stmt ret = LoopUnroller(max_auto_step).Mutate(stmt);
return ConvertSSA(ret);
Stmt UnrollLoop(Stmt stmt,
int auto_max_step,
int auto_min_depth,
bool explicit_unroll) {
Stmt ret = LoopUnroller(
auto_max_step,
auto_min_depth,
explicit_unroll).Mutate(stmt);
if (!ret.same_as(stmt)) {
return ConvertSSA(ret);
} else {
return ret;
}
}
} // namespace ir
......
......@@ -58,7 +58,6 @@ def test_gemm():
s[BB].bind(ty, thread_y)
s[BB].bind(tx, thread_x)
max_auto_unroll_step = 0
# lowering test
s = s.normalize()
......@@ -68,8 +67,7 @@ def test_gemm():
print("skip because %s is not enabled.." % device)
return
f = tvm.build(s, [A, B, C], device,
max_auto_unroll_step=max_auto_unroll_step)
f = tvm.build(s, [A, B, C], device)
ctx = tvm.context(device, 0)
# launch the kernel.
n = nn
......
......@@ -14,9 +14,11 @@ def test_unroll_loop():
tvm.make.Load(dtype, Ab.data, i) + 1,
j + 1)))
assert isinstance(stmt, tvm.stmt.For)
stmt = tvm.ir_pass.UnrollLoop(stmt, 4)
assert not isinstance(stmt, tvm.stmt.For)
print(stmt)
ret = tvm.ir_pass.UnrollLoop(stmt, 2, 0, True)
assert not isinstance(ret, tvm.stmt.For)
ret = tvm.ir_pass.UnrollLoop(stmt, 4, 0, False)
assert isinstance(ret, tvm.stmt.For)
assert ret.for_type == tvm.stmt.For.Unrolled
if __name__ == "__main__":
test_unroll_loop()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment