Unverified Commit f1e87f1b by boh_inspur Committed by GitHub

[CODEGEN][CUDA] Fix a bug when vectorized load&store was involved for… (#5428)

* [CODEGEN][CUDA] Fix a bug when vectorized load&store was involved for "char2"

* Add unittest for char2

* vector element load support char2&add some unittest for vector element load

* Merge common up logic&Support char3&Add unittest for char3
parent 95a816c9
......@@ -272,9 +272,17 @@ void CodeGenCUDA::PrintVecElemLoad(
static const char access[] = {'x', 'y', 'z', 'w'};
CHECK(i >= 0 && i < (t.is_float16() ? 8 : 4));
if ((t.is_int()) && t.bits() == 8) {
os << "((char)(" << vec << " >> " << i * 8 << "))";
if (t.lanes() == 2 || t.lanes() == 3) {
os << vec << "." << access[i % t.lanes()];
} else {
os << "((char)(" << vec << " >> " << i * 8 << "))";
}
} else if ((t.is_uint()) && t.bits() == 8) {
os << "((unsigned char)(" << vec << " >> " << i * 8 << "))";
if (t.lanes() == 2 || t.lanes() == 3) {
os << vec << "." << access[i % t.lanes()];
} else {
os << "((unsigned char)(" << vec << " >> " << i * 8 << "))";
}
} else if (t.is_float16()) {
os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->"
<< access[i % 2];
......@@ -289,12 +297,17 @@ void CodeGenCUDA::PrintVecElemStore(
static const char access[] = {'x', 'y', 'z', 'w'};
CHECK(i >= 0 && i < (t.is_float16() ? 8 : 4));
if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
stream << vec << "=";
// Do not read the first undef lane.
if (i != 0) {
stream << vec << " & ~(0x000000ff << " << i * 8 << ") |";
if (t.lanes() == 2 || t.lanes() == 3) {
stream << vec << '.' << access[i % t.lanes()] << "="
<< "(" << value << ");\n";
} else {
stream << vec << "=";
// Do not read the first undef lane.
if (i != 0) {
stream << vec << " & ~(0x000000ff << " << i * 8 << ") |";
}
stream << "(" << value << " << " << i * 8 << ");\n";
}
stream << "(" << value << " << " << i * 8 << ");\n";
} else if (t.is_float16()) {
stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->"
<< access[i % 2] << " = " << value << ";\n";
......@@ -789,11 +802,13 @@ void CodeGenCUDA::PrintVecElemLoadExpr(
DataType t, int i, const std::string& value, std::ostream& os) {
CHECK_GT(t.lanes(), 1);
if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
if (i != 0) {
os << "|";
if (!(t.lanes() == 2 || t.lanes() == 3)) {
if (i != 0) {
os << "|";
}
os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8 << "))";
return;
}
os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8 << "))";
return;
}
if (t.is_float16()) {
......
......@@ -55,7 +55,12 @@ def test_cuda_vectorize_add():
check_cuda("float32", 64, 2)
check_cuda("float32", 64, 3)
check_cuda("float32", 64, 4)
check_cuda("int8", 64, 2)
check_cuda("int8", 64, 3)
check_cuda("int8", 64, 4)
check_cuda("uint8", 64, 2)
check_cuda("uint8", 64, 3)
check_cuda("uint8", 64, 4)
check_cuda("float16", 64, 2)
check_cuda("float16", 64, 4)
check_cuda("float16", 64, 6)
......@@ -112,15 +117,17 @@ def test_cuda_vectorize_load():
b = tvm.nd.empty((n,), B.dtype, ctx)
fun(a,b)
tvm.testing.assert_allclose(a.asnumpy(), b.asnumpy())
check_cuda("int8", 64, 2)
check_cuda("int8", 64, 3)
check_cuda("int8", 64, 4)
check_cuda("int8", 64, 8)
check_cuda("int8", 64, 16)
def test_cuda_make_int8x4():
def check_cuda(n, value):
def test_cuda_make_int8():
def check_cuda(n, value, lanes):
if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
print("skip because cuda is not enabled..")
return
lanes = 4
dtype = 'int8'
ctx = tvm.gpu(0)
A = te.compute((n, lanes), lambda i,j: tvm.tir.const(value, dtype=dtype))
......@@ -133,9 +140,15 @@ def test_cuda_make_int8x4():
a = tvm.nd.empty(np_a.shape, dtype, ctx)
fun(a)
np.testing.assert_equal(a.asnumpy(), np_a)
check_cuda(64, 0xAB)
check_cuda(64, 0)
check_cuda(64, -3)
check_cuda(64, 0xAB, 4)
check_cuda(64, 0, 4)
check_cuda(64, -3, 4)
check_cuda(64, 0xAB, 3)
check_cuda(64, 0, 3)
check_cuda(64, -3, 3)
check_cuda(64, 0xAB, 2)
check_cuda(64, 0, 2)
check_cuda(64, -3, 2)
def test_cuda_inf_nan():
......@@ -579,6 +592,8 @@ def test_cuda_vectorize_load_permute_pad():
(0, 0)), mode='constant', constant_values=0)
tvm.testing.assert_allclose(b.asnumpy(), ref)
check_cuda("int8", 64, 16, 3, 2)
check_cuda("uint8", 64, 16, 3, 2)
check_cuda("int8", 64, 16, 3, 4)
check_cuda("uint8", 64, 16, 3, 4)
check_cuda("int32", 64, 16, 3, 4)
......@@ -589,7 +604,7 @@ if __name__ == "__main__":
test_cuda_vectorize_add()
test_cuda_multiply_add()
test_cuda_vectorize_load()
test_cuda_make_int8x4()
test_cuda_make_int8()
test_cuda_inf_nan()
test_cuda_shuffle()
test_vectorized_casts()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment