Commit 46b4a914 by Tianqi Chen Committed by GitHub

[PASS] Refactor build config, allow implicit unroll pragma (#167)

parent 86e56824
...@@ -3,3 +3,5 @@ tvm.build ...@@ -3,3 +3,5 @@ tvm.build
.. autofunction:: tvm.lower .. autofunction:: tvm.lower
.. autofunction:: tvm.build .. autofunction:: tvm.build
.. autofunction:: tvm.build_config
...@@ -95,15 +95,12 @@ def test_gemm(): ...@@ -95,15 +95,12 @@ def test_gemm():
s[BB].bind(ty, thread_y) s[BB].bind(ty, thread_y)
s[BB].bind(tx, thread_x) s[BB].bind(tx, thread_x)
s[BB].vectorize(xi) s[BB].vectorize(xi)
max_auto_unroll_step = 8
# correctness # correctness
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
f = tvm.build(s, [A, B, C], device, f = tvm.build(s, [A, B, C], device)
max_auto_unroll_step=max_auto_unroll_step)
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
# launch the kernel. # launch the kernel.
n, m, l = nn, nn, nn n, m, l = nn, nn, nn
...@@ -117,7 +114,10 @@ def test_gemm(): ...@@ -117,7 +114,10 @@ def test_gemm():
np.testing.assert_allclose( np.testing.assert_allclose(
c.asnumpy(), np.dot(b_np.T, a_np), rtol=1e-5) c.asnumpy(), np.dot(b_np.T, a_np), rtol=1e-5)
check_device("cuda") with tvm.build_config(auto_unroll_max_step=32,
auto_unroll_min_depth=0,
unroll_explicit=False):
check_device("cuda")
if __name__ == "__main__": if __name__ == "__main__":
test_gemm() test_gemm()
...@@ -147,8 +147,7 @@ def lstm(): ...@@ -147,8 +147,7 @@ def lstm():
def check_device(target): def check_device(target):
num_step = n_num_step num_step = n_num_step
flstm = tvm.build(s, [Xi2h, Wh2h, scan_h, scan_c], flstm = tvm.build(s, [Xi2h, Wh2h, scan_h, scan_c],
target, target)
detect_global_barrier=DETECT_GLOBAL_BARRIER)
ctx = tvm.gpu(0) if target == "cuda" else tvm.cl(0) ctx = tvm.gpu(0) if target == "cuda" else tvm.cl(0)
# launch the kernel. # launch the kernel.
scan_h_np = np.zeros( scan_h_np = np.zeros(
...@@ -172,7 +171,12 @@ def lstm(): ...@@ -172,7 +171,12 @@ def lstm():
tgap = time.time() - tstart tgap = time.time() - tstart
print("Time cost=%g" % tgap) print("Time cost=%g" % tgap)
check_device("cuda") # set unroll_explicit for more readable code.
with tvm.build_config(
detect_global_barrier=DETECT_GLOBAL_BARRIER,
auto_unroll_max_step=128,
unroll_explicit=False):
check_device("cuda")
if __name__ == "__main__": if __name__ == "__main__":
lstm() lstm()
...@@ -15,7 +15,7 @@ from tvm.contrib import nvcc_compiler ...@@ -15,7 +15,7 @@ from tvm.contrib import nvcc_compiler
import numpy as np import numpy as np
# Quick knobs # Quick knobs
TASK="rnn_matexp" TASK="matexp"
USE_MANUAL_CODE = False USE_MANUAL_CODE = False
PERSIST_KERNEL = True PERSIST_KERNEL = True
DETECT_GLOBAL_BARRIER = PERSIST_KERNEL DETECT_GLOBAL_BARRIER = PERSIST_KERNEL
...@@ -44,7 +44,6 @@ def rnn_matexp(): ...@@ -44,7 +44,6 @@ def rnn_matexp():
n_num_step = 128 n_num_step = 128
n_num_hidden = 1152 n_num_hidden = 1152
n_batch_size = 4 n_batch_size = 4
max_auto_unroll_step = 0
detect_global_barrier = DETECT_GLOBAL_BARRIER detect_global_barrier = DETECT_GLOBAL_BARRIER
num_step = tvm.var("num_step") num_step = tvm.var("num_step")
...@@ -111,10 +110,12 @@ def rnn_matexp(): ...@@ -111,10 +110,12 @@ def rnn_matexp():
s[SS].bind(tx, thread_x) s[SS].bind(tx, thread_x)
def check_device(target): def check_device(target):
f = tvm.build(s, [s_scan, Whh], with tvm.build_config(
target, detect_global_barrier=detect_global_barrier,
max_auto_unroll_step=max_auto_unroll_step, auto_unroll_min_depth=2,
detect_global_barrier=detect_global_barrier) auto_unroll_max_step=128,
unroll_explicit=False):
f = tvm.build(s, [s_scan, Whh], target)
ctx = tvm.gpu(0) if target == "cuda" else tvm.cl(0) ctx = tvm.gpu(0) if target == "cuda" else tvm.cl(0)
# launch the kernel. # launch the kernel.
res_np = np.zeros( res_np = np.zeros(
......
...@@ -144,12 +144,16 @@ Stmt SplitPipeline(Stmt stmt, bool split_load); ...@@ -144,12 +144,16 @@ Stmt SplitPipeline(Stmt stmt, bool split_load);
Stmt NarrowChannelAccess(Stmt stmt); Stmt NarrowChannelAccess(Stmt stmt);
/*! /*!
* \brief unroll the constant loops * \brief unroll the constant loop marked by unroll.
* This pass also automatically attach pragma unroll tag to loops which meets the standard.
*
* \param stmt The statment to be unrolled. * \param stmt The statment to be unrolled.
* \param max_auto_step The maximum step to stop performing automatic unrolling. * \param auto_max_step The maximum step before stop attach automatic unroll
* \param auto_min_depth The minimum depth before we can start automatic unroll
* \param explicit_unroll Whether explicitly unroll the loop, or leave unroll annotation to codegen.
* \return Transformed stmt. * \return Transformed stmt.
*/ */
Stmt UnrollLoop(Stmt stmt, int max_auto_step); Stmt UnrollLoop(Stmt stmt, int auto_max_step, int auto_min_depth, bool explicit_unroll);
/*! /*!
* \brief vectorize the constant loops * \brief vectorize the constant loops
......
...@@ -161,7 +161,7 @@ class Stage : public NodeRef { ...@@ -161,7 +161,7 @@ class Stage : public NodeRef {
Stage& vectorize(IterVar var); // NOLINT(*) Stage& vectorize(IterVar var); // NOLINT(*)
/*! /*!
* \brief Unroll iteration. * \brief Unroll iteration.
* \param var The axis to be vectorized. * \param var The axis to be unrolled.
* \return reference to self. * \return reference to self.
*/ */
Stage& unroll(IterVar var); // NOLINT(*) Stage& unroll(IterVar var); // NOLINT(*)
......
...@@ -26,4 +26,4 @@ from .intrin import * ...@@ -26,4 +26,4 @@ from .intrin import *
from .node import register_node from .node import register_node
from .ndarray import register_extension from .ndarray import register_extension
from .schedule import create_schedule from .schedule import create_schedule
from .build import build, lower from .build import build, lower, build_config
...@@ -13,6 +13,77 @@ from . import collections ...@@ -13,6 +13,77 @@ from . import collections
from . import module from . import module
from . import codegen from . import codegen
class BuildConfig(object):
"""Configuration scope to set a build config option.
Parameters
----------
kwargs
Keyword arguments of configurations to set.
"""
current = None
defaults = {
'auto_unroll_max_step': 0,
'auto_unroll_min_depth': 1,
'unroll_explicit': True,
'detect_global_barrier': True
}
def __init__(self, **kwargs):
self._old_scope = None
for k, _ in kwargs.items():
if k not in BuildConfig.defaults:
raise ValueError(
"invalid argument %s, candidates are %s" % (k, BuildConfig.defaults.keys()))
self._attr = kwargs
def __getattr__(self, name):
if name not in self._attr:
return BuildConfig.defaults[name]
return self._attr[name]
def __enter__(self):
# pylint: disable=protected-access
self._old_scope = BuildConfig.current
attr = BuildConfig.current._attr.copy()
attr.update(self._attr)
self._attr = attr
BuildConfig.current = self
return self
def __exit__(self, ptype, value, trace):
assert self._old_scope
BuildConfig.current = self._old_scope
BuildConfig.current = BuildConfig()
def build_config(**kwargs):
"""Configure the build behavior by setting config variables.
Parameters
----------
auto_unroll_max_step: int, default=0
Threshold of loop extent to be automatically unrolled.
auto_unroll_min_depth: int, default=1
The minimum loop nest level before the loop can be automatically unrolled.
unroll_explicit: bool, default=True
Whether explicitly unroll the loop, if set false, the unroll hint will
be passed to the CodeGen phase, which may generate pragma unroll hint.
Set this to be true if CodeGen support unroll pragma and
when we want to be more readable.
detect_global_barrier: bool, default=True
Whether detect global barrier.
Returns
-------
config: BuildConfig
The build configuration
"""
return BuildConfig(**kwargs)
def get_binds(args, binds=None): def get_binds(args, binds=None):
"""Internal function to get binds and arg_list given arguments. """Internal function to get binds and arg_list given arguments.
...@@ -49,12 +120,12 @@ def get_binds(args, binds=None): ...@@ -49,12 +120,12 @@ def get_binds(args, binds=None):
raise ValueError("args must be Tensor, Buffer or Var") raise ValueError("args must be Tensor, Buffer or Var")
return binds, arg_list return binds, arg_list
def lower(sch, def lower(sch,
args, args,
name="default_function", name="default_function",
binds=None, binds=None,
simple_mode=False, simple_mode=False):
max_auto_unroll_step=0):
"""Lowering step before build into target. """Lowering step before build into target.
Parameters Parameters
...@@ -76,9 +147,6 @@ def lower(sch, ...@@ -76,9 +147,6 @@ def lower(sch,
Whether only output simple and compact statement, this will skip Whether only output simple and compact statement, this will skip
LoopPartition, api wrapper generation and Unrolling. LoopPartition, api wrapper generation and Unrolling.
max_auto_unroll_step: int, optional
Maximum step to perform automatic unrolling
Returns Returns
------- -------
f : LoweredFunc or Stmt f : LoweredFunc or Stmt
...@@ -97,8 +165,12 @@ def lower(sch, ...@@ -97,8 +165,12 @@ def lower(sch,
stmt = ir_pass.VectorizeLoop(stmt) stmt = ir_pass.VectorizeLoop(stmt)
stmt = ir_pass.InjectVirtualThread(stmt) stmt = ir_pass.InjectVirtualThread(stmt)
stmt = ir_pass.StorageRewrite(stmt) stmt = ir_pass.StorageRewrite(stmt)
if not simple_mode: cfg = BuildConfig.current
stmt = ir_pass.UnrollLoop(stmt, max_auto_unroll_step) stmt = ir_pass.UnrollLoop(
stmt,
cfg.auto_unroll_max_step,
cfg.auto_unroll_min_depth,
cfg.unroll_explicit)
stmt = ir_pass.Simplify(stmt) stmt = ir_pass.Simplify(stmt)
if simple_mode: if simple_mode:
return stmt return stmt
...@@ -110,9 +182,7 @@ def build(sch, ...@@ -110,9 +182,7 @@ def build(sch,
target="llvm", target="llvm",
target_host=None, target_host=None,
name="default_function", name="default_function",
binds=None, binds=None):
max_auto_unroll_step=0,
detect_global_barrier=True):
"""Build a function with arguments as signiture. """Build a function with arguments as signiture.
Parameters Parameters
...@@ -142,12 +212,6 @@ def build(sch, ...@@ -142,12 +212,6 @@ def build(sch,
Dictionary that maps the binding of symbolic buffer to Tensor. Dictionary that maps the binding of symbolic buffer to Tensor.
By default, a new buffer is created for each tensor in the argument. By default, a new buffer is created for each tensor in the argument.
max_auto_unroll_step: int, optional
Maximum step to perform automatic unrolling
detect_global_barrier: boolean, optional
Whether detect and inser global barrier
Returns Returns
------- -------
f : Function, or pair of functions f : Function, or pair of functions
...@@ -158,8 +222,7 @@ def build(sch, ...@@ -158,8 +222,7 @@ def build(sch,
raise ValueError("args must be given for build from schedule") raise ValueError("args must be given for build from schedule")
fapi = lower(sch, args, fapi = lower(sch, args,
name=name, name=name,
binds=binds, binds=binds)
max_auto_unroll_step=max_auto_unroll_step)
elif isinstance(sch, collections.LoweredFunc): elif isinstance(sch, collections.LoweredFunc):
if args: if args:
raise ValueError("args must be done when build from LoweredFunc") raise ValueError("args must be done when build from LoweredFunc")
...@@ -167,7 +230,7 @@ def build(sch, ...@@ -167,7 +230,7 @@ def build(sch,
else: else:
raise ValueError("sch have to be Schedule or LoweredFunc") raise ValueError("sch have to be Schedule or LoweredFunc")
# device related lowering # device related lowering
if detect_global_barrier: if BuildConfig.current.detect_global_barrier:
fapi = ir_pass.StorageSync(fapi, "global") fapi = ir_pass.StorageSync(fapi, "global")
fapi = ir_pass.StorageSync(fapi, "shared") fapi = ir_pass.StorageSync(fapi, "shared")
warp_size = 32 if target == "cuda" else 1 warp_size = 32 if target == "cuda" else 1
......
...@@ -51,6 +51,12 @@ TVM_REGISTER_API("ir_pass.PostOrderVisit") ...@@ -51,6 +51,12 @@ TVM_REGISTER_API("ir_pass.PostOrderVisit")
*ret = PassName(args[0], args[1]); \ *ret = PassName(args[0], args[1]); \
}) \ }) \
#define REGISTER_PASS3(PassName) \
TVM_REGISTER_API("ir_pass."#PassName) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = PassName(args[0], args[1], args[2]); \
}) \
#define REGISTER_PASS4(PassName) \ #define REGISTER_PASS4(PassName) \
TVM_REGISTER_API("ir_pass."#PassName) \ TVM_REGISTER_API("ir_pass."#PassName) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \ .set_body([](TVMArgs args, TVMRetValue *ret) { \
...@@ -64,7 +70,7 @@ REGISTER_PASS4(Inline); ...@@ -64,7 +70,7 @@ REGISTER_PASS4(Inline);
REGISTER_PASS2(StorageFlatten); REGISTER_PASS2(StorageFlatten);
REGISTER_PASS1(VectorizeLoop); REGISTER_PASS1(VectorizeLoop);
REGISTER_PASS2(ExprUseVar); REGISTER_PASS2(ExprUseVar);
REGISTER_PASS2(UnrollLoop); REGISTER_PASS4(UnrollLoop);
REGISTER_PASS2(StorageSync); REGISTER_PASS2(StorageSync);
REGISTER_PASS4(MakeAPI); REGISTER_PASS4(MakeAPI);
REGISTER_PASS1(SplitHostDevice); REGISTER_PASS1(SplitHostDevice);
......
...@@ -27,10 +27,8 @@ void CodeGenCUDA::AddFunction(LoweredFunc f) { ...@@ -27,10 +27,8 @@ void CodeGenCUDA::AddFunction(LoweredFunc f) {
} }
void CodeGenCUDA::VisitStmt_(const ir::For* op) { void CodeGenCUDA::VisitStmt_(const ir::For* op) {
int ext;
CHECK(is_zero(op->min)); CHECK(is_zero(op->min));
if (arith::GetConstInt(op->extent, &ext) && if (op->for_type == ir::ForType::Unrolled) {
ext <= max_auto_unroll_) {
PrintIndent(); PrintIndent();
stream << "#pragma unroll\n"; stream << "#pragma unroll\n";
} }
......
...@@ -36,9 +36,6 @@ class CodeGenCUDA final : public CodeGenC { ...@@ -36,9 +36,6 @@ class CodeGenCUDA final : public CodeGenC {
void VisitStmt_(const Evaluate *op) final; void VisitStmt_(const Evaluate *op) final;
private: private:
// magic number to add pragma unroll to it.
// used to generate code that is compact but still unrolls.
int max_auto_unroll_{32};
// Whether global barrier is needed. // Whether global barrier is needed.
bool need_global_barrier_{false}; bool need_global_barrier_{false};
// Global barrier state // Global barrier state
......
...@@ -16,8 +16,12 @@ namespace ir { ...@@ -16,8 +16,12 @@ namespace ir {
class LoopUnroller : public IRMutator { class LoopUnroller : public IRMutator {
public: public:
explicit LoopUnroller(int max_auto_step) explicit LoopUnroller(int auto_max_step,
: max_auto_step_(max_auto_step) { int auto_min_depth,
bool explicit_unroll)
: auto_max_step_(auto_max_step),
auto_min_depth_(auto_min_depth),
explicit_unroll_(explicit_unroll) {
} }
Stmt Mutate_(const For* op, const Stmt& s) { Stmt Mutate_(const For* op, const Stmt& s) {
...@@ -33,15 +37,16 @@ class LoopUnroller : public IRMutator { ...@@ -33,15 +37,16 @@ class LoopUnroller : public IRMutator {
if (v2 != nullptr) { if (v2 != nullptr) {
value = static_cast<int>(v2->value); value = static_cast<int>(v2->value);
} }
bool allow_unroll = (op->for_type == ForType::Serial && bool auto_unroll = (op->for_type == ForType::Serial &&
value >= 0 && value <= max_auto_step_); value >= 0 && value <= auto_max_step_ &&
loop_depth_ >= auto_min_depth_);
if (op->for_type == ForType::Unrolled) { if (op->for_type == ForType::Unrolled) {
CHECK_GE(value, 0) CHECK_GE(value, 0)
<< "Cannot unroll non-constant loop"; << "Cannot unroll non-constant loop";
allow_unroll = true; auto_unroll = true;
} }
if (allow_unroll) { if (auto_unroll && explicit_unroll_) {
using arith::ComputeExpr; using arith::ComputeExpr;
if (value == 0) return Evaluate::make(0); if (value == 0) return Evaluate::make(0);
Stmt body = op->body; Stmt body = op->body;
...@@ -59,20 +64,48 @@ class LoopUnroller : public IRMutator { ...@@ -59,20 +64,48 @@ class LoopUnroller : public IRMutator {
unrolled = step; unrolled = step;
} }
} }
return this->Mutate(unrolled); ++loop_depth_;
Stmt ret = this->Mutate(unrolled);
--loop_depth_;
return ret;
} else { } else {
return IRMutator::Mutate_(op, stmt); ++loop_depth_;
Stmt ret = IRMutator::Mutate_(op, stmt);
if (auto_unroll) {
op = ret.as<For>();
if (op->for_type != ForType::Unrolled) {
ret = For::make(
op->loop_var, op->min, op->extent,
ForType::Unrolled, op->device_api, op->body);
}
}
--loop_depth_;
return ret;
} }
} }
private: private:
int max_auto_step_; // maximum number of step to perform auto unroll.
int auto_max_step_;
int auto_min_depth_;
bool explicit_unroll_;
int loop_depth_{0};
}; };
Stmt UnrollLoop(Stmt stmt, int max_auto_step) { Stmt UnrollLoop(Stmt stmt,
Stmt ret = LoopUnroller(max_auto_step).Mutate(stmt); int auto_max_step,
return ConvertSSA(ret); int auto_min_depth,
bool explicit_unroll) {
Stmt ret = LoopUnroller(
auto_max_step,
auto_min_depth,
explicit_unroll).Mutate(stmt);
if (!ret.same_as(stmt)) {
return ConvertSSA(ret);
} else {
return ret;
}
} }
} // namespace ir } // namespace ir
......
...@@ -58,7 +58,6 @@ def test_gemm(): ...@@ -58,7 +58,6 @@ def test_gemm():
s[BB].bind(ty, thread_y) s[BB].bind(ty, thread_y)
s[BB].bind(tx, thread_x) s[BB].bind(tx, thread_x)
max_auto_unroll_step = 0
# lowering test # lowering test
s = s.normalize() s = s.normalize()
...@@ -68,8 +67,7 @@ def test_gemm(): ...@@ -68,8 +67,7 @@ def test_gemm():
print("skip because %s is not enabled.." % device) print("skip because %s is not enabled.." % device)
return return
f = tvm.build(s, [A, B, C], device, f = tvm.build(s, [A, B, C], device)
max_auto_unroll_step=max_auto_unroll_step)
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
# launch the kernel. # launch the kernel.
n = nn n = nn
......
...@@ -14,9 +14,11 @@ def test_unroll_loop(): ...@@ -14,9 +14,11 @@ def test_unroll_loop():
tvm.make.Load(dtype, Ab.data, i) + 1, tvm.make.Load(dtype, Ab.data, i) + 1,
j + 1))) j + 1)))
assert isinstance(stmt, tvm.stmt.For) assert isinstance(stmt, tvm.stmt.For)
stmt = tvm.ir_pass.UnrollLoop(stmt, 4) ret = tvm.ir_pass.UnrollLoop(stmt, 2, 0, True)
assert not isinstance(stmt, tvm.stmt.For) assert not isinstance(ret, tvm.stmt.For)
print(stmt) ret = tvm.ir_pass.UnrollLoop(stmt, 4, 0, False)
assert isinstance(ret, tvm.stmt.For)
assert ret.for_type == tvm.stmt.For.Unrolled
if __name__ == "__main__": if __name__ == "__main__":
test_unroll_loop() test_unroll_loop()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment