Commit 2f1d709f by Denis Khalikov Committed by Thierry Moreau

[PASS] InstrumentBoundCheckers pass (#2079)

The pass which instruments checkers before
memory accesses (load/store).
This allows to handle invalid memory accesses.

The patch is related to issue:
https://discuss.tvm.ai/t/array-bounds-checking/944
parent 2a5656bf
......@@ -220,6 +220,9 @@ class BuildConfigNode : public Node {
/*! \brief Whether to dump the IR of each pass (only when building from python) */
bool dump_pass_ir = false;
/*! \brief Whether to instrument loads and stores with check for out of the bounds. */
bool instrument_bound_checkers = false;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("data_alignment", &data_alignment);
v->Visit("offset_factor", &offset_factor);
......@@ -232,6 +235,7 @@ class BuildConfigNode : public Node {
v->Visit("detect_global_barrier", &detect_global_barrier);
v->Visit("partition_const_loop", &partition_const_loop);
v->Visit("dump_pass_ir", &dump_pass_ir);
v->Visit("instrument_bound_checkers", &instrument_bound_checkers);
}
static constexpr const char* _type_key = "BuildConfig";
......
......@@ -206,6 +206,8 @@ constexpr const char* scan_init_scope = "scan_init_scope";
* This gives hint to require stride of dim to be k * align + offset.
*/
constexpr const char* buffer_dim_align = "buffer_dim_align";
/*! \brief Mark stores/loads with theirs bounds. */
constexpr const char* buffer_bound = "buffer_bound";
/*!
* \brief Bind the buffer specification to the region of the op
* When this scope occurs, the stmt.node is a Array<NodeRef> = [buffer, tensor]
......
......@@ -181,11 +181,13 @@ Stmt Inline(Stmt stmt,
* \param extern_buffer Map specifies external
* buffer assignment of input and outputs.
* \param cache_line_size The size of CPU cache line.
* \param create_bound_attribute Whether to create bound attributes.
* \return Transformed stmt.
*/
Stmt StorageFlatten(Stmt stmt,
Map<Tensor, Buffer> extern_buffer,
int cache_line_size);
int cache_line_size,
bool create_bound_attribute = false);
/*!
* \brief Remove No Op from the Stmt.
......@@ -235,6 +237,13 @@ Stmt UnrollLoop(Stmt stmt,
Stmt VectorizeLoop(Stmt stmt);
/*!
* \brief instruments bound checkers.
* \param stmt The statment to be instrumented.
* \return Instrumented Stmt.
*/
Stmt InstrumentBoundCheckers(Stmt stmt);
/*!
* \brief Inject virtual thread loops into stmt.
* \param stmt The statment to be transformed.
* \return Transformed stmt.
......
......@@ -125,7 +125,8 @@ class BuildConfig(NodeBase):
"data_alignment": -1,
"restricted_func": True,
"double_buffer_split_loop": 1,
"dump_pass_ir": False
"dump_pass_ir": False,
"instrument_bound_checkers": False
}
_dump_ir = DumpIR()
......@@ -344,7 +345,7 @@ def lower(sch,
for f in lower_phase0:
stmt = f(stmt)
# Phase 1
stmt = ir_pass.StorageFlatten(stmt, binds, 64)
stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers)
stmt = ir_pass.CanonicalSimplify(stmt)
for f in lower_phase1:
stmt = f(stmt)
......@@ -370,6 +371,9 @@ def lower(sch,
stmt = ir_pass.RewriteUnsafeSelect(stmt)
for f in lower_phase3:
stmt = f(stmt)
# Instrument BoundCheckers
if cfg.instrument_bound_checkers:
stmt = ir_pass.InstrumentBoundCheckers(stmt)
if simple_mode:
return stmt
return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func)
......
......@@ -66,6 +66,14 @@ TVM_REGISTER_API("ir_pass.Equal")
}
});
TVM_REGISTER_API("ir_pass.StorageFlatten")
.set_body([](TVMArgs args, TVMRetValue *ret) {
if (args.size() <= 3) {
*ret = StorageFlatten(args[0], args[1], args[2]);
} else {
*ret = StorageFlatten(args[0], args[1], args[2], args[3]);
}
});
TVM_REGISTER_API("ir_pass.AttrsEqual")
.set_body_typed<bool(const NodeRef&, const NodeRef&)>([](const NodeRef& lhs, const NodeRef& rhs) {
......@@ -126,7 +134,6 @@ REGISTER_PASS1(ConvertSSA);
REGISTER_PASS1(VerifySSA);
REGISTER_PASS1(RewriteUnsafeSelect);
REGISTER_PASS4(Inline);
REGISTER_PASS3(StorageFlatten);
REGISTER_PASS4(IRTransform);
REGISTER_PASS1(VectorizeLoop);
REGISTER_PASS5(UnrollLoop);
......@@ -155,5 +162,6 @@ REGISTER_PASS1(CombineContextCall);
REGISTER_PASS2(VerifyMemory);
REGISTER_PASS2(VerifyGPUCode);
REGISTER_PASS1(DecorateDeviceScope);
REGISTER_PASS1(InstrumentBoundCheckers);
} // namespace ir
} // namespace tvm
......@@ -364,7 +364,8 @@ Stmt BuildStmt(Schedule sch,
stmt = ir::InjectPrefetch(stmt);
// Phase 1
stmt = ir::StorageFlatten(stmt, out_binds, 64);
stmt = ir::StorageFlatten(stmt, out_binds, 64,
config->instrument_bound_checkers);
stmt = ir::CanonicalSimplify(stmt);
if (loop_partition) {
stmt = ir::LoopPartition(stmt, config->partition_const_loop);
......@@ -382,6 +383,9 @@ Stmt BuildStmt(Schedule sch,
stmt = ir::RemoveNoOp(stmt);
stmt = ir::RewriteUnsafeSelect(stmt);
if (config->instrument_bound_checkers)
stmt = ir::InstrumentBoundCheckers(stmt);
return stmt;
}
......
/*!
* Copyright (c) 2018 by Contributors
* \file bounds_checker.cc
*/
// Instrument checkers for out of the bounds access.
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <vector>
#include <unordered_map>
#include <utility>
namespace tvm {
namespace ir {
class BoundCollector : public IRVisitor {
public:
BoundCollector() {}
void Visit_(const AttrStmt *op) {
if (op->attr_key == ir::attr::buffer_bound) {
if (const Variable *key = op->node.as<Variable>()) {
mem_to_shape[key] = op->value;
}
}
IRVisitor::Visit_(op);
}
// Hashtable which maps buffer_var to shape.
std::unordered_map<const Variable *, Expr> mem_to_shape;
};
class BoundChecker : public IRMutator {
public:
explicit BoundChecker(
const std::unordered_map<const Variable *, Expr> &mem_to_shape)
: mem_to_shape_(mem_to_shape) {}
Stmt Mutate_(const Allocate *op, const Stmt &s) final {
// If the shape was updated we should update the hashtable.
if (UpdateIsNeeded(op->buffer_var)) {
Update(op->buffer_var, op->extents, op->type);
}
return IRMutator::Mutate_(op, s);
}
Expr Mutate_(const Call *op, const Expr &ex) final {
if (process_store_ && op->is_intrinsic(intrinsic::tvm_if_then_else)) {
unsafe_rewritten_ = true;
}
return IRMutator::Mutate_(op, ex);
}
Stmt Mutate_(const Store *op, const Stmt &s) final {
store_scope_bound_collector_.clear();
process_store_ = true;
unsafe_rewritten_ = false;
IRMutator::Mutate_(op, s);
process_store_ = false;
if (CanInstrument(op->index, op->buffer_var)) {
Collect(op->index, op->buffer_var);
}
// The collector should has at least one item.
if (store_scope_bound_collector_.size()) {
Expr condition = MakeCondition();
if (!condition.as<StringImm>()) {
Stmt nop = Evaluate::make(1);
Stmt then_case =
Store::make(op->buffer_var, op->value, op->index, op->predicate);
Stmt else_case =
AssertStmt::make(condition, StringImm::make(error_message_), nop);
Stmt body = IfThenElse::make(condition, then_case, else_case);
return body;
}
}
return s;
}
Expr Mutate_(const Load *op, const Expr &ex) final {
if (CanInstrument(op->index, op->buffer_var)) {
Collect(op->index, op->buffer_var);
}
return IRMutator::Mutate_(op, ex);
}
private:
bool UpdateIsNeeded(const VarExpr &buffer_var) const {
return (buffer_var.defined() && mem_to_shape_.count(buffer_var.get()));
}
void Update(const VarExpr &buffer_var, const Array<Expr> &new_shape,
const Type &type) {
// Sanity check at first.
if (!new_shape.size()) {
return;
}
for (size_t i = 0; i < new_shape.size(); ++i) {
if (!new_shape[0].defined() || !new_shape[i].type().is_scalar() ||
is_negative_const(new_shape[i])) {
return;
}
}
// Scalarize the shape.
Expr shape = Mul::make(make_const(UInt(64), type.lanes()),
Cast::make(UInt(64), new_shape[0]));
for (size_t i = 1; i < new_shape.size(); ++i) {
// Cast to unsigned to avoid integer overlow at frist.
shape = Mul::make(shape, Mul::make(make_const(UInt(64), type.lanes()),
Cast::make(UInt(64), new_shape[i])));
}
mem_to_shape_[buffer_var.get()] = shape;
}
bool IndexIsValid(const Expr &index) const {
if (!index.defined()) {
return false;
}
if (const Ramp *ramp_index = index.as<Ramp>()) {
return ramp_index->base.defined() &&
ramp_index->base.type().is_scalar() &&
ramp_index->stride.defined() &&
ramp_index->stride.type().is_scalar() && (ramp_index->lanes > 0);
}
return true;
}
bool CanInstrument(const Expr &index, const VarExpr &buffer_var) const {
return buffer_var.defined() && mem_to_shape_.count(buffer_var.get()) &&
IndexIsValid(index) && !unsafe_rewritten_;
}
void Collect(Expr index, VarExpr buffer_var) {
store_scope_bound_collector_.push_back(
std::make_pair(index, mem_to_shape_[buffer_var.get()]));
}
Expr MakeCondition() {
Expr condition;
for (size_t i = 0; i < store_scope_bound_collector_.size(); ++i) {
std::pair<Expr, Expr> buffer_to_mem = store_scope_bound_collector_[i];
Expr index = buffer_to_mem.first;
Expr upper_bound = buffer_to_mem.second;
if (const Ramp *ramp_index = index.as<Ramp>()) {
// In case index is base + stride * i.
// Non inclusive range.
index = Add::make(
ramp_index->base,
Mul::make(ramp_index->stride, make_const(ramp_index->stride.type(),
ramp_index->lanes - 1)));
}
// Try to simplify index and bound.
index = ir::Simplify(index);
upper_bound = ir::Simplify(upper_bound);
// Cast to the same type - signed, to be able to check lower bound.
index = Cast::make(Int(64), index);
upper_bound = Cast::make(Int(64), upper_bound);
// Looks like a lower bound should always be zero after normalization.
Expr lower_bound = make_zero(Int(64));
Expr current_condition =
And::make(GE::make(index, lower_bound), LT::make(index, upper_bound));
condition =
!i ? current_condition : And::make(condition, current_condition);
}
return condition;
}
// Whether we process store value recursively.
bool process_store_{false};
// Whether we face tvm_if_then_else intrinsic.
bool unsafe_rewritten_{false};
// Pool which collects the pair of index and shape for specific store/load.
std::vector<std::pair<Expr, Expr>> store_scope_bound_collector_;
// Error message.
const char *const error_message_ = "OUT OF THE BOUNDS";
// Hashtable which maps buffer_var to shape.
std::unordered_map<const Variable *, Expr> mem_to_shape_;
};
Stmt InstrumentBoundCheckers(Stmt stmt) {
BoundCollector bound_collector;
// At first walk recursively and collect bound attributes.
bound_collector.Visit(stmt);
return BoundChecker(bound_collector.mem_to_shape).Mutate(stmt);
}
} // namespace ir
} // namespace tvm
......@@ -31,7 +31,8 @@ using intrinsic::tvm_address_of;
class StorageFlattener : public IRMutator {
public:
explicit StorageFlattener(Map<Tensor, Buffer> extern_buffer,
int cache_line_size) {
int cache_line_size, bool create_bound_attributes)
: create_bound_attributes_(create_bound_attributes) {
for (auto kv : extern_buffer) {
BufferEntry e;
e.buffer = kv.second;
......@@ -101,6 +102,8 @@ class StorageFlattener : public IRMutator {
}
Stmt Mutate_(const Provide* op, const Stmt& s) final {
if (create_bound_attributes_)
shape_collector_.clear();
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Provide>();
TensorKey key{op->func, op->value_index};
......@@ -117,7 +120,20 @@ class StorageFlattener : public IRMutator {
{e.buffer->data, op->value},
Call::Intrinsic));
} else {
return e.buffer.vstore(e.RelIndex(op->args), op->value);
Stmt body = e.buffer.vstore(e.RelIndex(op->args), op->value);
if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) {
shape_collector_.push_back(
std::make_pair(e.buffer->data, e.buffer->shape));
}
// To create bound attribute collector should has at least one item.
if (create_bound_attributes_ && shape_collector_.size()) {
for (size_t i = 0; i < shape_collector_.size(); ++i) {
body = AttrStmt::make(
shape_collector_[i].first, ir::attr::buffer_bound,
MakeBound(e.buffer->dtype, shape_collector_[i].second), body);
}
}
return body;
}
}
......@@ -216,6 +232,11 @@ class StorageFlattener : public IRMutator {
ret = AttrStmt::make(
e.buffer->data, attr::storage_scope,
StringImm::make(e.buffer->scope), ret);
if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) {
ret = AttrStmt::make(e.buffer->data, ir::attr::buffer_bound,
MakeBound(e.buffer->dtype, e.buffer->shape), ret);
}
return ret;
}
}
......@@ -254,6 +275,11 @@ class StorageFlattener : public IRMutator {
const BufferEntry& e = it->second;
CHECK(!e.released)
<< "Read a buffer that is already out of scope";
if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) {
shape_collector_.push_back(
std::make_pair(e.buffer->data, e.buffer->shape));
}
return e.buffer.vload(e.RelIndex(op->args), e.buffer->dtype);
} else {
return expr;
......@@ -429,6 +455,31 @@ class StorageFlattener : public IRMutator {
}
}
};
bool ShapeIsValid(const Array<Expr> &shape) {
// Zero-dimensional tensor does not need boundary check.
if (!shape.size())
return false;
for (size_t i = 0; i < shape.size(); ++i) {
if (!shape[i].defined() || !shape[i].type().is_scalar() ||
is_negative_const(shape[i])) {
return false;
}
}
return true;
}
Expr MakeBound(const Type &type, const Array<Expr> &shape) {
// We have already checked the shape size to be greater then 0.
Expr bound = Mul::make(make_const(shape[0].type(), type.lanes()), shape[0]);
for (size_t i = 1; i < shape.size(); ++i) {
bound = Mul::make(
bound, Mul::make(make_const(bound.type(), type.lanes()), shape[i]));
}
return bound;
}
// The buffer assignment map
// Variable remap
std::unordered_map<const Variable*, Expr> var_remap_;
......@@ -440,16 +491,21 @@ class StorageFlattener : public IRMutator {
std::unordered_map<const Node*, std::string> storage_scope_;
// The current thread scope.
std::vector<ThreadScope> curr_thread_scope_;
// Collects shapes.
std::vector<std::pair<VarExpr, Array<Expr>>> shape_collector_;
// The size of cacheline
int cache_line_size_;
// The current stage is an OpenGL shader.
bool is_opengl_{false};
// Whether to mark load/store with theirs bounds.
bool create_bound_attributes_{false};
};
Stmt StorageFlatten(Stmt stmt,
Map<Tensor, Buffer> extern_buffer,
int cache_line_size) {
stmt = StorageFlattener(extern_buffer, cache_line_size).Mutate(stmt);
Stmt StorageFlatten(Stmt stmt, Map<Tensor, Buffer> extern_buffer,
int cache_line_size, bool create_bound_attributes) {
stmt =
StorageFlattener(extern_buffer, cache_line_size, create_bound_attributes)
.Mutate(stmt);
return stmt;
}
......
......@@ -348,6 +348,30 @@ def test_rank_zero():
tvm.testing.assert_allclose(d.asnumpy(), d_np)
check_llvm(64)
def test_rank_zero_bound_checkers():
def check_llvm(n):
if not tvm.module.enabled("llvm"):
return
with tvm.build_config(instrument_bound_checkers=True):
A = tvm.placeholder((n, ), name='A')
scale = tvm.placeholder((), name='scale')
k = tvm.reduce_axis((0, n), name="k")
C = tvm.compute((), lambda : tvm.sum(A[k] * scale, axis=k), name="C")
D = tvm.compute((), lambda : C + 1)
s = tvm.create_schedule(D.op)
# build and invoke the kernel.
f = tvm.build(s, [A, scale, D], "llvm")
ctx = tvm.cpu(0)
# launch the kernel.
a = tvm.nd.array(np.random.randint(0, 2, size=(n,)).astype(A.dtype), ctx)
sc = tvm.nd.array(
np.random.randint(0, 2, size=()).astype(scale.dtype), ctx)
d = tvm.nd.empty((), D.dtype, ctx)
f(a, sc, d)
d_np = np.sum(a.asnumpy()) * sc.asnumpy() + 1
tvm.testing.assert_allclose(d.asnumpy(), d_np)
check_llvm(64)
def test_alignment():
n = tvm.convert(1024)
......@@ -367,6 +391,7 @@ if __name__ == "__main__":
test_llvm_import()
test_alignment()
test_rank_zero()
test_rank_zero_bound_checkers()
test_llvm_bool()
test_llvm_persist_parallel()
test_llvm_select()
......
from nose.tools import raises
import tvm
import numpy as np
def collect_visit(stmt, f):
ret = []
tvm.ir_pass.PostOrderVisit(stmt, lambda x: ret.append(f(x)))
return ret
def lower(sch, args):
binds = {}
arg_list = []
for x in args:
if isinstance(x, tvm.tensor.Tensor):
buf = tvm.decl_buffer(x.shape, dtype=x.dtype, name=x.name)
assert x not in binds
binds[x] = buf
arg_list.append(buf)
else:
raise ValueError("args must be Tensor, Buffer or Var")
sch = sch.normalize()
bounds = tvm.schedule.InferBound(sch)
stmt = tvm.schedule.ScheduleOps(sch, bounds)
stmt = tvm.ir_pass.LoopPartition(stmt, True)
stmt = tvm.ir_pass.StorageFlatten(stmt, binds, 64, True)
stmt = tvm.ir_pass.CanonicalSimplify(stmt)
stmt = tvm.ir_pass.VectorizeLoop(stmt)
stmt = tvm.ir_pass.Simplify(stmt)
return stmt
@raises(Exception)
def test_out_of_bounds_llvm(index_a, index_b):
n = tvm.var("n")
A = tvm.placeholder ((n,), name='A')
B = tvm.placeholder ((n,), name='B')
C = tvm.compute(A.shape, lambda i: A[i + index_a] + B[i + index_b], name='C')
s = tvm.create_schedule (C.op)
tgt = "llvm"
tgt_host = "llvm"
stmt = tvm.lower (s, [A, B, C], simple_mode=True)
print (stmt)
fadd = tvm.build (s, [A, B, C], tgt, target_host=tgt_host, name="myadd")
ctx = tvm.context(tgt, 0)
a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=1024).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros(1024, dtype=C.dtype), ctx)
fadd (a, b, c)
def test_in_bounds_llvm():
n = tvm.var("n")
A = tvm.placeholder ((n,), name='A')
B = tvm.placeholder ((n,), name='B')
C = tvm.compute(A.shape, lambda i: A[i] + B[i], name='C')
s = tvm.create_schedule (C.op)
tgt = "llvm"
tgt_host = "llvm"
stmt = tvm.lower (s, [A, B, C], simple_mode=True)
print (stmt)
fadd = tvm.build (s, [A, B, C], tgt, target_host=tgt_host, name="myadd")
ctx = tvm.context(tgt, 0)
a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=1024).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros(1024, dtype=C.dtype), ctx)
fadd (a, b, c)
@raises(Exception)
def test_out_of_bounds_vectorize_llvm(nn, index_a, index_b):
n = tvm.convert(nn)
a = tvm.placeholder((n), name='a')
b = tvm.placeholder((n), name='b')
c = tvm.compute((n,), lambda i: a[i + index_a] + b[i + index_b], name='c')
s = tvm.create_schedule(c.op)
xo, xi = s[c].split(c.op.axis[0], factor=8)
s[c].parallel(xo)
s[c].vectorize(xi)
tgt = "llvm"
tgt_host = "llvm"
stmt = tvm.lower (s, [a, b, c], simple_mode=True)
print (stmt)
f = tvm.build(s, [a, b, c], tgt, target_host=tgt_host, name="myaddvec")
ctx = tvm.cpu(0)
n = nn
a = tvm.nd.array(np.random.uniform(size=(n)).astype(a.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(n)).astype(a.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=c.dtype), ctx)
f(a, b, c)
def test_in_bounds_vectorize_llvm():
n = 512
lanes = 2
A = tvm.placeholder((n,), name='A', dtype="float32x%d" % lanes)
B = tvm.compute((n,), lambda i: A[i], name='B')
C = tvm.compute((n,), lambda i: B[i] + tvm.const(1, A.dtype), name='C')
s = tvm.create_schedule(C.op)
xo, xi = s[C].split(C.op.axis[0], nparts=2)
_, xi = s[C].split(xi, factor=2)
s[C].parallel(xo)
s[C].vectorize(xi)
s[B].compute_at(s[C], xo)
xo, xi = s[B].split(B.op.axis[0], factor=2)
s[B].vectorize(xi)
# build and invoke the kernel.
lowered_func = tvm.lower (s, [A, C], "llvm", simple_mode=False)
print (lowered_func.body)
f = tvm.build(s, [A, C], "llvm")
ctx = tvm.cpu(0)
# launch the kernel.
a = tvm.nd.empty((n,), A.dtype).copyfrom(
np.random.uniform(size=(n, lanes)))
c = tvm.nd.empty((n,), C.dtype, ctx)
f(a, c)
tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + 1)
def test_in_bounds_loop_partition_basic_llvm():
n = tvm.var('n')
A = tvm.placeholder((n, ), name='A')
B = tvm.placeholder((n, ), name='B')
T = tvm.compute((n, ), lambda i: A[i]+B[i])
s = tvm.create_schedule(T.op)
xo, xi = s[T].split(T.op.axis[0], factor=4)
lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
print (lowered_func.body)
ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, T], "llvm")
a = tvm.nd.array(np.random.uniform(size=(32,)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(32,)).astype(B.dtype), ctx)
t = tvm.nd.empty((32,), T.dtype, ctx)
f(a, b, t)
@raises(Exception)
def test_out_of_bounds_loop_partition_basic_llvm(index_a, index_b):
n = tvm.var('n')
A = tvm.placeholder((n, ), name='A')
B = tvm.placeholder((n, ), name='B')
T = tvm.compute((n, ), lambda i: A[i + index_a]+B[i + index_b])
s = tvm.create_schedule(T.op)
xo, xi = s[T].split(T.op.axis[0], factor=4)
lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
print (lowered_func.body)
ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, T], "llvm")
a = tvm.nd.array(np.random.uniform(size=(32,)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(32,)).astype(B.dtype), ctx)
t = tvm.nd.empty((32,), T.dtype, ctx)
f(a, b, t)
def test_in_bounds_const_loop_partition_ir():
def check_attr_stmt (x):
if isinstance(x, tvm.stmt.AttrStmt) and x.attr_key == "buffer_bound" and str(x.value) == str(n):
return True
return False
def check_branch_stmt (x):
if isinstance(x, tvm.stmt.IfThenElse):
return True
return False
def assert_bound_instrumentation(stmt, f, nums):
count = 0
for i in collect_visit(stmt, f):
if i is True:
count = count + 1
assert (count == nums)
def collect_branch_stmt (x):
if isinstance(x, tvm.stmt.IfThenElse):
branch_collector.append(x)
n = 21
A = tvm.placeholder((n, ), name='A')
B = tvm.placeholder((n, ), name='B')
T = tvm.compute((n, ), lambda i: A[i]+B[i])
s = tvm.create_schedule(T.op)
xo, xi = s[T].split(T.op.axis[0], factor=4)
bounds = tvm.schedule.InferBound(s)
stmt = lower (s, [A, B, T])
# num_attributes = num_buffers * num_splits = 2 * 3
# before instrumentation
assert_bound_instrumentation(stmt, check_attr_stmt, 2 * 3)
assert_bound_instrumentation(stmt, check_branch_stmt, 0)
stmt = tvm.ir_pass.InstrumentBoundCheckers(stmt)
# after instrumentation
assert_bound_instrumentation(stmt, check_attr_stmt, 2 * 3)
assert_bound_instrumentation(stmt, check_branch_stmt, 2)
print (stmt)
branch_collector = list()
collect_visit(stmt, collect_branch_stmt)
assert(len(branch_collector) == 2)
print (branch_collector[0].condition)
print (branch_collector[1].condition)
def test_in_bounds_const_loop_partition_llvm():
with tvm.build_config(instrument_bound_checkers=True, partition_const_loop=True):
n = 21
A = tvm.placeholder((n, ), name='A')
B = tvm.placeholder((n, ), name='B')
T = tvm.compute((n, ), lambda i: A[i]+B[i])
s = tvm.create_schedule(T.op)
xo, xi = s[T].split(T.op.axis[0], factor=4)
lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
print (lowered_func.body)
ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, T], "llvm")
a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(n,)).astype(B.dtype), ctx)
t = tvm.nd.empty((n,), T.dtype, ctx)
f(a, b, t)
@raises(Exception)
def test_out_of_bounds_const_loop_partition_llvm(index_a, index_b):
with tvm.build_config(instrument_bound_checkers=True, partition_const_loop=True):
n = 21
A = tvm.placeholder((n, ), name='A')
B = tvm.placeholder((n, ), name='B')
T = tvm.compute((n, ), lambda i: A[i + index_a]+B[i + index_b])
s = tvm.create_schedule(T.op)
xo, xi = s[T].split(T.op.axis[0], factor=4)
lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
print (lowered_func.body)
ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, T], "llvm")
a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(n,)).astype(B.dtype), ctx)
t = tvm.nd.empty((n,), T.dtype, ctx)
f(a, b, t)
def test_in_bounds_conv_llvm(loop_tiling=False):
HSTR = WSTR = 1
in_channel = 128
kernel_height = kernel_width = 3
out_channel = 64
batch_size = 1
in_height = in_width = 64
out_height = out_width = in_height - kernel_height + 1
data = tvm.placeholder((batch_size, in_channel, in_height, in_width), name='data')
kernel = tvm.placeholder((kernel_height, kernel_width, in_channel,
out_channel), name='kernel')
ic = tvm.reduce_axis((0, in_channel), name='ic')
kh = tvm.reduce_axis((0, kernel_height), name='kh')
kw = tvm.reduce_axis((0, kernel_width), name='kw')
conv = tvm.compute((batch_size, out_channel, out_height, out_width),
lambda n, oc, oh, ow: tvm.sum(data[n, ic, oh*HSTR + kh, ow*WSTR + kw] *
kernel[kh, kw, ic, oc],
axis=[ic, kh, kw]),
name="conv2d")
s = tvm.create_schedule(conv.op)
n, oc, oh, ow = conv.op.axis
if loop_tiling:
oho, owo, ohi, owi = s[conv].tile(oh, ow, 16, 16)
lowered_func = tvm.lower(s, [data, kernel, conv], simple_mode=True)
print (lowered_func.body)
ctx = tvm.cpu (0)
f = tvm.build(s, [data, kernel, conv], "llvm")
data_input = tvm.nd.array(np.random.uniform(
size=(batch_size, in_channel, in_height, in_width)).astype(tvm.float32), ctx)
kernel_input = tvm.nd.array(np.random.uniform(
size=(kernel_height, kernel_width, in_channel, out_channel)).astype(tvm.float32), ctx)
conv_out = tvm.nd.empty ((batch_size, out_channel, out_height, out_width), tvm.float32, ctx)
f(data_input, kernel_input, conv_out)
@raises(Exception)
def test_out_of_bounds_conv_llvm(data_offsets, kernel_offsets, loop_tiling=False):
HSTR = WSTR = 1
in_channel = 128
kernel_height = kernel_width = 3
out_channel = 64
batch_size = 1
in_height = in_width = 64
out_height = out_width = in_height - kernel_height + 1
data = tvm.placeholder((batch_size, in_channel, in_height, in_width), name='data')
kernel = tvm.placeholder((kernel_height, kernel_width, in_channel,
out_channel), name='kernel')
ic = tvm.reduce_axis((0, in_channel), name='ic')
kh = tvm.reduce_axis((0, kernel_height), name='kh')
kw = tvm.reduce_axis((0, kernel_width), name='kw')
conv = tvm.compute((batch_size, out_channel, out_height, out_width),
lambda n, oc, oh, ow: tvm.sum(data[n + data_offsets[0],
ic + data_offsets[1],
oh*HSTR + kh + data_offsets[2],
ow*WSTR + kw + data_offsets[3]]
*
kernel[kh + kernel_offsets[0],
kw + kernel_offsets[1],
ic + kernel_offsets[2],
oc + kernel_offsets[3]],
axis=[ic, kh, kw]),
name="conv2d")
s = tvm.create_schedule(conv.op)
n, oc, oh, ow = conv.op.axis
if loop_tiling:
oho, owo, ohi, owi = s[conv].tile(oh, ow, 16, 16)
lowered_func = tvm.lower(s, [data, kernel, conv], simple_mode=True)
print (lowered_func.body)
ctx = tvm.cpu (0)
f = tvm.build(s, [data, kernel, conv], "llvm")
data_input = tvm.nd.array(np.random.uniform(
size=(batch_size, in_channel, in_height, in_width)).astype(tvm.float32), ctx)
kernel_input = tvm.nd.array(np.random.uniform(
size=(kernel_height, kernel_width, in_channel, out_channel)).astype(tvm.float32), ctx)
conv_out = tvm.nd.empty ((batch_size, out_channel, out_height, out_width), tvm.float32, ctx)
f(data_input, kernel_input, conv_out)
def test_in_bounds_tensors_with_same_shapes1D_llvm():
n = tvm.var('n')
k = tvm.var('k')
m = tvm.var('m')
A = tvm.placeholder((n, ), name='A')
B = tvm.placeholder((k, ), name='B')
T = tvm.compute((m, ), lambda i: A[i]*B[i])
s = tvm.create_schedule(T.op)
lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
print (lowered_func.body)
ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, T], "llvm")
a = tvm.nd.array(np.random.uniform(size=(32, )).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(32,)).astype(B.dtype), ctx)
t = tvm.nd.empty((32,), T.dtype, ctx)
f(a, b, t)
@raises(Exception)
def test_out_of_bounds_tensors_with_diff_shapes1D_llvm(a_shape, b_shape, c_shape):
n = tvm.var('n')
k = tvm.var('k')
m = tvm.var('m')
A = tvm.placeholder((n, ), name='A')
B = tvm.placeholder((k, ), name='B')
T = tvm.compute((m, ), lambda i: A[i]*B[i])
s = tvm.create_schedule(T.op)
lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
print (lowered_func.body)
ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, T], "llvm")
a = tvm.nd.array(np.random.uniform(size=(a_shape,)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(b_shape,)).astype(B.dtype), ctx)
t = tvm.nd.empty((c_shape,), T.dtype, ctx)
f(a, b, t)
def test_in_bounds_tensors_with_same_shapes2D_llvm():
n = tvm.var('n')
k = tvm.var('k')
m = tvm.var('m')
A = tvm.placeholder((n, n), name='A')
B = tvm.placeholder((k, k), name='B')
T = tvm.compute((m, m), lambda i, j: A[i][j]*B[i][j])
s = tvm.create_schedule(T.op)
lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
print (lowered_func.body)
ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, T], "llvm")
a = tvm.nd.array(np.random.uniform(size=(32, 32)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(32, 32)).astype(B.dtype), ctx)
t = tvm.nd.empty((32, 32), T.dtype, ctx)
f(a, b, t)
@raises(Exception)
def test_out_of_bounds_tensors_with_diff_shapes2D_llvm(a_shape, b_shape, c_shape):
n = tvm.var('n')
k = tvm.var('k')
m = tvm.var('m')
A = tvm.placeholder((n, n), name='A')
B = tvm.placeholder((k, k), name='B')
T = tvm.compute((m, m), lambda i, j: A[i][j]*B[i][j])
s = tvm.create_schedule(T.op)
lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
print (lowered_func.body)
ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, T], "llvm")
a = tvm.nd.array(np.random.uniform(size=(a_shape[0],a_shape[1])).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(b_shape[0],b_shape[1])).astype(B.dtype), ctx)
t = tvm.nd.empty((c_shape[0],c_shape[1]), T.dtype, ctx)
f(a, b, t)
def test_in_bounds_tensors_with_same_shapes3D_llvm():
n = tvm.var('n')
k = tvm.var('k')
m = tvm.var('m')
A = tvm.placeholder((n, n, n), name='A')
B = tvm.placeholder((k, k, k), name='B')
T = tvm.compute((m, m, m), lambda i, j, p: A[i][j][p]*B[i][j][p])
s = tvm.create_schedule(T.op)
lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
print (lowered_func.body)
ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, T], "llvm")
a = tvm.nd.array(np.random.uniform(size=(32,32,32)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(32,32,32)).astype(B.dtype), ctx)
t = tvm.nd.empty((32, 32, 32), T.dtype, ctx)
f(a, b, t)
@raises(Exception)
def test_out_of_bounds_tensors_with_diff_shapes3D_llvm(a_shape, b_shape, c_shape):
n = tvm.var('n')
k = tvm.var('k')
m = tvm.var('m')
A = tvm.placeholder((n, n, n), name='A')
B = tvm.placeholder((k, k, k), name='B')
T = tvm.compute((m, m, m), lambda i, j, p: A[i][j][p]*B[i][j][p])
s = tvm.create_schedule(T.op)
lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
print (lowered_func.body)
ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, T], "llvm")
a = tvm.nd.array(np.random.uniform(size=(a_shape[0],a_shape[1], c_shape[2])).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(b_shape[0],b_shape[1], b_shape[2])).astype(B.dtype), ctx)
t = tvm.nd.empty((c_shape[0],c_shape[1],c_shape[2]), T.dtype, ctx)
f(a, b, t)
@raises(Exception)
def test_out_of_bounds_tensors_with_zero_shape_op_with_not_zero_shape_llvm():
if not tvm.module.enabled("llvm"):
return
n = 64
A = tvm.placeholder((n, ), name='A')
scale = tvm.placeholder((), name='scale')
k = tvm.reduce_axis((0, n), name="k")
C = tvm.compute((), lambda : tvm.sum(A[k + k + k] * scale, axis=k), name="C")
D = tvm.compute((), lambda : C + 1)
s = tvm.create_schedule(D.op)
stmt = tvm.lower (s, [A, scale, D], simple_mode=True)
print (stmt)
# build and invoke the kernel.
f = tvm.build(s, [A, scale, D], "llvm")
ctx = tvm.cpu(0)
# launch the kernel.
a = tvm.nd.array(np.random.randint(0, 2, size=(n,)).astype(A.dtype), ctx)
sc = tvm.nd.array(
np.random.randint(0, 2, size=()).astype(scale.dtype), ctx)
d = tvm.nd.empty((), D.dtype, ctx)
f(a, sc, d)
d_np = np.sum(a.asnumpy()) * sc.asnumpy() + 1
tvm.testing.assert_allclose(d.asnumpy(), d_np)
if __name__ == "__main__":
with tvm.build_config(instrument_bound_checkers=True):
# zero scale
test_out_of_bounds_tensors_with_zero_shape_op_with_not_zero_shape_llvm()
# in bound
test_in_bounds_llvm()
# upper bound
test_out_of_bounds_llvm(1, 0)
test_out_of_bounds_llvm(0, 1)
test_out_of_bounds_llvm(1, 1)
test_out_of_bounds_llvm(10000, 0)
test_out_of_bounds_llvm(0, 10000)
test_out_of_bounds_llvm(10000, 10000)
# lower bound
test_out_of_bounds_llvm(-1, 0)
test_out_of_bounds_llvm(0, -1)
test_out_of_bounds_llvm(-1, -1)
test_out_of_bounds_llvm(-10000, 0)
test_out_of_bounds_llvm(0, -10000)
test_out_of_bounds_llvm(-10000, -10000)
# vectorize in bound
test_in_bounds_vectorize_llvm()
# vectorization upper bound
test_out_of_bounds_vectorize_llvm(1024, 1000, 0)
test_out_of_bounds_vectorize_llvm(1024, 0, 10000)
# vectorization lower bound
test_out_of_bounds_vectorize_llvm(1024, -1000, 0)
test_out_of_bounds_vectorize_llvm(1024, 0, -10000)
test_in_bounds_const_loop_partition_llvm()
test_out_of_bounds_const_loop_partition_llvm(1, 0)
test_out_of_bounds_const_loop_partition_llvm(0, 1)
test_out_of_bounds_const_loop_partition_llvm(-1, 0)
test_out_of_bounds_const_loop_partition_llvm(0, -1)
test_in_bounds_loop_partition_basic_llvm()
test_out_of_bounds_loop_partition_basic_llvm(32, 0)
test_out_of_bounds_loop_partition_basic_llvm(0, 32)
test_out_of_bounds_loop_partition_basic_llvm(-32, 0)
test_out_of_bounds_loop_partition_basic_llvm(0, -32)
# conv
test_in_bounds_conv_llvm()
test_out_of_bounds_conv_llvm([1, 0, 0, 0], [0, 0, 0, 0])
test_out_of_bounds_conv_llvm([0, 1, 0, 0], [0, 0, 0, 0])
test_out_of_bounds_conv_llvm([0, 0, 1, 0], [0, 0, 0, 0])
test_out_of_bounds_conv_llvm([0, 0, 0, 1], [0, 0, 0, 0])
test_out_of_bounds_conv_llvm([-1, 0, 0, 0], [0, 0, 0, 0])
test_out_of_bounds_conv_llvm([0, -1, 0, 0], [0, 0, 0, 0])
test_out_of_bounds_conv_llvm([0, 0, -1, 0], [0, 0, 0, 0])
test_out_of_bounds_conv_llvm([0, 0, 0, -1], [0, 0, 0, 0])
test_out_of_bounds_conv_llvm([0, 0, 0, 0], [1, 0, 0, 0])
test_out_of_bounds_conv_llvm([0, 0, 0, 0], [0, 1, 0, 0])
test_out_of_bounds_conv_llvm([0, 0, 0, 0], [0, 0, 1, 0])
test_out_of_bounds_conv_llvm([0, 0, 0, 0], [0, 0, 0, 1])
test_out_of_bounds_conv_llvm([0, 0, 0, 0], [-1, 0, 0, 0])
test_out_of_bounds_conv_llvm([0, 0, 0, 0], [0, -1, 0, 0])
test_out_of_bounds_conv_llvm([0, 0, 0, 0], [0, 0, -1, 0])
test_out_of_bounds_conv_llvm([0, 0, 0, 0], [0, 0, 0, -1])
# loop tiling
test_in_bounds_conv_llvm(True)
test_out_of_bounds_conv_llvm([1, 0, 0, 0], [0, 0, 0, 0], True)
test_out_of_bounds_conv_llvm([0, 1, 0, 0], [0, 0, 0, 0], True)
test_out_of_bounds_conv_llvm([0, 0, 1, 0], [0, 0, 0, 0], True)
test_out_of_bounds_conv_llvm([0, 0, 0, 1], [0, 0, 0, 0], True)
test_out_of_bounds_conv_llvm([-1, 0, 0, 0], [0, 0, 0, 0], True)
test_out_of_bounds_conv_llvm([0, -1, 0, 0], [0, 0, 0, 0], True)
test_out_of_bounds_conv_llvm([0, 0, -1, 0], [0, 0, 0, 0], True)
test_out_of_bounds_conv_llvm([0, 0, 0, -1], [0, 0, 0, 0], True)
test_out_of_bounds_conv_llvm([0, 0, 0, 0], [1, 0, 0, 0], True)
test_out_of_bounds_conv_llvm([0, 0, 0, 0], [0, 1, 0, 0], True)
test_out_of_bounds_conv_llvm([0, 0, 0, 0], [0, 0, 1, 0], True)
test_out_of_bounds_conv_llvm([0, 0, 0, 0], [0, 0, 0, 1], True)
test_out_of_bounds_conv_llvm([0, 0, 0, 0], [-1, 0, 0, 0], True)
test_out_of_bounds_conv_llvm([0, 0, 0, 0], [0, -1, 0, 0], True)
test_out_of_bounds_conv_llvm([0, 0, 0, 0], [0, 0, -1, 0], True)
test_out_of_bounds_conv_llvm([0, 0, 0, 0], [0, 0, 0, -1], True)
# tensors with diff shapes basic operation such as mul
test_out_of_bounds_tensors_with_diff_shapes1D_llvm (32, 64, 64)
test_out_of_bounds_tensors_with_diff_shapes1D_llvm (64, 32, 64)
test_out_of_bounds_tensors_with_diff_shapes2D_llvm([64, 64], [32, 32], [64, 64])
test_out_of_bounds_tensors_with_diff_shapes2D_llvm([32, 32], [64, 64], [64, 64])
test_out_of_bounds_tensors_with_diff_shapes3D_llvm([64, 64, 64], [32, 32, 32], [64, 64, 64])
test_out_of_bounds_tensors_with_diff_shapes3D_llvm([32, 32, 32], [64, 64, 64], [64, 64, 64])
# check tensors with the same shapes
test_in_bounds_tensors_with_same_shapes1D_llvm()
test_in_bounds_tensors_with_same_shapes2D_llvm()
test_in_bounds_tensors_with_same_shapes3D_llvm()
# ir tests
test_in_bounds_const_loop_partition_ir()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment