Unverified Commit d2e58ad2 by LiangLiu Committed by GitHub

[CODEGEN][CUDA] Fix vector load (#5226)

* Fix high-low bit bug in __pack_half2

* Fix vector load

* Add unit8 support for PrintVecElemLoadExpr and BroadcastNode
parent 2c1ca60e
...@@ -668,15 +668,7 @@ void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*) ...@@ -668,15 +668,7 @@ void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*)
std::string ref = GetVecLoad(op->dtype, op->buffer_var.get(), base); std::string ref = GetVecLoad(op->dtype, op->buffer_var.get(), base);
HandleVolatileLoads(ref, op, os); HandleVolatileLoads(ref, op, os);
} else { } else {
// The assignment below introduces side-effect, and the resulting value cannot std::ostringstream svalue_expr;
// be reused across multiple expression, thus a new scope is needed
int vec_scope = BeginScope();
// load seperately.
std::string svalue = GetUniqueName("_");
this->PrintIndent();
this->PrintType(op->dtype, stream);
stream << ' ' << svalue << ";\n";
std::string sindex = SSAGetID(PrintExpr(op->index), op->index.dtype()); std::string sindex = SSAGetID(PrintExpr(op->index), op->index.dtype());
std::string vid = GetVarID(op->buffer_var.get()); std::string vid = GetVarID(op->buffer_var.get());
DataType elem_type = op->dtype.element_of(); DataType elem_type = op->dtype.element_of();
...@@ -699,10 +691,9 @@ void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*) ...@@ -699,10 +691,9 @@ void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*)
value_temp << '['; value_temp << '[';
PrintVecElemLoad(sindex, op->index.dtype(), i, value_temp); PrintVecElemLoad(sindex, op->index.dtype(), i, value_temp);
value_temp << ']'; value_temp << ']';
PrintVecElemStore(svalue, op->dtype, i, value_temp.str()); PrintVecElemLoadExpr(op->dtype, i, value_temp.str(), svalue_expr);
} }
os << svalue; os << svalue_expr.str();
EndScope(vec_scope);
} }
} }
} }
...@@ -955,5 +946,30 @@ void CodeGenC::VisitStmt_(const ProducerConsumerNode* op) { ...@@ -955,5 +946,30 @@ void CodeGenC::VisitStmt_(const ProducerConsumerNode* op) {
PrintStmt(op->body); PrintStmt(op->body);
} }
void CodeGenC::PrintVecElemLoadExpr(
DataType t, int i, const std::string& value, std::ostream& os) {
CHECK_GT(t.lanes(), 1);
if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
if (i != 0) {
os << "|";
}
os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8 << "))";
return;
}
if (i == 0) {
os << "((";
PrintType(t, os);
os << t.lanes() << ")(";
}
os << value;
if (i != t.lanes() - 1) {
os << ",";
} else {
os << "))";
}
return;
}
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
...@@ -191,6 +191,8 @@ class CodeGenC : ...@@ -191,6 +191,8 @@ class CodeGenC :
const std::string& vec, DataType t, int i, const std::string& value); const std::string& vec, DataType t, int i, const std::string& value);
// Get a cast type from to // Get a cast type from to
virtual std::string CastFromTo(std::string value, DataType from, DataType target); virtual std::string CastFromTo(std::string value, DataType from, DataType target);
// Get load of single element with expression
virtual void PrintVecElemLoadExpr(DataType t, int i, const std::string& value, std::ostream& os);
protected: protected:
// Print reference to struct location // Print reference to struct location
......
...@@ -591,13 +591,17 @@ void CodeGenCUDA::VisitExpr_(const RampNode* op, std::ostream& os) { ...@@ -591,13 +591,17 @@ void CodeGenCUDA::VisitExpr_(const RampNode* op, std::ostream& os) {
} }
void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*)
if (op->dtype.is_int() && op->dtype.bits() == 8 && op->lanes == 4) { if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8 && op->lanes == 4) {
// make_int8x4 // make_int8x4
const int64_t *p = as_const_int(op->value); const int64_t *p = as_const_int(op->value);
CHECK(p); CHECK(p);
int64_t v = *p & 0xFF; int64_t v = *p & 0xFF;
v = (v << 24) | (v << 16) | (v << 8) | v; v = (v << 24) | (v << 16) | (v << 8) | v;
os << "(int)" << v; if (op->dtype.is_uint()) {
os << "(uint)" << v;
} else {
os << "(int)" << v;
}
return; return;
} }
...@@ -796,5 +800,49 @@ void CodeGenCUDA::HandleVolatileLoads(const std::string& value, ...@@ -796,5 +800,49 @@ void CodeGenCUDA::HandleVolatileLoads(const std::string& value,
} }
} }
void CodeGenCUDA::PrintVecElemLoadExpr(
DataType t, int i, const std::string& value, std::ostream& os) {
CHECK_GT(t.lanes(), 1);
if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
if (i != 0) {
os << "|";
}
os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8 << "))";
return;
}
if (t.is_float16()) {
if (i == 0) {
os << "make_";
PrintType(t, os);
os << '(';
}
if (i % 2 == 0) {
os << "__pack_half2(" << value;
} else {
os << "," << value << ")";
if (i != t.lanes() - 1) {
os << ",";
} else {
os << ")";
}
}
return;
}
if (i == 0) {
os << "make_";
PrintType(t, os);
os << "(";
}
os << value;
if (i != t.lanes() - 1) {
os << ",";
} else {
os << ")";
}
return;
}
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
...@@ -55,6 +55,7 @@ class CodeGenCUDA final : public CodeGenC { ...@@ -55,6 +55,7 @@ class CodeGenCUDA final : public CodeGenC {
void PrintVecElemStore( void PrintVecElemStore(
const std::string& vec, DataType t, int i, const std::string& value) final; const std::string& vec, DataType t, int i, const std::string& value) final;
void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) void BindThreadIndex(const IterVar& iv) final; // NOLINT(*)
void PrintVecElemLoadExpr(DataType t, int i, const std::string& value, std::ostream& os) final;
// overload visitor // overload visitor
void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const ShuffleNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const ShuffleNode* op, std::ostream& os) final; // NOLINT(*)
......
...@@ -291,7 +291,7 @@ static inline __device__ __host__ unsigned ...@@ -291,7 +291,7 @@ static inline __device__ __host__ unsigned
__pack_half2(const half x, const half y) { __pack_half2(const half x, const half y) {
unsigned v0 = *((unsigned short *)&x); unsigned v0 = *((unsigned short *)&x);
unsigned v1 = *((unsigned short *)&y); unsigned v1 = *((unsigned short *)&y);
return (v0 << 16) | v1; return (v1 << 16) | v0;
} }
)"; )";
......
...@@ -543,6 +543,44 @@ def test_vectorized_popcount(): ...@@ -543,6 +543,44 @@ def test_vectorized_popcount():
run_test("uint32") run_test("uint32")
run_test("uint64") run_test("uint64")
def test_cuda_vectorize_load_permute_pad():
def check_cuda(dtype, n, l, padding, lanes):
if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
print("skip because cuda is not enabled..")
return
if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version):
print("Skip because gpu does not have fp16 support")
return
ctx = tvm.gpu(0)
A = tvm.te.placeholder((n, l), name='A', dtype=dtype)
B = tvm.te.compute((n // lanes, l + 2 * padding, lanes),
lambda i, j, k: tvm.te.if_then_else(
tvm.te.any(j < padding, j >= l + padding),
tvm.runtime.convert(0).astype(dtype), A[i * lanes + k, j - padding]),
name='B')
s = te.create_schedule(B.op)
block, thread, vectorize = s[B].op.axis
s[B].bind(block, bx)
s[B].bind(thread, tx)
s[B].vectorize(vectorize)
fun = tvm.build(s, [A, B], "cuda", name="vector_load_permute_pad")
np_a = np.random.randint(
low=-128, high=127, size=(n, l)).astype(A.dtype)
a = tvm.nd.empty((n, l), A.dtype, ctx).copyfrom(np_a)
b = tvm.nd.empty((n // lanes, l + padding * 2, lanes), B.dtype, ctx)
fun(a, b)
np_a_reshape = np_a.reshape(n // lanes, lanes, l).transpose(0, 2, 1)
ref = np.pad(np_a_reshape, ((0, 0), (padding, padding),
(0, 0)), mode='constant', constant_values=0)
tvm.testing.assert_allclose(b.asnumpy(), ref)
check_cuda("int8", 64, 16, 3, 4)
check_cuda("uint8", 64, 16, 3, 4)
check_cuda("int32", 64, 16, 3, 4)
check_cuda("float16", 64, 16, 3, 4)
check_cuda("float32", 64, 16, 3, 4)
if __name__ == "__main__": if __name__ == "__main__":
test_cuda_vectorize_add() test_cuda_vectorize_add()
test_cuda_multiply_add() test_cuda_multiply_add()
...@@ -560,3 +598,4 @@ if __name__ == "__main__": ...@@ -560,3 +598,4 @@ if __name__ == "__main__":
test_vectorized_intrin1() test_vectorized_intrin1()
test_vectorized_intrin2() test_vectorized_intrin2()
test_vectorized_popcount() test_vectorized_popcount()
test_cuda_vectorize_load_permute_pad()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment