Commit 531bb7c4 by Lianmin Zheng Committed by Tianqi Chen

[PASS] Add GPU IR verifier (#1296)

parent f216b25e
......@@ -477,6 +477,30 @@ LoweredFunc LowerIntrin(LoweredFunc f, const std::string& target);
*/
bool VerifyMemory(LoweredFunc func, int device_type);
/*!
* \brief Verify the correctness of a GPU code
* It will check the whether the amount of memory usage or the number of threads
* in a block exceeds the limit
* \param stmt The statement to be checked
* \param constraints The dict to specify constraints to check.
* Possible keys are
*
* "max_local_memory_per_block": Total amount of local memory per block (in bytes).
* "max_shared_memory_per_block": Total amount of shared memory per block (in bytes).
* "max_thread_per_block": Maximum number of threads per block.
* "max_thread_x": Maximum length of threadIdx.x.
* "max_thread_y": Maximum length of threadIdx.y.
* "max_thread_z": Maximum length of threadIdx.z.
*
* If one key is missing in this argument, the pass won't check for that item.
* \return valid Whether it is a valid GPU code
*
*/
bool VerifyGPUCode(Stmt stmt,
Map<std::string, Expr> constraints);
} // namespace ir
} // namespace tvm
......
......@@ -23,7 +23,8 @@ enum DeviceAttrKind : int {
kComputeVersion = 4,
kDeviceName = 5,
kMaxClockRate = 6,
kMultiProcessorCount = 7
kMultiProcessorCount = 7,
kMaxThreadDimensions = 8
};
/*! \brief Number of bytes each allocation must align to */
......
......@@ -3,6 +3,7 @@
from __future__ import absolute_import
import ctypes
import json
import numpy as np
from .base import _LIB, check_call
from .. import _api_internal
......@@ -178,6 +179,18 @@ class TVMContext(ctypes.Structure):
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 7)
@property
def max_thread_dimensions(self):
"""Return the maximum size of each thread axis
Returns
-------
dims: List of int
The maximum length of threadIdx.x, threadIdx.y, threadIdx.z
"""
return json.loads(_api_internal._GetDeviceAttr(
self.device_type, self.device_id, 8))
def sync(self):
"""Synchronize until jobs finished at the context."""
check_call(_LIB.TVMSynchronize(self.device_type, self.device_id, None))
......
......@@ -131,5 +131,6 @@ REGISTER_PASS2(LowerIntrin);
REGISTER_PASS1(LowerTVMBuiltin);
REGISTER_PASS1(CombineContextCall);
REGISTER_PASS2(VerifyMemory);
REGISTER_PASS2(VerifyGPUCode);
} // namespace ir
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file verify_gpu_code.cc
* \brief Verify the correctness of a GPU IR.
* It will check the whether the amount of memory usage or the number of threads
* in a block exceeds the limit
*/
#include <tvm/api_registry.h>
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
namespace tvm {
namespace ir {
class GPUCodeVerifier : public IRVisitor {
public:
bool Verify(tvm::Stmt stmt,
int64_t max_local_memory_per_block,
int64_t max_shared_memory_per_block,
int64_t max_thread_per_block,
int64_t max_thread_x,
int64_t max_thread_y,
int64_t max_thread_z) {
max_local_memory_per_block_ = static_cast<size_t>(max_local_memory_per_block);
max_shared_memory_per_block_ = static_cast<size_t>(max_shared_memory_per_block);
max_thread_per_block_ = static_cast<size_t>(max_thread_per_block);
max_thread_x_ = static_cast<size_t>(max_thread_x);
max_thread_y_ = static_cast<size_t>(max_thread_y);
max_thread_z_ = static_cast<size_t>(max_thread_z);
Reset_();
this->Visit(stmt);
return valid_;
}
void Visit_(const ProducerConsumer *op) {
if (nest_level_ == 0) {
// enter a new kernel, reset statistics
Reset_();
}
if (op->is_producer) {
nest_level_++;
IRVisitor::Visit_(op);
nest_level_--;
} else {
IRVisitor::Visit_(op);
}
if (nest_level_ == 0) {
// exit a kernel, check the validity
valid_ &= thread_per_block_ <= max_thread_per_block_;
valid_ &= local_memory_per_block_ <= max_local_memory_per_block_;
valid_ &= shared_memory_per_block_ <= max_shared_memory_per_block_;
}
}
void Visit_(const Allocate *op) {
IRVisitor::Visit_(op);
// visit an allocation of a buffer in shared memory, record its size
if (visited_local_buffers_.count(op->buffer_var.get()) != 0) {
size_t size = static_cast<size_t>(op->constant_allocation_size());
local_memory_per_block_ += size * op->type.bytes();
} else if (visited_shared_buffers_.count(op->buffer_var.get()) != 0) {
size_t size = static_cast<size_t>(op->constant_allocation_size());
shared_memory_per_block_ += size * op->type.bytes();
}
}
void Visit_(const AttrStmt *op) {
if (op->attr_key == attr::storage_scope) {
if (op->value.as<StringImm>()->value == "local") {
visited_local_buffers_.insert(op->node.as<tvm::Variable>());
} else if (op->value.as<StringImm>()->value == "shared") {
visited_shared_buffers_.insert(op->node.as<tvm::Variable>());
}
} else if (op->attr_key == attr::thread_extent) {
VarExpr var = op->node.as<tvm::IterVarNode>()->var;
const auto *extent = op->value.as<IntImm>();
CHECK(extent);
// record the number of threads in a block
std::string name = var.get()->name_hint;
if (name == "threadIdx.x" || name == "threadIdx.y" || name == "threadIdx.z") {
if (!visited_threads_.count(name)) {
visited_threads_.insert(name);
size_t length = static_cast<size_t>(extent->value);
thread_per_block_ *= length;
if (name == "threadIdx.x") {
valid_ &= length <= max_thread_x_;
} else if (name == "threadIdx.y") {
valid_ &= length <= max_thread_y_;
} else if (name == "threadIdx.z") {
valid_ &= length <= max_thread_z_;
}
}
}
}
IRVisitor::Visit_(op);
}
private:
int nest_level_{0};
std::unordered_set<const tvm::Variable *> visited_local_buffers_;
std::unordered_set<const tvm::Variable *> visited_shared_buffers_;
std::unordered_set<std::string> visited_threads_;
size_t local_memory_per_block_;
size_t shared_memory_per_block_;
size_t thread_per_block_;
size_t max_local_memory_per_block_;
size_t max_shared_memory_per_block_;
size_t max_thread_per_block_;
size_t max_thread_x_, max_thread_y_, max_thread_z_;
bool valid_{true};
void Reset_() {
visited_local_buffers_.clear();
visited_shared_buffers_.clear();
local_memory_per_block_ = 0;
shared_memory_per_block_ = 0;
visited_threads_.clear();
thread_per_block_ = 1;
}
};
bool VerifyGPUCode(Stmt stmt,
Map<std::string, Expr> constraints) {
GPUCodeVerifier verifier;
auto get_int = [&constraints](std::string key, int64_t def) {
auto iter = constraints.find(key);
if (iter != constraints.end()) {
return ((*iter).second).as<IntImm>()->value;
} else {
return def;
}
};
int64_t max_local_memory_per_block = get_int("max_local_memory_per_block", INT64_MAX);
int64_t max_shared_memory_per_block = get_int("max_shared_memory_per_block", INT64_MAX);
int64_t max_thread_per_block = get_int("max_thread_per_block", INT64_MAX);
int64_t max_thread_x = get_int("max_thread_x", INT64_MAX);
int64_t max_thread_y = get_int("max_thread_y", INT64_MAX);
int64_t max_thread_z = get_int("max_thread_z", INT64_MAX);
return verifier.Verify(stmt,
max_local_memory_per_block,
max_shared_memory_per_block,
max_thread_per_block,
max_thread_x,
max_thread_y,
max_thread_z);
}
} // namespace ir
} // namespace tvm
......@@ -5,10 +5,12 @@
*/
#include <tvm/runtime/device_api.h>
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h>
#include <cuda_runtime.h>
#include <tvm/container.h>
#include <tvm/ir.h>
#include <tvm/packed_func_ext.h>
#include "./cuda_common.h"
namespace tvm {
......@@ -70,6 +72,20 @@ class CUDADeviceAPI final : public DeviceAPI {
&value, cudaDevAttrMultiProcessorCount, ctx.device_id));
break;
}
case kMaxThreadDimensions: {
int dims[3];
CUDA_CALL(cudaDeviceGetAttribute(
&dims[0], cudaDevAttrMaxBlockDimX, ctx.device_id));
CUDA_CALL(cudaDeviceGetAttribute(
&dims[1], cudaDevAttrMaxBlockDimY, ctx.device_id));
CUDA_CALL(cudaDeviceGetAttribute(
&dims[2], cudaDevAttrMaxBlockDimZ, ctx.device_id));
std::stringstream ss; // use json string to return multiple int values;
ss << "[" << dims[0] <<", " << dims[1] << ", " << dims[2] << "]";
*rv = ss.str();
return;
}
}
*rv = value;
}
......
......@@ -42,6 +42,7 @@ void MetalWorkspace::GetAttr(
case kDeviceName: return;
case kMaxClockRate: return;
case kMultiProcessorCount: return;
case kMaxThreadDimensions: return;
case kExist: break;
}
}
......
......@@ -4,6 +4,9 @@
*/
#include <tvm/runtime/registry.h>
#include <dmlc/thread_local.h>
#include <tvm/container.h>
#include <tvm/ir.h>
#include <tvm/packed_func_ext.h>
#include "./opencl_common.h"
namespace tvm {
......@@ -30,6 +33,7 @@ void OpenCLWorkspace::GetAttr(
CHECK_LT(index, devices.size())
<< "Invalid device id " << index;
switch (kind) {
case kExist: break;
case kMaxThreadsPerBlock: {
size_t value;
OPENCL_CALL(clGetDeviceInfo(
......@@ -80,7 +84,16 @@ void OpenCLWorkspace::GetAttr(
*rv = static_cast<int32_t>(value);
break;
}
case kExist: break;
case kMaxThreadDimensions: {
size_t dims[3];
OPENCL_CALL(clGetDeviceInfo(
devices[index], CL_DEVICE_MAX_WORK_ITEM_SIZES, sizeof(dims), dims, nullptr));
std::stringstream ss; // use json string to return multiple int values;
ss << "[" << dims[0] <<", " << dims[1] << ", " << dims[2] << "]";
*rv = ss.str();
break;
}
}
}
......
......@@ -97,6 +97,7 @@ void OpenGLWorkspace::GetAttr(
case kDeviceName: return;
case kMaxClockRate: return;
case kMultiProcessorCount: return;
case kMaxThreadDimensions: return;
}
}
......
......@@ -52,6 +52,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
case kDeviceName: return;
case kMaxClockRate: return;
case kMultiProcessorCount: return;
case kMaxThreadDimensions: return;
}
*rv = value;
}
......
......@@ -73,6 +73,7 @@ void VulkanWorkspace::GetAttr(
case kMaxClockRate: return;
case kMultiProcessorCount: return;
case kExist: break;
case kMaxThreadDimensions: break;
}
}
......
"""Test gpu code verifier"""
import tvm
def get_verify_pass(valid, **kwargs):
def verify_pass(stmt):
valid[0] = tvm.ir_pass.VerifyGPUCode(stmt, kwargs)
return stmt
return verify_pass
def test_shared_memory():
N = 1024
M = 128
A = tvm.placeholder((N,), name='A', dtype='float32')
B = tvm.compute((N, ), lambda i: A[i], name='B')
s = tvm.create_schedule([B.op])
AA = s.cache_read(A, "shared", [B])
o, i = s[B].split(s[B].op.axis[0], M)
s[AA].compute_at(s[B], o)
s[B].bind(o, tvm.thread_axis("blockIdx.x"))
s[B].bind(i, tvm.thread_axis("threadIdx.x"))
# shared memory usage: M * 4B
# thread usage: M
for target in ['opencl', 'cuda']:
if not tvm.context(target).exist:
continue
valid = [None]
with tvm.build_config(**{"add_lower_pass": [
(2, get_verify_pass(valid,
max_shared_memory_per_block=4 * M - 1,
max_thread_per_block=M))]}):
tvm.build(s, [A, B], target)
assert not valid[0]
with tvm.build_config(**{"add_lower_pass": [
(2, get_verify_pass(valid,
max_shared_memory_per_block=4 * M,
max_thread_per_block=M))]}):
tvm.build(s, [A, B], target)
assert valid[0]
def test_local_memory():
N = 1024
M = 128
A = tvm.placeholder((N,), name='A', dtype='float32')
B = tvm.compute((N, ), lambda i: A[i], name='B')
s = tvm.create_schedule([B.op])
AA = s.cache_read(A, "local", [B])
o, i = s[B].split(s[B].op.axis[0], M)
s[AA].compute_at(s[B], o)
s[B].bind(o, tvm.thread_axis("blockIdx.x"))
# local memory usage: M * 4B
# thread usage: M
for target in ['opencl', 'cuda']:
if not tvm.context(target).exist:
continue
valid = [None]
with tvm.build_config(**{"add_lower_pass": [
(2, get_verify_pass(valid,
max_local_memory_per_block=4 * M - 1,
max_thread_per_block=1))]}):
tvm.build(s, [A, B], target)
assert not valid[0]
with tvm.build_config(**{"add_lower_pass": [
(2, get_verify_pass(valid,
max_local_memory_per_block=4 * M,
max_thread_per_block=1))]}):
tvm.build(s, [A, B], target)
assert valid[0]
def test_num_thread():
N = 1024
M = 128
A = tvm.placeholder((N,), name='A', dtype='float32')
B = tvm.compute((N, ), lambda i: A[i], name='B')
s = tvm.create_schedule([B.op])
o, i = s[B].split(s[B].op.axis[0], M)
s[B].bind(o, tvm.thread_axis('threadIdx.x'))
s[B].bind(i, tvm.thread_axis("threadIdx.y"))
# shared memory usage: 0
# thread usage: N
for target in ['opencl', 'cuda']:
if not tvm.context(target).exist:
continue
valid = [None]
with tvm.build_config(**{"add_lower_pass": [
(2, get_verify_pass(valid,
max_shared_memory_per_block=0,
max_thread_per_block=N - 1))]}):
tvm.build(s, [A, B], target)
assert not valid[0]
with tvm.build_config(**{"add_lower_pass": [
(2, get_verify_pass(valid,
max_shared_memory_per_block=0,
max_thread_per_block=N))]}):
tvm.build(s, [A, B], target)
assert valid[0]
with tvm.build_config(**{"add_lower_pass": [
(2, get_verify_pass(valid,
max_shared_memory_per_block=0,
max_thread_per_block=N,
max_thread_y=M-1))]}):
tvm.build(s, [A, B], target)
assert not valid[0]
with tvm.build_config(**{"add_lower_pass": [
(2, get_verify_pass(valid,
max_shared_memory_per_block=0,
max_thread_per_block=N,
max_thread_y=M))]}):
tvm.build(s, [A, B], target)
assert valid[0]
def test_multiple_kernels():
N = 1024
A = tvm.placeholder((N, N), name='A')
B = tvm.compute((N, N), lambda i, j: A[i, j])
C = tvm.compute((N, N), lambda i, j: B[i, j])
s = tvm.create_schedule([C.op])
s[C].bind(s[C].op.axis[1], tvm.thread_axis("threadIdx.x"))
s[B].bind(s[B].op.axis[1], tvm.thread_axis("threadIdx.x"))
# shared memory usage: 0
# thread usage: N
for target in ['opencl', 'cuda']:
if not tvm.context(target).exist:
continue
valid = [None]
with tvm.build_config(**{"add_lower_pass": [
(2, get_verify_pass(valid,
max_shared_memory_per_block=0,
max_thread_per_block=N - 1))]}):
tvm.build(s, [A, C], target)
assert not valid[0]
with tvm.build_config(**{"add_lower_pass": [
(2, get_verify_pass(valid,
max_shared_memory_per_block=0,
max_thread_per_block=N))]}):
tvm.build(s, [A, C], target)
assert valid[0]
if __name__ == "__main__":
test_local_memory()
test_shared_memory()
test_num_thread()
test_multiple_kernels()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment