Commit bf8a5c07 by Tianqi Chen Committed by GitHub

[SCHEDULE] Add store_predicate (#131)

parent 2112a1f9
...@@ -76,7 +76,7 @@ else ...@@ -76,7 +76,7 @@ else
endif endif
# llvm configuration # llvm configuration
ifeq ($(USE_LLVM), 1) ifdef LLVM_CONFIG
LLVM_VERSION=$(shell $(LLVM_CONFIG) --version| cut -b 1,3) LLVM_VERSION=$(shell $(LLVM_CONFIG) --version| cut -b 1,3)
LLVM_INCLUDE=$(filter -I%, $(shell $(LLVM_CONFIG) --cxxflags)) LLVM_INCLUDE=$(filter -I%, $(shell $(LLVM_CONFIG) --cxxflags))
LDFLAGS += $(shell $(LLVM_CONFIG) --ldflags --libs --system-libs) LDFLAGS += $(shell $(LLVM_CONFIG) --ldflags --libs --system-libs)
......
...@@ -82,10 +82,22 @@ class Stage : public NodeRef { ...@@ -82,10 +82,22 @@ class Stage : public NodeRef {
*/ */
Stage& bind(IterVar ivar, IterVar thread_ivar); Stage& bind(IterVar ivar, IterVar thread_ivar);
/*! /*!
* \brief Set predicate under which store to the array can be performed.
* Use this when there are duplicated threads doing the same store and we only
* need one of them to do the store.
*
* \note This is a dangerous scheduling primitive that can change behavior of program.
* Only do when we are certain that thare are duplicated store.
* \param predicate The condition to be checked.
* \return reference to self.
*/
Stage& set_store_predicate(Expr predicate);
/*!
* \brief Specify environment threads that launched around the group's scope. * \brief Specify environment threads that launched around the group's scope.
* This can only be used in group stage. * This can only be used in group stage.
* \param threads The threads to be launched around the scope. * \param threads The threads to be launched around the scope.
* \note Each thread can only appear in one env_threads. * \note Each thread can only appear in one env_threads.
* This is a beta feature.
* \return reference to self. * \return reference to self.
*/ */
Stage& env_threads(Array<IterVar> threads); Stage& env_threads(Array<IterVar> threads);
...@@ -341,8 +353,15 @@ class StageNode : public Node { ...@@ -341,8 +353,15 @@ class StageNode : public Node {
/*! /*!
* \brief Specify threads to be launched at the stage. * \brief Specify threads to be launched at the stage.
* This is only valid for composite ops such as Scan. * This is only valid for composite ops such as Scan.
* \note Experimental primitive: used for thread persistence.
*/ */
Array<IterVar> env_threads; Array<IterVar> env_threads;
/*!
* \brief The predicate under which store can happen
* Use this when there can be duplicated threads doing the same store.
* \note Experimental primitive: used by cross thread-reduction.
*/
Expr store_predicate;
/*! \brief The relation bwteen of IterVars */ /*! \brief The relation bwteen of IterVars */
Array<IterVarRelation> relations; Array<IterVarRelation> relations;
/*! \brief additional attributes about iter var. */ /*! \brief additional attributes about iter var. */
......
...@@ -39,9 +39,9 @@ USE_METAL = 0 ...@@ -39,9 +39,9 @@ USE_METAL = 0
# whether build with LLVM support # whether build with LLVM support
# Requires LLVM version >= 4.0 # Requires LLVM version >= 4.0
# Set LLVM_CONFIG to your version # Set LLVM_CONFIG to your version, uncomment to build with llvm support
# LLVM_CONFIG = llvm-config-4.0 #
USE_LLVM = 0 # LLVM_CONFIG = llvm-config
#--------------------------------------------- #---------------------------------------------
# Contrib optional libraries. # Contrib optional libraries.
......
...@@ -85,7 +85,7 @@ def build(sch, ...@@ -85,7 +85,7 @@ def build(sch,
target_host=None, target_host=None,
name="default_function", name="default_function",
binds=None, binds=None,
max_auto_unroll_step=8, max_auto_unroll_step=0,
detect_global_barrier=True): detect_global_barrier=True):
"""Build a function with arguments as signiture. """Build a function with arguments as signiture.
......
...@@ -63,7 +63,7 @@ class ExprOp(object): ...@@ -63,7 +63,7 @@ class ExprOp(object):
return _make.LE(self, other) return _make.LE(self, other)
def __eq__(self, other): def __eq__(self, other):
return _make.EQ(self, other) return self.equal(other)
def __ne__(self, other): def __ne__(self, other):
return _make.NE(self, other) return _make.NE(self, other)
...@@ -74,6 +74,21 @@ class ExprOp(object): ...@@ -74,6 +74,21 @@ class ExprOp(object):
def __ge__(self, other): def __ge__(self, other):
return _make.GE(self, other) return _make.GE(self, other)
def equal(self, other):
"""Build an equal check expression with other expr.
Parameters
----------
other : Expr
The other expression
Returns
-------
ret : Expr
The equality expression.
"""
return _make.EQ(self, other)
class Expr(NodeBase, ExprOp): class Expr(NodeBase, ExprOp):
"""Base class of all tvm Expressions""" """Base class of all tvm Expressions"""
......
...@@ -276,6 +276,19 @@ class Stage(NodeBase): ...@@ -276,6 +276,19 @@ class Stage(NodeBase):
threads = [threads] threads = [threads]
_api_internal._StageEnvThreads(self, threads) _api_internal._StageEnvThreads(self, threads)
def set_store_predicate(self, predicate):
"""Set predicate under which store to the array can be performed.
Use this when there are duplicated threads doing the same store and we only
need one of them to do the store.
Parameters
----------
predicate : Expr
The guard condition fo store.
"""
_api_internal._StageSetStorePredicate(self, predicate)
def compute_at(self, parent, scope): def compute_at(self, parent, scope):
"""Attach the stage at parent's scope """Attach the stage at parent's scope
......
...@@ -307,6 +307,12 @@ TVM_REGISTER_API("_StageEnvThreads") ...@@ -307,6 +307,12 @@ TVM_REGISTER_API("_StageEnvThreads")
.env_threads(args[1]); .env_threads(args[1]);
}); });
TVM_REGISTER_API("_StageSetStorePredicate")
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage()
.set_store_predicate(args[1]);
});
TVM_REGISTER_API("_StageUnroll") TVM_REGISTER_API("_StageUnroll")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage() args[0].operator Stage()
......
...@@ -38,7 +38,7 @@ class CodeGenCUDA final : public CodeGenC { ...@@ -38,7 +38,7 @@ class CodeGenCUDA final : public CodeGenC {
private: private:
// magic number to add pragma unroll to it. // magic number to add pragma unroll to it.
// used to generate code that is compact but still unrolls. // used to generate code that is compact but still unrolls.
int max_auto_unroll_{64}; int max_auto_unroll_{256};
// Whether global barrier is needed. // Whether global barrier is needed.
bool need_global_barrier_{false}; bool need_global_barrier_{false};
// Global barrier state // Global barrier state
......
...@@ -9,13 +9,13 @@ namespace tvm { ...@@ -9,13 +9,13 @@ namespace tvm {
namespace codegen { namespace codegen {
namespace intrin { namespace intrin {
// Add float suffix to the intrinsics, CUDA fast math. // Add float suffix to the intrinsics, CUDA fast math.
struct CUDAFastMath { struct CUDAMath {
std::string operator()(Type t, std::string name) const { std::string operator()(Type t, std::string name) const {
if (t.lanes() == 1) { if (t.lanes() == 1) {
if (t.is_float()) { if (t.is_float()) {
switch (t.bits()) { switch (t.bits()) {
case 64: return name; case 64: return name;
case 32: return "__" + name + 'f'; case 32: return name + 'f';
case 16: return 'h' + name; case 16: return 'h' + name;
default: return ""; default: return "";
} }
...@@ -25,6 +25,17 @@ struct CUDAFastMath { ...@@ -25,6 +25,17 @@ struct CUDAFastMath {
} }
}; };
struct CUDAFastMath : public CUDAMath {
std::string operator()(Type t, std::string name) const {
if (t.lanes() == 1 && t.is_float() && t.bits() == 32) {
return "__" + name + 'f';
} else {
return CUDAMath::operator()(t, name);
}
return "";
}
};
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp") TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp")
.set_body(DispatchExtern<CUDAFastMath>); .set_body(DispatchExtern<CUDAFastMath>);
...@@ -32,7 +43,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log") ...@@ -32,7 +43,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log")
.set_body(DispatchExtern<CUDAFastMath>); .set_body(DispatchExtern<CUDAFastMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tanh") TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tanh")
.set_body(DispatchExtern<CUDAFastMath>); .set_body(DispatchExtern<CUDAMath>);
} // namespace intrin } // namespace intrin
} // namespace codegen } // namespace codegen
......
...@@ -242,7 +242,6 @@ Stmt MakeCrossThreadReduction( ...@@ -242,7 +242,6 @@ Stmt MakeCrossThreadReduction(
freduce_args.push_back(reduce->source); freduce_args.push_back(reduce->source);
freduce_args.push_back(cond); freduce_args.push_back(cond);
std::vector<Expr> thread_head_check;
for (IterVar iv : stage->leaf_iter_vars) { for (IterVar iv : stage->leaf_iter_vars) {
if (iv->iter_type == kCommReduce) { if (iv->iter_type == kCommReduce) {
auto it = stage->iter_var_attrs.find(iv); auto it = stage->iter_var_attrs.find(iv);
...@@ -250,10 +249,14 @@ Stmt MakeCrossThreadReduction( ...@@ -250,10 +249,14 @@ Stmt MakeCrossThreadReduction(
(*it).second->bind_thread.defined()) { (*it).second->bind_thread.defined()) {
IterVar tv = (*it).second->bind_thread; IterVar tv = (*it).second->bind_thread;
freduce_args.push_back(tv->var); freduce_args.push_back(tv->var);
thread_head_check.push_back(tv->var == 0);
} }
} }
} }
// Checks for the thread.
std::vector<Expr> thread_head_check;
if (stage->store_predicate.defined()) {
thread_head_check.emplace_back(stage->store_predicate);
}
Type t = reduce->type; Type t = reduce->type;
Expr pred = const_true(t.lanes()); Expr pred = const_true(t.lanes());
Stmt reduce_body = Store::make(res_handle, Stmt reduce_body = Store::make(res_handle,
...@@ -311,6 +314,9 @@ Stmt ComputeOpNode::BuildProvide( ...@@ -311,6 +314,9 @@ Stmt ComputeOpNode::BuildProvide(
nest.push_back(op::MakeIfNest(op::MakeBoundCheck( nest.push_back(op::MakeIfNest(op::MakeBoundCheck(
stage, dom_map, false, stage, dom_map, false,
std::unordered_set<IterVar>(), value_map))); std::unordered_set<IterVar>(), value_map)));
if (stage->store_predicate.defined()) {
nest.emplace_back(op::MakeIfNest({stage->store_predicate}));
}
provide = Substitute(provide, value_map); provide = Substitute(provide, value_map);
if (init.defined()) { if (init.defined()) {
......
...@@ -200,6 +200,12 @@ Stage& Stage::env_threads(Array<IterVar> threads) { ...@@ -200,6 +200,12 @@ Stage& Stage::env_threads(Array<IterVar> threads) {
return *this; return *this;
} }
Stage& Stage::set_store_predicate(Expr predicate) {
StageNode* self = operator->();
self->store_predicate = predicate;
return *this;
}
Stage& Stage::split( Stage& Stage::split(
IterVar parent, Expr factor, IterVar* p_outer, IterVar* p_inner) { // NOLINT(*) IterVar parent, Expr factor, IterVar* p_outer, IterVar* p_inner) { // NOLINT(*)
Split(operator->(), parent, factor, Expr(), p_outer, p_inner); Split(operator->(), parent, factor, Expr(), p_outer, p_inner);
......
...@@ -98,8 +98,10 @@ def test_rfactor_threads(): ...@@ -98,8 +98,10 @@ def test_rfactor_threads():
s[B].bind(bx, tvm.thread_axis("blockIdx.x")) s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
s[B].bind(ty, tvm.thread_axis("threadIdx.y")) s[B].bind(ty, tvm.thread_axis("threadIdx.y"))
tx = s[B].op.reduce_axis[0] tx = s[B].op.reduce_axis[0]
s[B].bind(tx, tvm.thread_axis("threadIdx.x")) thread_x = tvm.thread_axis("threadIdx.x")
s[B].bind(tx, thread_x)
s[BF].compute_at(s[B], tx) s[BF].compute_at(s[B], tx)
s[B].set_store_predicate(thread_x.var.equal(0))
# one line to build the function. # one line to build the function.
def check_target(device, host="stackvm"): def check_target(device, host="stackvm"):
......
"""LSTM Example, still work in progress.."""
import tvm
import time
import os
import argparse
from tvm.contrib import nvcc_compiler
import numpy as np
# Quick knobs
TASK="lstm"
USE_MANUAL_CODE = False
PERSIST_KERNEL = True
DETECT_GLOBAL_BARRIER = PERSIST_KERNEL
SKIP_CHECK = False
UNROLL_WLOAD = True
@tvm.register_func
def tvm_callback_cuda_compile(code):
"""Use nvcc compiler for better perf."""
ptx = nvcc_compiler.compile_source(code, target="ptx", options=["-arch=sm_52"])
return ptx
def write_code(code, fname):
with open(fname, "w") as f:
f.write(code)
@tvm.register_func
def tvm_callback_cuda_postproc(code):
if not os.path.exists("perf"):
os.mkdir("perf")
write_code(code, "perf/%s_generated.cu" % TASK)
if USE_MANUAL_CODE:
code = open("perf/%s_manual.cu" % TASK).read()
return code
def lstm():
if not PERSIST_KERNEL:
raise ValueError("Non persist LSTM not yet supported")
detect_global_barrier = DETECT_GLOBAL_BARRIER
num_thread_y = 8
num_thread_x = 16 * 3 / 2
num_sm = 24
n_num_step = 128
num_step = tvm.var('num_step')
num_hidden = 1152 / 2
batch_size = 1
# Global transition matrix
# Input hidden channel can be pre-caculated by a gemm
Xi2h = tvm.placeholder((num_step, batch_size, 4, num_hidden), name="Xi2h")
# Only handle hidden transition, saves space.
Wh2h = tvm.placeholder((4, num_hidden, num_hidden), name="Wh2h")
# h: output hidden state, c: cell state.
s_state_h = tvm.placeholder((num_step, batch_size, num_hidden))
s_state_c = tvm.placeholder((num_step, batch_size, num_hidden))
s_init_c = tvm.compute((1, batch_size, num_hidden),
lambda *i: 0.0, name="init_c")
s_init_h = tvm.compute((1, batch_size, num_hidden),
lambda *i: 0.0, name="init_h")
# LSTM transition
k = tvm.reduce_axis((0, num_hidden), name="ki2h")
s_h2h = tvm.compute(
(num_step, batch_size, 4, num_hidden),
lambda t, i, x, j: tvm.sum(s_state_h[t - 1, i, k] * Wh2h[x, j, k], axis=k),
name="s_h2h")
# Gate rules
gates = tvm.compute(Xi2h.shape, lambda *i:
Xi2h(*i) + s_h2h(*i), name="gates")
gshape = (num_step, batch_size, num_hidden)
in_gate = tvm.compute(gshape, lambda t, i, j: tvm.sigmoid(gates[t, i, 0, j]), name="in_gate")
in_transform = tvm.compute(gshape, lambda t, i, j: tvm.tanh(gates[t, i, 1, j]), name="in_transform")
forget_gate = tvm.compute(gshape, lambda t, i, j: tvm.sigmoid(gates[t, i, 2, j]), name="forget_gate")
out_gate = tvm.compute(gshape, lambda t, i, j: tvm.sigmoid(gates[t, i, 3, j]), name="out_gate")
next_c = tvm.compute(gshape,
lambda t, i, j:
forget_gate[t, i, j] * s_state_c[t - 1, i, j] +
in_gate[t, i, j] * in_transform[t, i, j], name="next_c")
next_h = tvm.compute(gshape,
lambda t, i, j: out_gate[t, i, j] * tvm.tanh(next_c[t, i, j]), name="next_h")
update_c = tvm.compute(gshape, lambda *i: next_c(*i), name="update_c")
update_h = tvm.compute(gshape, lambda *i: next_h(*i), name="update_h")
# schedule
scan_h, scan_c = tvm.scan(
[s_init_h, s_init_c],
[update_h, update_c],
[s_state_h, s_state_c],
inputs=[Xi2h],
name="lstm_scan")
# schedule
s = tvm.create_schedule(scan_h.op)
# Inline gate computations
s[gates].compute_inline()
s[in_gate].compute_inline()
s[in_transform].compute_inline()
s[forget_gate].compute_inline()
s[out_gate].compute_inline()
block_x = tvm.thread_axis((0, num_sm), "blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x")
thread_y = tvm.thread_axis((0, num_thread_y), "threadIdx.y")
s_state_h_S = s.cache_read(s_state_h, "shared", [s_h2h])
s_state_c_S = s.cache_read(s_state_c, "shared", [next_c])
Wh2hL = s.cache_read(Wh2h, "local", [s_h2h])
ko, ki = s[s_h2h].split(s[s_h2h].op.reduce_axis[0], nparts=num_thread_y)
s_h2h_rf = s.rfactor(s_h2h, ko)
s[s_h2h].bind(s[s_h2h].op.reduce_axis[0], thread_y)
s[s_h2h_rf].compute_at(s[s_h2h], s[s_h2h].op.reduce_axis[0])
if PERSIST_KERNEL:
s[scan_h.op].env_threads([block_x, thread_y, thread_x])
s[Wh2hL].compute_at(s[scan_h.op], thread_x)
else:
s[Wh2hL].compute_at(s[s_h2h], s[s_h2h].op.axis[3])
if UNROLL_WLOAD:
s[Wh2hL].unroll(Wh2hL.op.axis[0])
s[Wh2hL].unroll(Wh2hL.op.axis[2])
s[s_state_h_S].compute_at(s[s_h2h_rf], s[s_h2h_rf].op.axis[3])
s[s_state_c_S].compute_at(s[scan_h.op], s[scan_h].op.scan_axis)
for ss in [s_state_h_S]:
xo, xi = s[ss].split(ss.op.axis[2], factor=num_thread_x * num_thread_y)
ty, xi = s[ss].split(xi, nparts=num_thread_y)
tx, xi = s[ss].split(xi, nparts=num_thread_x)
s[ss].bind(ty, thread_y)
s[ss].bind(tx, thread_x)
for init in [s_init_c, s_init_h]:
bx, xi = s[init].split(init.op.axis[2], nparts=num_sm)
tx, xi = s[init].split(xi, nparts=num_thread_x)
s[init].bind(bx, block_x)
s[init].bind(tx, thread_x)
s[next_c].set_store_predicate(thread_y.equal(0))
s[next_h].set_store_predicate(thread_y.equal(0))
for update in [update_c, update_h]:
bx, xi = s[update].split(s[update].op.axis[2], nparts=num_sm)
tx, xi = s[update].split(xi, nparts=num_thread_x)
s[update].bind(bx, block_x)
s[update].bind(tx, thread_x)
s[update].set_store_predicate(thread_y.equal(0))
# verify we can lower correctly
def check_device(target):
num_step = n_num_step
flstm = tvm.build(s, [Xi2h, Wh2h, scan_h, scan_c],
target,
detect_global_barrier=DETECT_GLOBAL_BARRIER)
ctx = tvm.gpu(0) if target == "cuda" else tvm.cl(0)
# launch the kernel.
scan_h_np = np.zeros(
(num_step, batch_size, num_hidden)).astype("float32")
scan_c_np = np.zeros(
(num_step, batch_size, num_hidden)).astype("float32")
Xi2h_np = np.random.normal(
size=(num_step, batch_size, 4, num_hidden)).astype("float32")
Wh2h_np = np.random.normal(
size=(4, num_hidden, num_hidden)).astype("float32")
scan_h_a = tvm.nd.array(scan_h_np, ctx)
scan_c_a = tvm.nd.array(scan_c_np, ctx)
Xi2h_a = tvm.nd.array(Xi2h_np, ctx)
Wh2h_a = tvm.nd.array(Wh2h_np, ctx)
flstm(Xi2h_a, Wh2h_a, scan_h_a, scan_c_a)
ctx.sync()
# measure time cost of second step.
tstart = time.time()
flstm(Xi2h_a, Wh2h_a, scan_h_a, scan_c_a)
ctx.sync()
tgap = time.time() - tstart
print("Time cost=%g" % tgap)
check_device("cuda")
if __name__ == "__main__":
lstm()
...@@ -90,6 +90,8 @@ def rnn_matexp(): ...@@ -90,6 +90,8 @@ def rnn_matexp():
s[s_update].bind(tx, thread_x) s[s_update].bind(tx, thread_x)
s[CL].bind(s[CL].op.reduce_axis[0], thread_y) s[CL].bind(s[CL].op.reduce_axis[0], thread_y)
s[CLF].compute_at(s[CL], s[CL].op.reduce_axis[0]) s[CLF].compute_at(s[CL], s[CL].op.reduce_axis[0])
# Duplicate store predicate.
s[CL].set_store_predicate(thread_y.equal(0))
if PERSIST_KERNEL: if PERSIST_KERNEL:
s[WhhL].compute_at(s[s_scan], thread_x) s[WhhL].compute_at(s[s_scan], thread_x)
...@@ -109,7 +111,6 @@ def rnn_matexp(): ...@@ -109,7 +111,6 @@ def rnn_matexp():
s[SS].bind(tx, thread_x) s[SS].bind(tx, thread_x)
def check_device(target): def check_device(target):
codes = []
f = tvm.build(s, [s_scan, Whh], f = tvm.build(s, [s_scan, Whh],
target, target,
max_auto_unroll_step=max_auto_unroll_step, max_auto_unroll_step=max_auto_unroll_step,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment