Unverified Commit 4b5f324a by Tang, Shizhi Committed by GitHub

[TIR] Make lower_warp_memory support extent(threadIdx.x) < warp_size (#5307)

* support extent(threadIdx.x) < warp_size in lower_warp_memory

* more docs for lower_warp_memory
parent 49d304fc
...@@ -1228,9 +1228,17 @@ constexpr const char* tvm_storage_sync = "tvm_storage_sync"; ...@@ -1228,9 +1228,17 @@ constexpr const char* tvm_storage_sync = "tvm_storage_sync";
/*! /*!
* \brief See pseudo code * \brief See pseudo code
* *
* Type tvm_warp_shuffle(Type value, warp_id) { * Type tvm_warp_shuffle(Type value, warp_id, width, warp_size) {
* return (value passed in by warp indicated by warp_id); * return (value passed in by warp indicated by warp_id);
* } * }
*
* Parameter warp_id indicates the source thread ID in a warp.
*
* Parameter width indicates the number of threads involved in one
* shuffle. See CUDA document for __shfl.
*
* Parameter warp_size is the size of a warp, which helps a backend
* to determine wheter the width paramter is legal.
*/ */
constexpr const char* tvm_warp_shuffle = "tvm_warp_shuffle"; constexpr const char* tvm_warp_shuffle = "tvm_warp_shuffle";
/*! /*!
......
...@@ -81,11 +81,15 @@ struct CUDAPopcount { ...@@ -81,11 +81,15 @@ struct CUDAPopcount {
} }
}; };
struct CUDAShuffle { static void DispatchCUDAShuffle(const TVMArgs& args, TVMRetValue* rv) {
std::string operator()(DataType t, std::string name) const { PrimExpr e = args[0];
return "__shfl"; const CallNode* call = e.as<CallNode>();
} CHECK(call != nullptr);
}; CHECK_EQ(call->args.size(), 4); // value, warp_id, width, warp_size
Array<PrimExpr> cuda_args{{call->args[0], call->args[1], call->args[2]}};
*rv = CallNode::make(
call->dtype, "__shfl", cuda_args, CallNode::PureExtern);
}
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.floor") TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.floor")
.set_body(DispatchExtern<CUDAMath>); .set_body(DispatchExtern<CUDAMath>);
...@@ -154,7 +158,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.popcount") ...@@ -154,7 +158,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.popcount")
.set_body(DispatchExtern<CUDAPopcount>); .set_body(DispatchExtern<CUDAPopcount>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle") TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle")
.set_body(DispatchExtern<CUDAShuffle>); .set_body(DispatchCUDAShuffle);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fmod") TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fmod")
.set_body(DispatchExtern<CUDAMath>); .set_body(DispatchExtern<CUDAMath>);
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
* \file intrin_rule_opencl.cc * \file intrin_rule_opencl.cc
* \brief OpenCL intrinsic rules. * \brief OpenCL intrinsic rules.
*/ */
#include <tvm/arith/analyzer.h>
#include "../intrin_rule.h" #include "../intrin_rule.h"
namespace tvm { namespace tvm {
...@@ -89,14 +90,21 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.cosh") ...@@ -89,14 +90,21 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.cosh")
// There is no warp shuffle instruction in standard OpenCL // There is no warp shuffle instruction in standard OpenCL
// When shuffle is used, we assume it is intel's shuffle extension // When shuffle is used, we assume it is intel's shuffle extension
struct IntelShuffle { static void DispatchIntelShuffle(const TVMArgs& args, TVMRetValue* rv) {
std::string operator()(DataType t, std::string name) const { PrimExpr e = args[0];
return "intel_sub_group_shuffle"; const CallNode* call = e.as<CallNode>();
} CHECK(call != nullptr);
}; CHECK_EQ(call->args.size(), 4); // value, warp_id, width, warp_size
arith::Analyzer analyzer;
CHECK(analyzer.CanProve(call->args[2] == call->args[3]))
<< "Intel warp shuffle dose not support width != warp_size";
Array<PrimExpr> cuda_args{{call->args[0], call->args[1]}};
*rv = CallNode::make(
call->dtype, "intel_sub_group_shuffle", cuda_args, CallNode::PureExtern);
}
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tvm_warp_shuffle") TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tvm_warp_shuffle")
.set_body(DispatchExtern<IntelShuffle>); .set_body(DispatchIntelShuffle);
} // namespace intrin } // namespace intrin
} // namespace codegen } // namespace codegen
......
...@@ -60,28 +60,45 @@ namespace tir { ...@@ -60,28 +60,45 @@ namespace tir {
// //
// Before rewrite, // Before rewrite,
// //
// alloc warp warp_mem[n * warp_size * m] // alloc warp warp_mem[n * width * m]
// store warp_mem[m * warp_index + (warp_size * m) * y + x] // store warp_mem[m * warp_index + (width * m) * y + x]
// load warp_mem[m * z + (warp_size * m) * y + x] // load warp_mem[m * z + (width * m) * y + x]
// subject to x \in [0, m), y \in [0, n) // subject to x \in [0, m), y \in [0, n)
// //
// where width equals to the extent of threadIdx.x, which should
// be no larger than the warp size
//
// After rewrite: // After rewrite:
// //
// alloc local local_mem[n * m] // alloc local local_mem[n * m]
// store warp_mem[m * y + x] // store warp_mem[m * y + x]
// warp_shuffle(load warp_mem[m * y + x], z) // warp_shuffle(load warp_mem[m * y + x], z)
// subject to (m * y + x) is invariant to warp_index // subject to (m * y + x) is invariant to warp_index
//
// If width == warp size, we are shuffling on full warps.
// Otherwise, we are virtually shuffling on sub-warps,
// whose size equals to width. In this case, you can imagine
// a warp only consists of `width` threads. Width is passed
// as an argument to the shuffle primitive, and will be
// lowered to the device code if the target supports.
//
// A limitation of this sub-warp approach is that users
// cannot shuffle across the sub-warp boundary (i.e. shuffle
// with threadIdx.y or threadIdx.z indices). It can be solved
// via fusing threadIdx.x to the warp size, or improving the
// analyzer to detect both 3 thread axes, which is left for
// future improvements.
// Algorithm // Algorithm
// //
// To implement this rewrite rule, we can do the follow step: // To implement this rewrite rule, we can do the follow step:
// For each warp memory alloc // For each warp memory alloc
// - Use linear pattern detector on load index to find m // - Use linear pattern detector on load index to find m
// - Deduce n given warp_size and alloc size // - Deduce n given width and alloc size
// - Now that we have m, n, warp_size, we can proceed with the rewrite // - Now that we have m, n, width, we can proceed with the rewrite
// Visitor to find m in pattern // Visitor to find m in pattern
// store warp_mem[m * warp_index + (warp_size * m) * y + x] // store warp_mem[m * warp_index + (width * m) * y + x]
class WarpStoreCoeffFinder : private StmtVisitor { class WarpStoreCoeffFinder : private StmtVisitor {
public: public:
WarpStoreCoeffFinder(const VarNode* buffer, WarpStoreCoeffFinder(const VarNode* buffer,
...@@ -153,12 +170,12 @@ class WarpIndexFinder : private StmtVisitor { ...@@ -153,12 +170,12 @@ class WarpIndexFinder : private StmtVisitor {
explicit WarpIndexFinder(int warp_size) explicit WarpIndexFinder(int warp_size)
: warp_size_(warp_size) { : warp_size_(warp_size) {
} }
// find the warp co-efficient in the statement given the warp size // find the warp co-efficient and the shuffle width in the statement
IterVar Find(const Stmt& stmt) { std::pair<Var, int> Find(const Stmt& stmt) {
this->VisitStmt(stmt); this->VisitStmt(stmt);
CHECK(warp_index_.defined()) CHECK(warp_index_.defined())
<< "Cannot find warp index(threadIdx.x) within the scope of warp memory"; << "Cannot find warp index(threadIdx.x) within the scope of warp memory";
return warp_index_; return std::make_pair(warp_index_->var, width_);
} }
private: private:
...@@ -167,11 +184,12 @@ class WarpIndexFinder : private StmtVisitor { ...@@ -167,11 +184,12 @@ class WarpIndexFinder : private StmtVisitor {
if (op->attr_key == attr::thread_extent) { if (op->attr_key == attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node); IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag == "threadIdx.x") { if (iv->thread_tag == "threadIdx.x") {
int value; int value = 0;
CHECK(arith::GetConstInt(op->value, &value) && CHECK(arith::GetConstInt(op->value, &value) &&
value == warp_size_) value <= warp_size_ &&
<< "Expect threadIdx.x 's size to be equal to warp size(" warp_size_ % value == 0)
<< warp_size_ << ")" << " to enable warp memory" << "Expect threadIdx.x 's size to be no larger than, and a factor of"
<< " warp size(" << warp_size_ << ")" << " to enable warp memory"
<< " but get " << op->value << " instead"; << " but get " << op->value << " instead";
if (warp_index_.defined()) { if (warp_index_.defined()) {
CHECK(warp_index_.same_as(iv)) CHECK(warp_index_.same_as(iv))
...@@ -180,6 +198,7 @@ class WarpIndexFinder : private StmtVisitor { ...@@ -180,6 +198,7 @@ class WarpIndexFinder : private StmtVisitor {
<< "Please create it using thread_axis once and reuse the axis " << "Please create it using thread_axis once and reuse the axis "
<< "across multiple binds in the same kernel"; << "across multiple binds in the same kernel";
} else { } else {
width_ = value;
warp_index_ = iv; warp_index_ = iv;
} }
} }
...@@ -188,6 +207,8 @@ class WarpIndexFinder : private StmtVisitor { ...@@ -188,6 +207,8 @@ class WarpIndexFinder : private StmtVisitor {
} }
// warp size // warp size
int warp_size_{0}; int warp_size_{0};
// number of threads involved in one shuffle
int width_{0};
// the warp index // the warp index
IterVar warp_index_{nullptr}; IterVar warp_index_{nullptr};
}; };
...@@ -204,16 +225,16 @@ class WarpAccessRewriter : protected StmtExprMutator { ...@@ -204,16 +225,16 @@ class WarpAccessRewriter : protected StmtExprMutator {
CHECK_GT(alloc_size, 0) CHECK_GT(alloc_size, 0)
<< "warp memory only support constant alloc size"; << "warp memory only support constant alloc size";
alloc_size *= op->dtype.lanes(); alloc_size *= op->dtype.lanes();
warp_index_ = WarpIndexFinder(warp_size_).Find(op->body)->var; std::tie(warp_index_, width_) = WarpIndexFinder(warp_size_).Find(op->body);
warp_coeff_ = WarpStoreCoeffFinder( warp_coeff_ = WarpStoreCoeffFinder(
buffer_, warp_index_, analyzer_).Find(op->body); buffer_, warp_index_, analyzer_).Find(op->body);
CHECK_EQ(alloc_size % (warp_size_ * warp_coeff_), 0) CHECK_EQ(alloc_size % (width_ * warp_coeff_), 0)
<< "Warp memory must be multiple of warp size"; << "Warp memory must be multiple of the extent of threadIdx.x";
warp_group_ = alloc_size / (warp_size_ * warp_coeff_); warp_group_ = alloc_size / (width_ * warp_coeff_);
return AllocateNode::make( return AllocateNode::make(
op->buffer_var, op->buffer_var,
op->dtype, op->dtype,
{make_const(DataType::Int(32), alloc_size / warp_size_)}, {make_const(DataType::Int(32), alloc_size / width_)},
op->condition, op->condition,
this->VisitStmt(op->body)); this->VisitStmt(op->body));
} }
...@@ -247,7 +268,7 @@ class WarpAccessRewriter : protected StmtExprMutator { ...@@ -247,7 +268,7 @@ class WarpAccessRewriter : protected StmtExprMutator {
op->dtype, op->buffer_var, local_index, op->predicate); op->dtype, op->buffer_var, local_index, op->predicate);
return CallNode::make(load_value.dtype(), return CallNode::make(load_value.dtype(),
intrinsic::tvm_warp_shuffle, intrinsic::tvm_warp_shuffle,
{load_value, group}, {load_value, group, width_, warp_size_},
CallNode::Intrinsic); CallNode::Intrinsic);
} else { } else {
return StmtExprMutator::VisitExpr_(op); return StmtExprMutator::VisitExpr_(op);
...@@ -276,9 +297,9 @@ class WarpAccessRewriter : protected StmtExprMutator { ...@@ -276,9 +297,9 @@ class WarpAccessRewriter : protected StmtExprMutator {
return std::make_pair(x, z); return std::make_pair(x, z);
} else { } else {
PrimExpr x = analyzer_->canonical_simplify(indexmod(index, m)); PrimExpr x = analyzer_->canonical_simplify(indexmod(index, m));
PrimExpr y = index / make_const(index.dtype(), warp_coeff_ * warp_size_); PrimExpr y = index / make_const(index.dtype(), warp_coeff_ * width_);
y = y * m + x; y = y * m + x;
PrimExpr z = indexdiv(indexmod(index, make_const(index.dtype(), warp_coeff_ * warp_size_)), PrimExpr z = indexdiv(indexmod(index, make_const(index.dtype(), warp_coeff_ * width_)),
m); m);
return std::make_pair(analyzer_->canonical_simplify(y), return std::make_pair(analyzer_->canonical_simplify(y),
analyzer_->canonical_simplify(z)); analyzer_->canonical_simplify(z));
...@@ -290,6 +311,8 @@ class WarpAccessRewriter : protected StmtExprMutator { ...@@ -290,6 +311,8 @@ class WarpAccessRewriter : protected StmtExprMutator {
int warp_size_{0}; int warp_size_{0};
// The buffer variable // The buffer variable
const VarNode* buffer_; const VarNode* buffer_;
// number of threads involved in one shuffle
int width_{0};
// Warp index // Warp index
Var warp_index_; Var warp_index_;
// the coefficient m // the coefficient m
......
...@@ -91,6 +91,48 @@ def test_lower_warp_memory_cuda_end_to_end(): ...@@ -91,6 +91,48 @@ def test_lower_warp_memory_cuda_end_to_end():
check_cuda("float32") check_cuda("float32")
check_cuda("float16") check_cuda("float16")
def test_lower_warp_memory_cuda_half_a_warp():
def check_cuda(dtype):
if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
print("skip because cuda is not enabled..")
return
if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version):
print("Skip because gpu does not have fp16 support")
return
m = 16
A = te.placeholder((m,), name='A', dtype=dtype)
B = te.compute((m,), lambda i: A[(i + 1) % m], name='B')
cuda_target = tvm.target.create("cuda")
assert cuda_target.thread_warp_size == 2 * m
with cuda_target:
s = te.create_schedule(B.op)
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
AA = s.cache_read(A, "warp", [B])
xo, xi = s[B].split(B.op.axis[0], nparts=1)
s[B].bind(xi, tx)
s[B].bind(xo, bx)
s[AA].compute_at(s[B], xo)
xo, xi = s[AA].split(s[AA].op.axis[0], nparts=1)
s[AA].bind(xo, bx)
s[AA].bind(xi, tx)
ctx = tvm.gpu(0)
func = tvm.build(s, [A, B], "cuda")
A_np = np.array(list(range(m)), dtype=dtype)
B_np = np.array(list(range(1, m)) + [0], dtype=dtype)
A_nd = tvm.nd.array(A_np, ctx)
B_nd = tvm.nd.array(np.zeros(B_np.shape, dtype=B_np.dtype), ctx)
func(A_nd, B_nd)
tvm.testing.assert_allclose(B_nd.asnumpy(), B_np, rtol=1e-3)
check_cuda("float32")
check_cuda("float16")
if __name__ == "__main__": if __name__ == "__main__":
test_lower_warp_memory_local_scope() test_lower_warp_memory_local_scope()
test_lower_warp_memory_cuda_end_to_end() test_lower_warp_memory_cuda_end_to_end()
test_lower_warp_memory_cuda_half_a_warp()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment