Commit 34e31c44 by Ding Committed by Tianqi Chen

[PASS] Add VerifyMemory pass and test cases (#410) (#993)

parent 75b93d30
......@@ -440,6 +440,20 @@ LoweredFunc PointerValueTypeRewrite(LoweredFunc f);
* \return Transformed function.
*/
LoweredFunc LowerIntrin(LoweredFunc f, const std::string& target);
/*!
* \brief Verify if memory accesses are legal for a specific target device type.
*
* In the case that tgt is cuda, if not all workload is bound with
* threads, CPU code is generated that tries to access GPU memory,
* which is illegal. This pass performs verification for this case.
*
* \param func The function to be verified.
* \param device_type The target device type.
* \return Success of memory verification.
*/
bool VerifyMemory(LoweredFunc func, int device_type);
} // namespace ir
} // namespace tvm
......
......@@ -424,10 +424,15 @@ def build(sch,
target = _target.current_target() if target is None else target
target = _target.create(target) if target else _target.create("llvm")
device_type = ndarray.context(target.target_name, 0).device_type
fhost = []
fdevice = []
for func in flist:
if not ir_pass.VerifyMemory(func, device_type):
raise ValueError(
"Direct host side access to device memory is detected in %s. "
"Did you forget to bind?" % func.name)
if func.func_type == container.LoweredFunc.MixedFunc:
if BuildConfig.current.detect_global_barrier:
func = ir_pass.ThreadSync(func, "global")
......@@ -449,7 +454,6 @@ def build(sch,
warnings.warn(
"Specified target %s, but cannot find device code, did you do bind?" % target)
device_type = ndarray.context(target.target_name, 0).device_type
fhost = [ir_pass.BindDeviceType(x, device_type) for x in fhost]
fhost = [ir_pass.LowerTVMBuiltin(x) for x in fhost]
......
......@@ -128,5 +128,6 @@ REGISTER_PASS2(LowerThreadAllreduce);
REGISTER_PASS2(LowerIntrin);
REGISTER_PASS1(LowerTVMBuiltin);
REGISTER_PASS1(CombineContextCall);
REGISTER_PASS2(VerifyMemory);
} // namespace ir
} // namespace tvm
......@@ -269,7 +269,11 @@ runtime::Module build(const Array<LoweredFunc>& funcs,
Array<LoweredFunc> fhost;
Array<LoweredFunc> fdevice;
for (const auto &x : funcs) {
for (const auto& x : funcs) {
CHECK(ir::VerifyMemory(x, target.device_type))
<< "Direct host side access to device memory is detected in " << x->func_name()
<< ". Did you forget to bind?";
if (x->func_type == kMixedFunc) {
auto func = x;
if (config->detect_global_barrier) {
......
/*!
* Copyright (c) 2018 by Contributors
* \file verify_memory.cc
* \brief Pass to check if memory accesses are legal.
*/
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_pass.h>
namespace tvm {
namespace ir {
namespace {
/*!
* \brief Verify if memory accesses are legal.
*
* In the case that tgt is cuda, if workload is not bound with
* threads, CPU code is generated that tries to access GPU memory,
* which is illegal.
*
* This pass performs such verification by checking if all Producer/Consumer
* with memory accesses are bound with threads when device type is GPU.
*/
class MemoryAccessVerifier final : protected IRVisitor {
public:
/// Special member functions
//@{
explicit MemoryAccessVerifier(LoweredFunc f, int device_type)
: func_(f), dev_type_(device_type) {}
virtual ~MemoryAccessVerifier() = default;
MemoryAccessVerifier(const MemoryAccessVerifier &) = delete;
MemoryAccessVerifier(MemoryAccessVerifier &&) = delete;
MemoryAccessVerifier &operator=(const MemoryAccessVerifier &) = delete;
MemoryAccessVerifier &operator=(MemoryAccessVerifier &&) = delete;
//@}
/// Interface to perform memory access verification
void Run() {
if (!IsGPUDevice(dev_type_)) return;
IRVisitor::Visit(func_->body);
}
/// Verification result
bool Failed() const { return failure_; }
protected:
/// Visitor implementation
//@{
void Visit(const NodeRef &n) final {
if (Failed()) return;
IRVisitor::Visit(n);
}
void Visit_(const LetStmt *op) final {
// Book keep definitions
defs_[op->var.get()] = op->value;
return IRVisitor::Visit_(op);
}
void Visit_(const AttrStmt *op) final {
if (!InThreadEnv() && (op->attr_key == attr::thread_extent ||
op->attr_key == attr::pipeline_exec_scope)) {
EnterThreadEnv();
IRVisitor::Visit_(op);
ExitThreadEnv();
} else {
IRVisitor::Visit_(op);
}
}
void Visit_(const ProducerConsumer *op) final {
EnterProducerConsumer(op);
IRVisitor::Visit_(op);
ExitProducerConsumer();
}
void Visit_(const Load *op) final {
HandleLoadStoreToVariable(op->buffer_var);
return IRVisitor::Visit_(op);
}
void Visit_(const Store *op) final {
HandleLoadStoreToVariable(op->buffer_var);
return IRVisitor::Visit_(op);
}
//@}
/// Check if the value of a Variable comes from function argument.
bool IsFromFunctionArgs(const Variable *var) const {
const Variable *V = var;
while (true) {
CHECK(V) << "Invalid Variable\n";
// Variable is from function args. Return true.
if (V == func_->args[0].node_.get()) return true;
// The value is expected to come from a tvm_struct_get Call.
// Get the first argument of tvm_struct_get, and continue.
const auto &iter = defs_.find(V);
if (iter == defs_.end()) return false;
const Call *C = iter->second.as<const Call>();
if (!C || C->name != intrinsic::tvm_struct_get) return false;
V = C->args[0].as<Variable>();
}
return false;
}
/// Handle memory access to a Variable
void HandleLoadStoreToVariable(const VarExpr &var) {
// We skip the access within thread env.
if (InThreadEnv()) return;
// We only check access within a producer/consumer.
// Because for load/store out side of producer/consumer,
// they don't have to be in thread env to stay legal (e.g. Load of args).
if (!InProducerConsumer()) return;
// We only handle the variable from function argument.
// If it does not come from args, then it could be allocated internally,
// it may possibly be in host or device address space.
// We do not handle this case, and skip it conservatively.
if (!IsFromFunctionArgs(var.get())) return;
// The verification fails in this case.
SetFailure();
}
/// Status getter/setter
//@{
bool InThreadEnv() const { return in_thread_env_; }
void EnterThreadEnv() { in_thread_env_ = true; }
void ExitThreadEnv() { in_thread_env_ = false; }
bool InProducerConsumer() const { return pc_ != nullptr; }
const ProducerConsumer *GetCurrentProducerConsumer() const { return pc_; }
void EnterProducerConsumer(const ProducerConsumer *pc) { this->pc_ = pc; }
void ExitProducerConsumer() { pc_ = nullptr; }
void SetFailure() { failure_ = true; }
//@}
/// Check if a given DLDeviceType/TVMDeviceExtType value denotes GPU device.
static bool IsGPUDevice(int dev_type) {
return kDLGPU == dev_type || kDLOpenCL == dev_type ||
kDLVulkan == dev_type || kDLMetal == dev_type ||
kDLROCM == dev_type || kOpenGL == dev_type;
}
private:
/// Status of visitor
//@{
bool in_thread_env_{false};
const ProducerConsumer *pc_{nullptr};
bool failure_{false}; ///< If the verification fails (i.e. has illegal access)
//@}
LoweredFunc func_{nullptr}; ///< Function to be verified.
int dev_type_{kDLCPU}; ///< Device type
std::unordered_map<const Variable *, Expr> defs_; ///< Variable definitions
};
} // namespace
/// Interface of VerifyMemory pass
bool VerifyMemory(LoweredFunc func, int device_type) {
MemoryAccessVerifier v(func, device_type);
v.Run();
return !v.Failed();
}
} // namespace ir
} // namespace tvm
import tvm
# The following DLDeviceType/TVMDeviceExtType values
# are originally defined in dlpack.h and c_runtime_api.h.
gpu_devices = [2, 4, 7, 8, 10, 11]
other_devices = [1, 3, 9, 12]
def lower(sch, args):
binds = {}
arg_list = []
for x in args:
if isinstance(x, tvm.tensor.Tensor):
buf = tvm.decl_buffer(x.shape, dtype=x.dtype, name=x.name)
assert x not in binds
binds[x] = buf
arg_list.append(buf)
else:
raise ValueError("args must be Tensor, Buffer or Var")
sch = sch.normalize()
bounds = tvm.schedule.InferBound(sch)
stmt = tvm.schedule.ScheduleOps(sch, bounds)
stmt = tvm.ir_pass.LoopPartition(stmt, False)
stmt = tvm.ir_pass.StorageFlatten(stmt, binds, 64)
func = tvm.ir_pass.MakeAPI(stmt, "myadd", arg_list, 0, True)
return func
# All computations are bound.
# So VerifyMemory pass is expected to succeed.
#
def test_verify_memory_all_bind():
n = tvm.var("n")
A = tvm.placeholder((n,), name='A')
B = tvm.compute(A.shape, lambda i: A[i] + 1.0, name="B")
# B is bound to threads.
s = tvm.create_schedule(B.op)
bx, tx = s[B].split(B.op.axis[0], factor=64)
s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
s[B].bind(tx, tvm.thread_axis("threadIdx.x"))
func = lower(s, [A, B])
for dev_type in gpu_devices + other_devices:
assert tvm.ir_pass.VerifyMemory(func, dev_type)
# Computations are not bound.
# So VerifyMemory pass fails when device type is GPU.
#
def test_verify_memory_not_bind():
n = tvm.var("n")
A = tvm.placeholder((n,), name='A')
B = tvm.compute(A.shape, lambda i: A[i] + 1.0, name="B")
# B is not bound to threads.
s = tvm.create_schedule(B.op)
func = lower(s, [A, B])
for dev_type in gpu_devices:
assert not tvm.ir_pass.VerifyMemory(func, dev_type)
for dev_type in other_devices:
assert tvm.ir_pass.VerifyMemory(func, dev_type)
# Computations are partially bound.
# So VerifyMemory pass fails when device type is GPU.
#
def test_verify_memory_partially_bind():
n = tvm.var("n")
A = tvm.placeholder((n,), name='A')
B = tvm.compute(A.shape, lambda i: A[i] + 1.0, name="B")
C = tvm.compute(B.shape, lambda i: B[i] + 2.0, name="C")
D = tvm.compute(C.shape, lambda i: C[i] + 2.0, name="D")
# C is bound to threads, but B and D are not.
s = tvm.create_schedule([B.op, C.op, D.op])
bx, tx = s[C].split(C.op.axis[0], factor=64)
s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
func = lower(s, [A, B, C, D])
for dev_type in gpu_devices:
assert not tvm.ir_pass.VerifyMemory(func, dev_type)
for dev_type in other_devices:
assert tvm.ir_pass.VerifyMemory(func, dev_type)
if __name__ == "__main__":
test_verify_memory_all_bind()
test_verify_memory_not_bind()
test_verify_memory_partially_bind()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment