Commit 2f1d709f by Denis Khalikov Committed by Thierry Moreau

[PASS] InstrumentBoundCheckers pass (#2079)

The pass which instruments checkers before
memory accesses (load/store).
This allows to handle invalid memory accesses.

The patch is related to issue:
https://discuss.tvm.ai/t/array-bounds-checking/944
parent 2a5656bf
...@@ -220,6 +220,9 @@ class BuildConfigNode : public Node { ...@@ -220,6 +220,9 @@ class BuildConfigNode : public Node {
/*! \brief Whether to dump the IR of each pass (only when building from python) */ /*! \brief Whether to dump the IR of each pass (only when building from python) */
bool dump_pass_ir = false; bool dump_pass_ir = false;
/*! \brief Whether to instrument loads and stores with check for out of the bounds. */
bool instrument_bound_checkers = false;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("data_alignment", &data_alignment); v->Visit("data_alignment", &data_alignment);
v->Visit("offset_factor", &offset_factor); v->Visit("offset_factor", &offset_factor);
...@@ -232,6 +235,7 @@ class BuildConfigNode : public Node { ...@@ -232,6 +235,7 @@ class BuildConfigNode : public Node {
v->Visit("detect_global_barrier", &detect_global_barrier); v->Visit("detect_global_barrier", &detect_global_barrier);
v->Visit("partition_const_loop", &partition_const_loop); v->Visit("partition_const_loop", &partition_const_loop);
v->Visit("dump_pass_ir", &dump_pass_ir); v->Visit("dump_pass_ir", &dump_pass_ir);
v->Visit("instrument_bound_checkers", &instrument_bound_checkers);
} }
static constexpr const char* _type_key = "BuildConfig"; static constexpr const char* _type_key = "BuildConfig";
......
...@@ -206,6 +206,8 @@ constexpr const char* scan_init_scope = "scan_init_scope"; ...@@ -206,6 +206,8 @@ constexpr const char* scan_init_scope = "scan_init_scope";
* This gives hint to require stride of dim to be k * align + offset. * This gives hint to require stride of dim to be k * align + offset.
*/ */
constexpr const char* buffer_dim_align = "buffer_dim_align"; constexpr const char* buffer_dim_align = "buffer_dim_align";
/*! \brief Mark stores/loads with theirs bounds. */
constexpr const char* buffer_bound = "buffer_bound";
/*! /*!
* \brief Bind the buffer specification to the region of the op * \brief Bind the buffer specification to the region of the op
* When this scope occurs, the stmt.node is a Array<NodeRef> = [buffer, tensor] * When this scope occurs, the stmt.node is a Array<NodeRef> = [buffer, tensor]
......
...@@ -181,11 +181,13 @@ Stmt Inline(Stmt stmt, ...@@ -181,11 +181,13 @@ Stmt Inline(Stmt stmt,
* \param extern_buffer Map specifies external * \param extern_buffer Map specifies external
* buffer assignment of input and outputs. * buffer assignment of input and outputs.
* \param cache_line_size The size of CPU cache line. * \param cache_line_size The size of CPU cache line.
* \param create_bound_attribute Whether to create bound attributes.
* \return Transformed stmt. * \return Transformed stmt.
*/ */
Stmt StorageFlatten(Stmt stmt, Stmt StorageFlatten(Stmt stmt,
Map<Tensor, Buffer> extern_buffer, Map<Tensor, Buffer> extern_buffer,
int cache_line_size); int cache_line_size,
bool create_bound_attribute = false);
/*! /*!
* \brief Remove No Op from the Stmt. * \brief Remove No Op from the Stmt.
...@@ -235,6 +237,13 @@ Stmt UnrollLoop(Stmt stmt, ...@@ -235,6 +237,13 @@ Stmt UnrollLoop(Stmt stmt,
Stmt VectorizeLoop(Stmt stmt); Stmt VectorizeLoop(Stmt stmt);
/*! /*!
* \brief instruments bound checkers.
* \param stmt The statment to be instrumented.
* \return Instrumented Stmt.
*/
Stmt InstrumentBoundCheckers(Stmt stmt);
/*!
* \brief Inject virtual thread loops into stmt. * \brief Inject virtual thread loops into stmt.
* \param stmt The statment to be transformed. * \param stmt The statment to be transformed.
* \return Transformed stmt. * \return Transformed stmt.
......
...@@ -125,7 +125,8 @@ class BuildConfig(NodeBase): ...@@ -125,7 +125,8 @@ class BuildConfig(NodeBase):
"data_alignment": -1, "data_alignment": -1,
"restricted_func": True, "restricted_func": True,
"double_buffer_split_loop": 1, "double_buffer_split_loop": 1,
"dump_pass_ir": False "dump_pass_ir": False,
"instrument_bound_checkers": False
} }
_dump_ir = DumpIR() _dump_ir = DumpIR()
...@@ -344,7 +345,7 @@ def lower(sch, ...@@ -344,7 +345,7 @@ def lower(sch,
for f in lower_phase0: for f in lower_phase0:
stmt = f(stmt) stmt = f(stmt)
# Phase 1 # Phase 1
stmt = ir_pass.StorageFlatten(stmt, binds, 64) stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers)
stmt = ir_pass.CanonicalSimplify(stmt) stmt = ir_pass.CanonicalSimplify(stmt)
for f in lower_phase1: for f in lower_phase1:
stmt = f(stmt) stmt = f(stmt)
...@@ -370,6 +371,9 @@ def lower(sch, ...@@ -370,6 +371,9 @@ def lower(sch,
stmt = ir_pass.RewriteUnsafeSelect(stmt) stmt = ir_pass.RewriteUnsafeSelect(stmt)
for f in lower_phase3: for f in lower_phase3:
stmt = f(stmt) stmt = f(stmt)
# Instrument BoundCheckers
if cfg.instrument_bound_checkers:
stmt = ir_pass.InstrumentBoundCheckers(stmt)
if simple_mode: if simple_mode:
return stmt return stmt
return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func) return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func)
......
...@@ -66,6 +66,14 @@ TVM_REGISTER_API("ir_pass.Equal") ...@@ -66,6 +66,14 @@ TVM_REGISTER_API("ir_pass.Equal")
} }
}); });
TVM_REGISTER_API("ir_pass.StorageFlatten")
.set_body([](TVMArgs args, TVMRetValue *ret) {
if (args.size() <= 3) {
*ret = StorageFlatten(args[0], args[1], args[2]);
} else {
*ret = StorageFlatten(args[0], args[1], args[2], args[3]);
}
});
TVM_REGISTER_API("ir_pass.AttrsEqual") TVM_REGISTER_API("ir_pass.AttrsEqual")
.set_body_typed<bool(const NodeRef&, const NodeRef&)>([](const NodeRef& lhs, const NodeRef& rhs) { .set_body_typed<bool(const NodeRef&, const NodeRef&)>([](const NodeRef& lhs, const NodeRef& rhs) {
...@@ -126,7 +134,6 @@ REGISTER_PASS1(ConvertSSA); ...@@ -126,7 +134,6 @@ REGISTER_PASS1(ConvertSSA);
REGISTER_PASS1(VerifySSA); REGISTER_PASS1(VerifySSA);
REGISTER_PASS1(RewriteUnsafeSelect); REGISTER_PASS1(RewriteUnsafeSelect);
REGISTER_PASS4(Inline); REGISTER_PASS4(Inline);
REGISTER_PASS3(StorageFlatten);
REGISTER_PASS4(IRTransform); REGISTER_PASS4(IRTransform);
REGISTER_PASS1(VectorizeLoop); REGISTER_PASS1(VectorizeLoop);
REGISTER_PASS5(UnrollLoop); REGISTER_PASS5(UnrollLoop);
...@@ -155,5 +162,6 @@ REGISTER_PASS1(CombineContextCall); ...@@ -155,5 +162,6 @@ REGISTER_PASS1(CombineContextCall);
REGISTER_PASS2(VerifyMemory); REGISTER_PASS2(VerifyMemory);
REGISTER_PASS2(VerifyGPUCode); REGISTER_PASS2(VerifyGPUCode);
REGISTER_PASS1(DecorateDeviceScope); REGISTER_PASS1(DecorateDeviceScope);
REGISTER_PASS1(InstrumentBoundCheckers);
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
...@@ -364,7 +364,8 @@ Stmt BuildStmt(Schedule sch, ...@@ -364,7 +364,8 @@ Stmt BuildStmt(Schedule sch,
stmt = ir::InjectPrefetch(stmt); stmt = ir::InjectPrefetch(stmt);
// Phase 1 // Phase 1
stmt = ir::StorageFlatten(stmt, out_binds, 64); stmt = ir::StorageFlatten(stmt, out_binds, 64,
config->instrument_bound_checkers);
stmt = ir::CanonicalSimplify(stmt); stmt = ir::CanonicalSimplify(stmt);
if (loop_partition) { if (loop_partition) {
stmt = ir::LoopPartition(stmt, config->partition_const_loop); stmt = ir::LoopPartition(stmt, config->partition_const_loop);
...@@ -382,6 +383,9 @@ Stmt BuildStmt(Schedule sch, ...@@ -382,6 +383,9 @@ Stmt BuildStmt(Schedule sch,
stmt = ir::RemoveNoOp(stmt); stmt = ir::RemoveNoOp(stmt);
stmt = ir::RewriteUnsafeSelect(stmt); stmt = ir::RewriteUnsafeSelect(stmt);
if (config->instrument_bound_checkers)
stmt = ir::InstrumentBoundCheckers(stmt);
return stmt; return stmt;
} }
......
/*!
* Copyright (c) 2018 by Contributors
* \file bounds_checker.cc
*/
// Instrument checkers for out of the bounds access.
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <vector>
#include <unordered_map>
#include <utility>
namespace tvm {
namespace ir {
class BoundCollector : public IRVisitor {
public:
BoundCollector() {}
void Visit_(const AttrStmt *op) {
if (op->attr_key == ir::attr::buffer_bound) {
if (const Variable *key = op->node.as<Variable>()) {
mem_to_shape[key] = op->value;
}
}
IRVisitor::Visit_(op);
}
// Hashtable which maps buffer_var to shape.
std::unordered_map<const Variable *, Expr> mem_to_shape;
};
class BoundChecker : public IRMutator {
public:
explicit BoundChecker(
const std::unordered_map<const Variable *, Expr> &mem_to_shape)
: mem_to_shape_(mem_to_shape) {}
Stmt Mutate_(const Allocate *op, const Stmt &s) final {
// If the shape was updated we should update the hashtable.
if (UpdateIsNeeded(op->buffer_var)) {
Update(op->buffer_var, op->extents, op->type);
}
return IRMutator::Mutate_(op, s);
}
Expr Mutate_(const Call *op, const Expr &ex) final {
if (process_store_ && op->is_intrinsic(intrinsic::tvm_if_then_else)) {
unsafe_rewritten_ = true;
}
return IRMutator::Mutate_(op, ex);
}
Stmt Mutate_(const Store *op, const Stmt &s) final {
store_scope_bound_collector_.clear();
process_store_ = true;
unsafe_rewritten_ = false;
IRMutator::Mutate_(op, s);
process_store_ = false;
if (CanInstrument(op->index, op->buffer_var)) {
Collect(op->index, op->buffer_var);
}
// The collector should has at least one item.
if (store_scope_bound_collector_.size()) {
Expr condition = MakeCondition();
if (!condition.as<StringImm>()) {
Stmt nop = Evaluate::make(1);
Stmt then_case =
Store::make(op->buffer_var, op->value, op->index, op->predicate);
Stmt else_case =
AssertStmt::make(condition, StringImm::make(error_message_), nop);
Stmt body = IfThenElse::make(condition, then_case, else_case);
return body;
}
}
return s;
}
Expr Mutate_(const Load *op, const Expr &ex) final {
if (CanInstrument(op->index, op->buffer_var)) {
Collect(op->index, op->buffer_var);
}
return IRMutator::Mutate_(op, ex);
}
private:
bool UpdateIsNeeded(const VarExpr &buffer_var) const {
return (buffer_var.defined() && mem_to_shape_.count(buffer_var.get()));
}
void Update(const VarExpr &buffer_var, const Array<Expr> &new_shape,
const Type &type) {
// Sanity check at first.
if (!new_shape.size()) {
return;
}
for (size_t i = 0; i < new_shape.size(); ++i) {
if (!new_shape[0].defined() || !new_shape[i].type().is_scalar() ||
is_negative_const(new_shape[i])) {
return;
}
}
// Scalarize the shape.
Expr shape = Mul::make(make_const(UInt(64), type.lanes()),
Cast::make(UInt(64), new_shape[0]));
for (size_t i = 1; i < new_shape.size(); ++i) {
// Cast to unsigned to avoid integer overlow at frist.
shape = Mul::make(shape, Mul::make(make_const(UInt(64), type.lanes()),
Cast::make(UInt(64), new_shape[i])));
}
mem_to_shape_[buffer_var.get()] = shape;
}
bool IndexIsValid(const Expr &index) const {
if (!index.defined()) {
return false;
}
if (const Ramp *ramp_index = index.as<Ramp>()) {
return ramp_index->base.defined() &&
ramp_index->base.type().is_scalar() &&
ramp_index->stride.defined() &&
ramp_index->stride.type().is_scalar() && (ramp_index->lanes > 0);
}
return true;
}
bool CanInstrument(const Expr &index, const VarExpr &buffer_var) const {
return buffer_var.defined() && mem_to_shape_.count(buffer_var.get()) &&
IndexIsValid(index) && !unsafe_rewritten_;
}
void Collect(Expr index, VarExpr buffer_var) {
store_scope_bound_collector_.push_back(
std::make_pair(index, mem_to_shape_[buffer_var.get()]));
}
Expr MakeCondition() {
Expr condition;
for (size_t i = 0; i < store_scope_bound_collector_.size(); ++i) {
std::pair<Expr, Expr> buffer_to_mem = store_scope_bound_collector_[i];
Expr index = buffer_to_mem.first;
Expr upper_bound = buffer_to_mem.second;
if (const Ramp *ramp_index = index.as<Ramp>()) {
// In case index is base + stride * i.
// Non inclusive range.
index = Add::make(
ramp_index->base,
Mul::make(ramp_index->stride, make_const(ramp_index->stride.type(),
ramp_index->lanes - 1)));
}
// Try to simplify index and bound.
index = ir::Simplify(index);
upper_bound = ir::Simplify(upper_bound);
// Cast to the same type - signed, to be able to check lower bound.
index = Cast::make(Int(64), index);
upper_bound = Cast::make(Int(64), upper_bound);
// Looks like a lower bound should always be zero after normalization.
Expr lower_bound = make_zero(Int(64));
Expr current_condition =
And::make(GE::make(index, lower_bound), LT::make(index, upper_bound));
condition =
!i ? current_condition : And::make(condition, current_condition);
}
return condition;
}
// Whether we process store value recursively.
bool process_store_{false};
// Whether we face tvm_if_then_else intrinsic.
bool unsafe_rewritten_{false};
// Pool which collects the pair of index and shape for specific store/load.
std::vector<std::pair<Expr, Expr>> store_scope_bound_collector_;
// Error message.
const char *const error_message_ = "OUT OF THE BOUNDS";
// Hashtable which maps buffer_var to shape.
std::unordered_map<const Variable *, Expr> mem_to_shape_;
};
Stmt InstrumentBoundCheckers(Stmt stmt) {
BoundCollector bound_collector;
// At first walk recursively and collect bound attributes.
bound_collector.Visit(stmt);
return BoundChecker(bound_collector.mem_to_shape).Mutate(stmt);
}
} // namespace ir
} // namespace tvm
...@@ -31,7 +31,8 @@ using intrinsic::tvm_address_of; ...@@ -31,7 +31,8 @@ using intrinsic::tvm_address_of;
class StorageFlattener : public IRMutator { class StorageFlattener : public IRMutator {
public: public:
explicit StorageFlattener(Map<Tensor, Buffer> extern_buffer, explicit StorageFlattener(Map<Tensor, Buffer> extern_buffer,
int cache_line_size) { int cache_line_size, bool create_bound_attributes)
: create_bound_attributes_(create_bound_attributes) {
for (auto kv : extern_buffer) { for (auto kv : extern_buffer) {
BufferEntry e; BufferEntry e;
e.buffer = kv.second; e.buffer = kv.second;
...@@ -101,6 +102,8 @@ class StorageFlattener : public IRMutator { ...@@ -101,6 +102,8 @@ class StorageFlattener : public IRMutator {
} }
Stmt Mutate_(const Provide* op, const Stmt& s) final { Stmt Mutate_(const Provide* op, const Stmt& s) final {
if (create_bound_attributes_)
shape_collector_.clear();
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Provide>(); op = stmt.as<Provide>();
TensorKey key{op->func, op->value_index}; TensorKey key{op->func, op->value_index};
...@@ -117,7 +120,20 @@ class StorageFlattener : public IRMutator { ...@@ -117,7 +120,20 @@ class StorageFlattener : public IRMutator {
{e.buffer->data, op->value}, {e.buffer->data, op->value},
Call::Intrinsic)); Call::Intrinsic));
} else { } else {
return e.buffer.vstore(e.RelIndex(op->args), op->value); Stmt body = e.buffer.vstore(e.RelIndex(op->args), op->value);
if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) {
shape_collector_.push_back(
std::make_pair(e.buffer->data, e.buffer->shape));
}
// To create bound attribute collector should has at least one item.
if (create_bound_attributes_ && shape_collector_.size()) {
for (size_t i = 0; i < shape_collector_.size(); ++i) {
body = AttrStmt::make(
shape_collector_[i].first, ir::attr::buffer_bound,
MakeBound(e.buffer->dtype, shape_collector_[i].second), body);
}
}
return body;
} }
} }
...@@ -216,6 +232,11 @@ class StorageFlattener : public IRMutator { ...@@ -216,6 +232,11 @@ class StorageFlattener : public IRMutator {
ret = AttrStmt::make( ret = AttrStmt::make(
e.buffer->data, attr::storage_scope, e.buffer->data, attr::storage_scope,
StringImm::make(e.buffer->scope), ret); StringImm::make(e.buffer->scope), ret);
if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) {
ret = AttrStmt::make(e.buffer->data, ir::attr::buffer_bound,
MakeBound(e.buffer->dtype, e.buffer->shape), ret);
}
return ret; return ret;
} }
} }
...@@ -254,6 +275,11 @@ class StorageFlattener : public IRMutator { ...@@ -254,6 +275,11 @@ class StorageFlattener : public IRMutator {
const BufferEntry& e = it->second; const BufferEntry& e = it->second;
CHECK(!e.released) CHECK(!e.released)
<< "Read a buffer that is already out of scope"; << "Read a buffer that is already out of scope";
if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) {
shape_collector_.push_back(
std::make_pair(e.buffer->data, e.buffer->shape));
}
return e.buffer.vload(e.RelIndex(op->args), e.buffer->dtype); return e.buffer.vload(e.RelIndex(op->args), e.buffer->dtype);
} else { } else {
return expr; return expr;
...@@ -429,6 +455,31 @@ class StorageFlattener : public IRMutator { ...@@ -429,6 +455,31 @@ class StorageFlattener : public IRMutator {
} }
} }
}; };
bool ShapeIsValid(const Array<Expr> &shape) {
// Zero-dimensional tensor does not need boundary check.
if (!shape.size())
return false;
for (size_t i = 0; i < shape.size(); ++i) {
if (!shape[i].defined() || !shape[i].type().is_scalar() ||
is_negative_const(shape[i])) {
return false;
}
}
return true;
}
Expr MakeBound(const Type &type, const Array<Expr> &shape) {
// We have already checked the shape size to be greater then 0.
Expr bound = Mul::make(make_const(shape[0].type(), type.lanes()), shape[0]);
for (size_t i = 1; i < shape.size(); ++i) {
bound = Mul::make(
bound, Mul::make(make_const(bound.type(), type.lanes()), shape[i]));
}
return bound;
}
// The buffer assignment map // The buffer assignment map
// Variable remap // Variable remap
std::unordered_map<const Variable*, Expr> var_remap_; std::unordered_map<const Variable*, Expr> var_remap_;
...@@ -440,16 +491,21 @@ class StorageFlattener : public IRMutator { ...@@ -440,16 +491,21 @@ class StorageFlattener : public IRMutator {
std::unordered_map<const Node*, std::string> storage_scope_; std::unordered_map<const Node*, std::string> storage_scope_;
// The current thread scope. // The current thread scope.
std::vector<ThreadScope> curr_thread_scope_; std::vector<ThreadScope> curr_thread_scope_;
// Collects shapes.
std::vector<std::pair<VarExpr, Array<Expr>>> shape_collector_;
// The size of cacheline // The size of cacheline
int cache_line_size_; int cache_line_size_;
// The current stage is an OpenGL shader. // The current stage is an OpenGL shader.
bool is_opengl_{false}; bool is_opengl_{false};
// Whether to mark load/store with theirs bounds.
bool create_bound_attributes_{false};
}; };
Stmt StorageFlatten(Stmt stmt, Stmt StorageFlatten(Stmt stmt, Map<Tensor, Buffer> extern_buffer,
Map<Tensor, Buffer> extern_buffer, int cache_line_size, bool create_bound_attributes) {
int cache_line_size) { stmt =
stmt = StorageFlattener(extern_buffer, cache_line_size).Mutate(stmt); StorageFlattener(extern_buffer, cache_line_size, create_bound_attributes)
.Mutate(stmt);
return stmt; return stmt;
} }
......
...@@ -348,6 +348,30 @@ def test_rank_zero(): ...@@ -348,6 +348,30 @@ def test_rank_zero():
tvm.testing.assert_allclose(d.asnumpy(), d_np) tvm.testing.assert_allclose(d.asnumpy(), d_np)
check_llvm(64) check_llvm(64)
def test_rank_zero_bound_checkers():
def check_llvm(n):
if not tvm.module.enabled("llvm"):
return
with tvm.build_config(instrument_bound_checkers=True):
A = tvm.placeholder((n, ), name='A')
scale = tvm.placeholder((), name='scale')
k = tvm.reduce_axis((0, n), name="k")
C = tvm.compute((), lambda : tvm.sum(A[k] * scale, axis=k), name="C")
D = tvm.compute((), lambda : C + 1)
s = tvm.create_schedule(D.op)
# build and invoke the kernel.
f = tvm.build(s, [A, scale, D], "llvm")
ctx = tvm.cpu(0)
# launch the kernel.
a = tvm.nd.array(np.random.randint(0, 2, size=(n,)).astype(A.dtype), ctx)
sc = tvm.nd.array(
np.random.randint(0, 2, size=()).astype(scale.dtype), ctx)
d = tvm.nd.empty((), D.dtype, ctx)
f(a, sc, d)
d_np = np.sum(a.asnumpy()) * sc.asnumpy() + 1
tvm.testing.assert_allclose(d.asnumpy(), d_np)
check_llvm(64)
def test_alignment(): def test_alignment():
n = tvm.convert(1024) n = tvm.convert(1024)
...@@ -367,6 +391,7 @@ if __name__ == "__main__": ...@@ -367,6 +391,7 @@ if __name__ == "__main__":
test_llvm_import() test_llvm_import()
test_alignment() test_alignment()
test_rank_zero() test_rank_zero()
test_rank_zero_bound_checkers()
test_llvm_bool() test_llvm_bool()
test_llvm_persist_parallel() test_llvm_persist_parallel()
test_llvm_select() test_llvm_select()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment