Commit 195973c0 by noituIover Committed by Wuwei Lin

Fix CUDA int8x4 vectorize (#3928)

* Fix int8x4 vectorize

* Fix gpu shared/local memory accumulate

* Add test_shared_memory for int8x4

* Adjust test format

* Fix cpplint
parent 880c2603
......@@ -207,7 +207,11 @@ void CodeGenCUDA::PrintVecElemLoad(
const std::string& vec, Type t, int i, std::ostream& os) { // NOLINT(*)
static const char access[] = {'x', 'y', 'z', 'w'};
CHECK(i >= 0 && i < 4);
os << vec << "." << access[i];
if (t.is_int() && t.bits() == 8) {
os << "(0x000000ff & (" << vec << " >> " << i * 8 << "))";
} else {
os << vec << "." << access[i];
}
}
void CodeGenCUDA::PrintVecElemStore(
......@@ -215,7 +219,12 @@ void CodeGenCUDA::PrintVecElemStore(
this->PrintIndent();
static const char access[] = {'x', 'y', 'z', 'w'};
CHECK(i >= 0 && i < 4);
stream << vec << "." << access[i] << " = " << value << ";\n";
if (t.is_int() && t.bits() == 8) {
stream << vec << "=" << vec << " & ~(0x000000ff << " << i * 8 << ") | ("
<< value << " << " << i * 8 << ");\n";
} else {
stream << vec << "." << access[i] << " = " << value << ";\n";
}
}
void CodeGenCUDA::PrintStorageSync(const Call* op) {
......
......@@ -83,10 +83,10 @@ class GPUCodeVerifier : public IRVisitor {
// visit an allocation of a buffer in shared memory, record its size
if (visited_local_buffers_.count(op->buffer_var.get()) != 0) {
size_t size = static_cast<size_t>(op->constant_allocation_size());
local_memory_per_block_ += size * op->type.bytes();
local_memory_per_block_ += size * op->type.bytes() * op->type.lanes();
} else if (visited_shared_buffers_.count(op->buffer_var.get()) != 0) {
size_t size = static_cast<size_t>(op->constant_allocation_size());
shared_memory_per_block_ += size * op->type.bytes();
shared_memory_per_block_ += size * op->type.bytes() * op->type.lanes();
}
}
......
......@@ -52,6 +52,7 @@ def test_cuda_vectorize_add():
check_cuda("float32", 64, 2)
check_cuda("float16", 64, 2)
check_cuda("int8", 64, 4)
def test_cuda_multiply_add():
......
......@@ -24,39 +24,45 @@ def get_verify_pass(valid, **kwargs):
return verify_pass
def test_shared_memory():
N = 1024
M = 128
A = tvm.placeholder((N,), name='A', dtype='float32')
B = tvm.compute((N, ), lambda i: A[i], name='B')
s = tvm.create_schedule([B.op])
AA = s.cache_read(A, "shared", [B])
o, i = s[B].split(s[B].op.axis[0], M)
s[AA].compute_at(s[B], o)
s[B].bind(o, tvm.thread_axis("blockIdx.x"))
s[B].bind(i, tvm.thread_axis("threadIdx.x"))
# shared memory usage: M * 4B
# thread usage: M
for target in ['opencl', 'cuda']:
if not tvm.context(target).exist:
continue
valid = [None]
with tvm.build_config(**{"add_lower_pass": [
(2, get_verify_pass(valid,
max_shared_memory_per_block=4 * M - 1,
max_threads_per_block=M))]}):
tvm.build(s, [A, B], target)
assert not valid[0]
with tvm.build_config(**{"add_lower_pass": [
(2, get_verify_pass(valid,
max_shared_memory_per_block=4 * M,
max_threads_per_block=M))]}):
tvm.build(s, [A, B], target)
assert valid[0]
def check_shared_memory(dtype):
N = 1024
M = 128
tvm_type = tvm.datatype._TVMType(dtype)
type_size = tvm_type.bits // 8 * tvm_type.lanes
A = tvm.placeholder((N,), name='A', dtype=dtype)
B = tvm.compute((N, ), lambda i: A[i], name='B')
s = tvm.create_schedule([B.op])
AA = s.cache_read(A, "shared", [B])
o, i = s[B].split(s[B].op.axis[0], M)
s[AA].compute_at(s[B], o)
s[B].bind(o, tvm.thread_axis("blockIdx.x"))
s[B].bind(i, tvm.thread_axis("threadIdx.x"))
# shared memory usage: M * sizeof(dtype) Bytes
# thread usage: M
for target in ['opencl', 'cuda']:
if not tvm.context(target).exist:
continue
valid = [None]
with tvm.build_config(**{"add_lower_pass": [
(2, get_verify_pass(valid,
max_shared_memory_per_block=type_size * M - 1,
max_threads_per_block=M))]}):
tvm.build(s, [A, B], target)
assert not valid[0]
with tvm.build_config(**{"add_lower_pass": [
(2, get_verify_pass(valid,
max_shared_memory_per_block=type_size * M,
max_threads_per_block=M))]}):
tvm.build(s, [A, B], target)
assert valid[0]
check_shared_memory('float32')
check_shared_memory('int8x4')
def test_local_memory():
N = 1024
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment