Unverified Commit f1e87f1b by boh_inspur Committed by GitHub

[CODEGEN][CUDA] Fix a bug when vectorized load&store was involved for… (#5428)

* [CODEGEN][CUDA] Fix a bug when vectorized load&store was involved for "char2"

* Add unittest for char2

* vector element load support char2&add some unittest for vector element load

* Merge common up logic&Support char3&Add unittest for char3
parent 95a816c9
...@@ -272,9 +272,17 @@ void CodeGenCUDA::PrintVecElemLoad( ...@@ -272,9 +272,17 @@ void CodeGenCUDA::PrintVecElemLoad(
static const char access[] = {'x', 'y', 'z', 'w'}; static const char access[] = {'x', 'y', 'z', 'w'};
CHECK(i >= 0 && i < (t.is_float16() ? 8 : 4)); CHECK(i >= 0 && i < (t.is_float16() ? 8 : 4));
if ((t.is_int()) && t.bits() == 8) { if ((t.is_int()) && t.bits() == 8) {
if (t.lanes() == 2 || t.lanes() == 3) {
os << vec << "." << access[i % t.lanes()];
} else {
os << "((char)(" << vec << " >> " << i * 8 << "))"; os << "((char)(" << vec << " >> " << i * 8 << "))";
}
} else if ((t.is_uint()) && t.bits() == 8) { } else if ((t.is_uint()) && t.bits() == 8) {
if (t.lanes() == 2 || t.lanes() == 3) {
os << vec << "." << access[i % t.lanes()];
} else {
os << "((unsigned char)(" << vec << " >> " << i * 8 << "))"; os << "((unsigned char)(" << vec << " >> " << i * 8 << "))";
}
} else if (t.is_float16()) { } else if (t.is_float16()) {
os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->"
<< access[i % 2]; << access[i % 2];
...@@ -289,12 +297,17 @@ void CodeGenCUDA::PrintVecElemStore( ...@@ -289,12 +297,17 @@ void CodeGenCUDA::PrintVecElemStore(
static const char access[] = {'x', 'y', 'z', 'w'}; static const char access[] = {'x', 'y', 'z', 'w'};
CHECK(i >= 0 && i < (t.is_float16() ? 8 : 4)); CHECK(i >= 0 && i < (t.is_float16() ? 8 : 4));
if (t.bits() == 8 && (t.is_int() || t.is_uint())) { if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
if (t.lanes() == 2 || t.lanes() == 3) {
stream << vec << '.' << access[i % t.lanes()] << "="
<< "(" << value << ");\n";
} else {
stream << vec << "="; stream << vec << "=";
// Do not read the first undef lane. // Do not read the first undef lane.
if (i != 0) { if (i != 0) {
stream << vec << " & ~(0x000000ff << " << i * 8 << ") |"; stream << vec << " & ~(0x000000ff << " << i * 8 << ") |";
} }
stream << "(" << value << " << " << i * 8 << ");\n"; stream << "(" << value << " << " << i * 8 << ");\n";
}
} else if (t.is_float16()) { } else if (t.is_float16()) {
stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->"
<< access[i % 2] << " = " << value << ";\n"; << access[i % 2] << " = " << value << ";\n";
...@@ -789,12 +802,14 @@ void CodeGenCUDA::PrintVecElemLoadExpr( ...@@ -789,12 +802,14 @@ void CodeGenCUDA::PrintVecElemLoadExpr(
DataType t, int i, const std::string& value, std::ostream& os) { DataType t, int i, const std::string& value, std::ostream& os) {
CHECK_GT(t.lanes(), 1); CHECK_GT(t.lanes(), 1);
if (t.bits() == 8 && (t.is_int() || t.is_uint())) { if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
if (!(t.lanes() == 2 || t.lanes() == 3)) {
if (i != 0) { if (i != 0) {
os << "|"; os << "|";
} }
os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8 << "))"; os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8 << "))";
return; return;
} }
}
if (t.is_float16()) { if (t.is_float16()) {
if (i == 0) { if (i == 0) {
......
...@@ -55,7 +55,12 @@ def test_cuda_vectorize_add(): ...@@ -55,7 +55,12 @@ def test_cuda_vectorize_add():
check_cuda("float32", 64, 2) check_cuda("float32", 64, 2)
check_cuda("float32", 64, 3) check_cuda("float32", 64, 3)
check_cuda("float32", 64, 4) check_cuda("float32", 64, 4)
check_cuda("int8", 64, 2)
check_cuda("int8", 64, 3)
check_cuda("int8", 64, 4) check_cuda("int8", 64, 4)
check_cuda("uint8", 64, 2)
check_cuda("uint8", 64, 3)
check_cuda("uint8", 64, 4)
check_cuda("float16", 64, 2) check_cuda("float16", 64, 2)
check_cuda("float16", 64, 4) check_cuda("float16", 64, 4)
check_cuda("float16", 64, 6) check_cuda("float16", 64, 6)
...@@ -112,15 +117,17 @@ def test_cuda_vectorize_load(): ...@@ -112,15 +117,17 @@ def test_cuda_vectorize_load():
b = tvm.nd.empty((n,), B.dtype, ctx) b = tvm.nd.empty((n,), B.dtype, ctx)
fun(a,b) fun(a,b)
tvm.testing.assert_allclose(a.asnumpy(), b.asnumpy()) tvm.testing.assert_allclose(a.asnumpy(), b.asnumpy())
check_cuda("int8", 64, 2)
check_cuda("int8", 64, 3)
check_cuda("int8", 64, 4)
check_cuda("int8", 64, 8) check_cuda("int8", 64, 8)
check_cuda("int8", 64, 16) check_cuda("int8", 64, 16)
def test_cuda_make_int8x4(): def test_cuda_make_int8():
def check_cuda(n, value): def check_cuda(n, value, lanes):
if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
print("skip because cuda is not enabled..") print("skip because cuda is not enabled..")
return return
lanes = 4
dtype = 'int8' dtype = 'int8'
ctx = tvm.gpu(0) ctx = tvm.gpu(0)
A = te.compute((n, lanes), lambda i,j: tvm.tir.const(value, dtype=dtype)) A = te.compute((n, lanes), lambda i,j: tvm.tir.const(value, dtype=dtype))
...@@ -133,9 +140,15 @@ def test_cuda_make_int8x4(): ...@@ -133,9 +140,15 @@ def test_cuda_make_int8x4():
a = tvm.nd.empty(np_a.shape, dtype, ctx) a = tvm.nd.empty(np_a.shape, dtype, ctx)
fun(a) fun(a)
np.testing.assert_equal(a.asnumpy(), np_a) np.testing.assert_equal(a.asnumpy(), np_a)
check_cuda(64, 0xAB) check_cuda(64, 0xAB, 4)
check_cuda(64, 0) check_cuda(64, 0, 4)
check_cuda(64, -3) check_cuda(64, -3, 4)
check_cuda(64, 0xAB, 3)
check_cuda(64, 0, 3)
check_cuda(64, -3, 3)
check_cuda(64, 0xAB, 2)
check_cuda(64, 0, 2)
check_cuda(64, -3, 2)
def test_cuda_inf_nan(): def test_cuda_inf_nan():
...@@ -579,6 +592,8 @@ def test_cuda_vectorize_load_permute_pad(): ...@@ -579,6 +592,8 @@ def test_cuda_vectorize_load_permute_pad():
(0, 0)), mode='constant', constant_values=0) (0, 0)), mode='constant', constant_values=0)
tvm.testing.assert_allclose(b.asnumpy(), ref) tvm.testing.assert_allclose(b.asnumpy(), ref)
check_cuda("int8", 64, 16, 3, 2)
check_cuda("uint8", 64, 16, 3, 2)
check_cuda("int8", 64, 16, 3, 4) check_cuda("int8", 64, 16, 3, 4)
check_cuda("uint8", 64, 16, 3, 4) check_cuda("uint8", 64, 16, 3, 4)
check_cuda("int32", 64, 16, 3, 4) check_cuda("int32", 64, 16, 3, 4)
...@@ -589,7 +604,7 @@ if __name__ == "__main__": ...@@ -589,7 +604,7 @@ if __name__ == "__main__":
test_cuda_vectorize_add() test_cuda_vectorize_add()
test_cuda_multiply_add() test_cuda_multiply_add()
test_cuda_vectorize_load() test_cuda_vectorize_load()
test_cuda_make_int8x4() test_cuda_make_int8()
test_cuda_inf_nan() test_cuda_inf_nan()
test_cuda_shuffle() test_cuda_shuffle()
test_vectorized_casts() test_vectorized_casts()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment