Commit e15aae2b by Tianqi Chen Committed by GitHub

[SCHEDULE][PASS] Enable Warp memory and lower to shuffle (#1050)

* [SCHEDULE][PASS] Enable Warp memory and lower to shuffle

* OpenCL dispatches for now to intel shuffle
parent cc71d505
......@@ -412,6 +412,14 @@ constexpr const char* tvm_call_packed_lowered = "tvm_call_packed_lowered";
*/
constexpr const char* tvm_storage_sync = "tvm_storage_sync";
/*!
* \brief See pseudo code
*
* Type tvm_warp_shuffle(Type value, warp_id) {
* return (value passed in by warp indicated by warp_id);
* }
*/
constexpr const char* tvm_warp_shuffle = "tvm_warp_shuffle";
/*!
* \brief Initialize the global barrier.
* Call this at beginning of kernel that need global barrier.
*/
......
......@@ -408,6 +408,15 @@ LoweredFunc ThreadSync(LoweredFunc stmt, std::string storage_scope);
LoweredFunc LowerThreadAllreduce(LoweredFunc f, int warp_size);
/*!
* \brief Lower warp memory in stmt.
* \param f The device function to be lowered.
* \param warp_size the size of warp where no sync is needed.
* this function will only take in effect if warp_size is bigger than one.
* \return Transformed function.
*/
LoweredFunc LowerWarpMemory(LoweredFunc f, int warp_size);
/*!
* \brief Lower packed function call.
* \param f The function to be lowered.
* \return Transformed function.
......
......@@ -450,6 +450,10 @@ def build(sch,
else:
raise ValueError("unknown function type %d" % func.func_type)
for i, func in enumerate(fdevice):
warp_size = target.thread_warp_size
fdevice[i] = ir_pass.LowerWarpMemory(func, warp_size)
if "gpu" in target.keys and not fdevice:
warnings.warn(
"Specified target %s, but cannot find device code, did you do bind?" % target)
......
......@@ -125,6 +125,7 @@ REGISTER_PASS2(SplitPipeline);
REGISTER_PASS2(LiftAttrScope);
REGISTER_PASS1(NarrowChannelAccess);
REGISTER_PASS2(LowerThreadAllreduce);
REGISTER_PASS2(LowerWarpMemory);
REGISTER_PASS2(LowerIntrin);
REGISTER_PASS1(LowerTVMBuiltin);
REGISTER_PASS1(CombineContextCall);
......
......@@ -5,6 +5,7 @@
#include <iomanip>
#include <cctype>
#include "./codegen_c.h"
#include "../pass/ir_util.h"
#include "../arithmetic/compute_expr.h"
namespace tvm {
......@@ -544,15 +545,6 @@ void CodeGenC::PrintVecBinaryOp(
}
}
inline bool TryGetRamp1Base(Expr index, int lanes, Expr *base) {
const Ramp* r = index.as<Ramp>();
if (!r) return false;
if (!is_one(r->stride)) return false;
CHECK_EQ(r->lanes, lanes);
*base = r->base;
return true;
}
void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*)
int lanes = op->type.lanes();
// delcare type.
......@@ -563,7 +555,7 @@ void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*)
CHECK(is_one(op->predicate))
<< "predicated load is not supported";
Expr base;
if (TryGetRamp1Base(op->index, op->type.lanes(), &base)) {
if (GetRamp1Base(op->index, op->type.lanes(), &base)) {
std::string ref = GetVecLoad(op->type, op->buffer_var.get(), base);
os << ref;
} else {
......@@ -617,7 +609,7 @@ void CodeGenC::VisitStmt_(const Store* op) {
CHECK(is_one(op->predicate))
<< "Predicated store is not supported";
Expr base;
if (TryGetRamp1Base(op->index, t.lanes(), &base)) {
if (GetRamp1Base(op->index, t.lanes(), &base)) {
std::string value = this->PrintExpr(op->value);
this->PrintVecStore(op->buffer_var.get(), t, base, value);
} else {
......
......@@ -49,6 +49,12 @@ struct CUDAPopcount {
}
};
struct CUDAShuffle {
std::string operator()(Type t, std::string name) const {
return "__shfl";
}
};
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp")
.set_body(DispatchExtern<CUDAFastMath>);
......@@ -67,6 +73,10 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.pow")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.popcount")
.set_body(DispatchExtern<CUDAPopcount>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle")
.set_body(DispatchExtern<CUDAShuffle>);
} // namespace intrin
} // namespace codegen
} // namespace tvm
......@@ -27,6 +27,17 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.pow")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.popcount")
.set_body(DispatchExtern<Direct>);
// There is no warp shuffle instruction in standard OpenCL
// When shuffle is used, we assume it is intel's shuffle extension
struct IntelShuffle {
std::string operator()(Type t, std::string name) const {
return "intel_sub_group_shuffle";
}
};
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tvm_warp_shuffle")
.set_body(DispatchExtern<IntelShuffle>);
} // namespace intrin
} // namespace codegen
} // namespace tvm
......@@ -161,6 +161,23 @@ inline int GetTempAllocaAlignment(Type type, int32_t const_size) {
}
return align;
}
/*!
* \brief Pattern match index to Ramp with stride=1
* This is a common pattern in continuous memory load.
* \param index The index formula
* \param lanes number of lanes in the ramp
* \param base The result base.
* \return true if pattern match success and store the base to base.
*/
inline bool GetRamp1Base(Expr index, int lanes, Expr *base) {
const Ramp* r = index.as<Ramp>();
if (!r) return false;
if (!is_one(r->stride)) return false;
CHECK_EQ(r->lanes, lanes);
*base = r->base;
return true;
}
} // namespace ir
} // namespace tvm
#endif // TVM_PASS_IR_UTIL_H_
......@@ -42,7 +42,16 @@ bool NeedRelax(const IterVar& iv,
if (tag.length() == 0 || tag == "pipeline") {
return !found_attach;
}
return static_cast<int>(scope.rank) <= ThreadScope::make(tag).rank;
ThreadScope ts = ThreadScope::make(tag);
// When there is warp memory
// threadIdx.x must be set to be warp index.
if (scope.rank == StorageRank::kWarp &&
ts.rank == 1 &&
ts.dim_index == 0) {
return true;
}
return static_cast<int>(scope.rank) <= ts.rank;
}
// infer storage scope, if not given
......
import tvm
from tvm.contrib import nvcc
import numpy as np
import time
......@@ -155,7 +156,46 @@ def test_add():
run("uint64")
def try_warp_memory():
"""skip this in default test because it require higher arch"""
m = 128
A = tvm.placeholder((m,), name='A')
B = tvm.compute((m,), lambda i: A[i] + 3, name='B')
warp_size = 32
s = tvm.create_schedule(B.op)
AA = s.cache_read(A, "warp", [B])
xo, xi = s[B].split(B.op.axis[0], warp_size * 2)
xi0, xi1 = s[B].split(xi, factor=warp_size)
tx = tvm.thread_axis("threadIdx.x")
s[B].bind(xi1, tx)
s[B].bind(xo, tvm.thread_axis("blockIdx.x"))
s[AA].compute_at(s[B], xo)
xo, xi = s[AA].split(s[AA].op.axis[0], warp_size)
s[AA].bind(xi, tx)
@tvm.register_func
def tvm_callback_cuda_compile(code):
ptx = nvcc.compile_cuda(code, target="ptx")
return ptx
# one line to build the function.
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("skip because %s is not enabled.." % device)
return
f = tvm.build(s, [A, B], device)
a = tvm.nd.array((np.random.uniform(size=m) * 256).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(m, dtype=B.dtype), ctx)
f(a, b)
np.testing.assert_allclose(
b.asnumpy(), a.asnumpy() + 3, rtol=1e-6)
check_device("cuda")
if __name__ == "__main__":
try_warp_memory()
test_add()
test_log_pow_llvm()
test_exp()
......
import tvm
def test_lower_warp_mem():
m = 128
A = tvm.placeholder((m,), name='A')
B = tvm.compute((m,), lambda i: A[i] + 3, name='B')
s = tvm.create_schedule(B.op)
AA = s.cache_read(A, "warp", [B])
xo, xi = s[B].split(B.op.axis[0], 32)
xi0, xi1 = s[B].split(xi, factor=16)
tx = tvm.thread_axis("threadIdx.x")
s[B].bind(xi1, tx)
s[B].bind(xo, tvm.thread_axis("blockIdx.x"))
s[AA].compute_at(s[B], xo)
xo, xi = s[AA].split(s[AA].op.axis[0], 16)
s[AA].bind(xi, tx)
f = tvm.lower(s, [A, B])
fhost, fdevice = tvm.ir_pass.SplitHostDevice(f)
fdevice = tvm.ir_pass.LowerWarpMemory(fdevice, 16)
assert(fdevice.body.body.value.value == "local")
assert(fdevice.body.body.body.extents[0].value == 2)
if __name__ == "__main__":
test_lower_warp_mem()
......@@ -53,6 +53,29 @@ def test_bound3():
assert(bounds[A1.op.axis[0]].extent.value==32)
assert(bounds[A1.op.axis[1]].extent.value==16)
def test_bound_warp():
m = tvm.var('m')
l = tvm.var('l')
A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
s = tvm.create_schedule(A2.op)
s[A1].set_scope("warp")
xo, xi = s[A2].split(A2.op.axis[0], 32)
xi0, xi1 = s[A2].split(xi, factor=16)
tx = tvm.thread_axis("threadIdx.x")
s[A2].bind(xi1, tx)
s[A2].bind(xi0, tvm.thread_axis("threadIdx.y"))
y = s[A2].op.axis[1]
s[A1].compute_at(s[A2], y)
xo, xi = s[A1].split(s[A1].op.axis[0], factor=16)
s[A1].bind(xi, tx)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
assert(bounds[A1.op.axis[0]].extent.value==16)
def test_bound_scan():
m = tvm.var("m")
n = tvm.var("n")
......@@ -249,3 +272,4 @@ if __name__ == "__main__":
test_bound_conv1d()
test_bound2()
test_gemm_bound()
test_bound_warp()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment