Unverified Commit afa84171 by Wei Pan Committed by GitHub

[CodeGen][CUDA] Enhance CUDA codegen for SelectNode (#4983)

- This patch allows CUDA backend to emit correct code for
  selects with vector conditions, which may be produced
  by floordiv op lowering etc..

- This already works for llvm BE, as llvm select instruction
  supports vector conditions.

Signed-off-by: Wei Pan <weip@nvidia.com>
parent 45ee7b5f
......@@ -112,6 +112,10 @@ class DataType {
bool is_vector() const {
return lanes() > 1;
}
/*! \return whether type is a bool vector type. */
bool is_vector_bool() const {
return is_vector() && bits() == 1;
}
/*!
* \brief Create a new data type by change lanes to a specified value.
* \param lanes The target number of lanes.
......
......@@ -135,6 +135,13 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
}
} else if (t == DataType::Bool()) {
os << "bool"; return;
} else if (t.is_vector_bool()) {
// CUDA does not support bool vectors.
// Use ushort vectors to represent instead.
int n = t.lanes();
if (n <= 4) {
os << "ushort" << n; return;
}
} else if (t.is_uint() || t.is_int()) {
if (t.is_uint()) {
if (t.lanes() != 1) {
......@@ -226,7 +233,7 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
}
void CodeGenCUDA::PrintVecBinaryOp(
const std::string&op, DataType t,
const std::string& op, DataType t,
PrimExpr lhs, PrimExpr rhs, std::ostream& os) { // NOLINT(*)
// unpacking operations.
int lanes = t.lanes();
......@@ -561,6 +568,48 @@ void CodeGenCUDA::VisitExpr_(const ShuffleNode* op, std::ostream &os) {
os << ')';
}
void CodeGenCUDA::VisitExpr_(const SelectNode* op, std::ostream &os) {
// Non-vector cases.
if (!op->dtype.is_vector()) {
CodeGenC::VisitExpr_(op, os);
return;
}
// Codegen vector condition case by serializing the select op.
CHECK(op->false_value->dtype == op->dtype &&
op->true_value->dtype == op->dtype &&
op->dtype.lanes() == op->condition.dtype().lanes());
int lanes = op->dtype.lanes();
int scope = BeginScope();
std::string c_var = SSAGetID(PrintExpr(op->condition), op->dtype);
std::string t_var = SSAGetID(PrintExpr(op->true_value), op->dtype);
std::string f_var = SSAGetID(PrintExpr(op->false_value), op->dtype);
std::string r_var = GetUniqueName("_");
this->PrintIndent();
this->PrintType(op->dtype, stream);
stream << ' ' << r_var << ";\n";
// The condition is stored as an ushort vector.
DataType memory_ty(DataType::TypeCode::kUInt, 16, lanes);
for (int i = 0; i < lanes; ++i) {
std::ostringstream item;
item << "(bool(";
PrintVecElemLoad(c_var, memory_ty, i, item);
item << ")?";
PrintVecElemLoad(t_var, op->dtype, i, item);
item << ':';
PrintVecElemLoad(f_var, op->dtype, i, item);
item << ')';
PrintVecElemStore(r_var, op->dtype, i, item.str());
}
os << r_var;
EndScope(scope);
}
inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) { // NOLINT(*)
switch (op->dtype.bits()) {
case 64: case 32: {
......
......@@ -43,11 +43,11 @@ class CodeGenCUDA final : public CodeGenC {
return (enable_fp16_ || enable_int8_ || need_math_constants_h_ || need_mma_h_);
}
// override behavior
void VisitStmt_(const tir::ForNode* op) final;
void VisitStmt_(const ForNode* op) final;
void PrintStorageSync(const CallNode* op) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintVecBinaryOp(
const std::string&op, DataType t,
const std::string& op, DataType t,
PrimExpr lhs, PrimExpr rhs, std::ostream& os) final; // NOLINT(*)
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
void PrintVecElemLoad(
......@@ -58,6 +58,7 @@ class CodeGenCUDA final : public CodeGenC {
// overload visitor
void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const ShuffleNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const SelectNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const FloatImmNode *op, std::ostream& os) final;
void VisitExpr_(const CallNode *op, std::ostream& os) final;
......
......@@ -321,6 +321,33 @@ def test_cuda_reduction():
check_cuda("float32")
check_cuda("float16")
def test_cuda_floordiv_with_vectorization():
if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
print("skip because cuda is not enabled..")
return
with tvm.target.cuda():
# B[i] = A[floordiv(i, k)]
n = 256
k = 37
A = te.placeholder((n,), name='A')
B = te.compute((n,), lambda i: A[tvm.tir.floordiv(i, k)], name='B')
s = te.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], nparts=1)
xio, xii = s[B].split(xi, factor=4)
s[B].vectorize(xii)
s[B].bind(xo, bx)
s[B].bind(xio, tx)
func = tvm.build(s, [A, B], 'cuda')
ctx = tvm.gpu(0)
a_np = np.random.uniform(size=(n,)).astype(A.dtype)
b_np = np.array([a_np[i//k] for i in range(0, n)])
a_nd = tvm.nd.array(a_np, ctx)
b_nd = tvm.nd.array(np.zeros(b_np.shape, dtype=b_np.dtype), ctx)
func(a_nd, b_nd)
tvm.testing.assert_allclose(b_nd.asnumpy(), b_np, rtol=1e-3)
if __name__ == "__main__":
test_cuda_vectorize_add()
test_cuda_multiply_add()
......@@ -331,4 +358,5 @@ if __name__ == "__main__":
test_cuda_reducition_binding()
test_rfactor_predicates()
test_cuda_const_float_to_half()
test_cuda_reduction()
\ No newline at end of file
test_cuda_reduction()
test_cuda_floordiv_with_vectorization()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment