Commit 00506a62 by Tianqi Chen Committed by GitHub

[IR] Add body to AssertStmt (#220)

* [IR] Add body to AssertStmt

* fix lint
parent c9da7254
Subproject commit 860199eea031a4ea694b8fce03ad0bf8127910ac
Subproject commit 36ecc1eec0898411ae70e98c315b03247d5fb4a0
......@@ -119,9 +119,9 @@ class IntSet : public NodeRef {
*/
struct ModularEntry {
/*! \brief The base */
int base;
int base{0};
/*! \brief linear co-efficient */
int coeff;
int coeff{1};
/*! \return entry represent everything */
static ModularEntry everything() {
......
......@@ -68,7 +68,10 @@ class Buffer : public NodeRef {
class BufferNode : public Node {
public:
// Data fields.
/*! \brief The pointer to the head of the data */
/*!
* \brief The pointer to the head of the data
* \sa data_alignment The alignment of data in bytes.
*/
Var data;
/*! \brief data type in the content of the tensor */
Type dtype;
......@@ -86,8 +89,13 @@ class BufferNode : public Node {
std::string name;
/*! \brief storage scope of the buffer, if other than global */
std::string scope;
/*! \brief Alignment multiple in terms of dtype elements (including lanes) */
int offset_alignment;
/*! \brief Alignment requirement of data pointer in bytes. */
int data_alignment;
/*!
* \brief Factor of elem_offset field,
* elem_offset is guaranteed to be multiple of offset_factor.
*/
int offset_factor;
/*! \brief constructor */
BufferNode() {}
......@@ -99,9 +107,12 @@ class BufferNode : public Node {
v->Visit("elem_offset", &elem_offset);
v->Visit("name", &name);
v->Visit("scope", &scope);
v->Visit("offset_alignment", &offset_alignment);
v->Visit("data_alignment", &data_alignment);
v->Visit("offset_factor", &offset_factor);
}
// User can specify data_alignment and offset_factor to be 0
// A default value will be picked.
static Buffer make(Var ptr,
Type dtype,
Array<Expr> shape,
......@@ -109,7 +120,8 @@ class BufferNode : public Node {
Expr byte_offset,
std::string name,
std::string scope,
int offset_alignment);
int data_alignment,
int offset_factor);
static constexpr const char* _type_key = "Buffer";
TVM_DECLARE_NODE_TYPE_INFO(BufferNode, Node);
......
......@@ -135,7 +135,7 @@ struct TensorKey {
}
};
/*! \brief namespace of possible attribute sin AttrStmt.type_key */
/*! \brief namespace of possible attribute sin AttrStmt.attr_key */
namespace attr {
// The above attr does not pass to ir stage.
/*! \brief Mark launching extent of thread, used by device API. */
......
......@@ -390,7 +390,8 @@ def decl_buffer(shape,
strides=None,
elem_offset=None,
scope="",
offset_alignment=0):
data_alignment=0,
offset_factor=0):
"""Decleare a new symbolic buffer.
Normally buffer is created automatically during lower and build.
......@@ -423,8 +424,15 @@ def decl_buffer(shape,
The storage scope of the buffer, if not global.
If scope equals empty string, it means it is global memory.
offset_alignment: int, optional
The alignment of offset
data_alignment: int, optional
The alignment of data pointer in bytes.
If 0 is passed, the alignment will be set to TVM's internal default.
offset_factor: int, optional
The factor of elem_offset field, when set,
elem_offset is required to be multiple of offset_factor.
If 0 is pssed, the alignment will be set to 1.
if non-zero is passed, we will created a Var for elem_offset if elem_offset is not None.
Returns
-------
......@@ -447,11 +455,14 @@ def decl_buffer(shape,
shape = (shape,) if isinstance(shape, (_expr.Expr, _Integral)) else shape
dtype = float32 if dtype is None else dtype
strides = () if strides is None else strides
if offset_factor != 0 and elem_offset is None:
elem_offset = var('%s_elem_offset' % name, shape[0].dtype)
if data is None:
data = var(name, "handle")
return _api_internal._Buffer(
data, dtype, shape, strides, elem_offset, name, scope, offset_alignment)
data, dtype, shape, strides, elem_offset, name, scope,
data_alignment, offset_factor)
def _IterVar(dom, name, iter_type, thread_tag=''):
......
......@@ -26,7 +26,8 @@ class BuildConfig(object):
'auto_unroll_max_step': 0,
'auto_unroll_min_depth': 1,
'unroll_explicit': True,
'detect_global_barrier': False
'detect_global_barrier': False,
'offset_factor': 0
}
def __init__(self, **kwargs):
self._old_scope = None
......@@ -76,6 +77,10 @@ def build_config(**kwargs):
detect_global_barrier: bool, default=True
Whether detect global barrier.
offset_factor: int, default=0
The factor used in default buffer declaration.
If specified as 0, offset field is not used.
Returns
-------
config: BuildConfig
......@@ -105,10 +110,12 @@ def get_binds(args, binds=None):
The list of symbolic buffers of arguments.
"""
binds = {} if binds is None else binds.copy()
offset_factor = BuildConfig.current.offset_factor
arg_list = []
for x in args:
if isinstance(x, tensor.Tensor):
buf = api.decl_buffer(x.shape, dtype=x.dtype, name=x.name)
buf = api.decl_buffer(x.shape, dtype=x.dtype, name=x.name,
offset_factor=offset_factor)
assert x not in binds
binds[x] = buf
arg_list.append(buf)
......
......@@ -143,7 +143,7 @@ REGISTER_MAKE2(Cast);
REGISTER_MAKE2(Broadcast);
REGISTER_MAKE3(Let);
REGISTER_MAKE3(LetStmt);
REGISTER_MAKE2(AssertStmt);
REGISTER_MAKE3(AssertStmt);
REGISTER_MAKE3(ProducerConsumer);
REGISTER_MAKE5(Allocate);
REGISTER_MAKE4(Provide);
......
......@@ -152,7 +152,8 @@ TVM_REGISTER_API("_Buffer")
args[4],
args[5],
args[6],
args[7]);
args[7],
args[8]);
});
TVM_REGISTER_API("_Tensor")
......
......@@ -724,6 +724,7 @@ void CodeGenC::VisitStmt_(const AssertStmt* op) {
} else {
stream << "assert(" << cond << ");\n";
}
this->PrintStmt(op->body);
}
void CodeGenC::VisitStmt_(const For* op) {
......
......@@ -1377,6 +1377,31 @@ void CodeGenLLVM::VisitStmt_(const AssertStmt* op) {
builder_->CreateRet(llvm::ConstantInt::getSigned(t_int32_, -1));
// otherwise set it to be new end.
builder_->SetInsertPoint(end_block);
// Detect useful invariant pattern and use them to visit child.
// Pattern: Var % const == 0
// TODO(tqchen) move these pattern to a generic scope info visitor.
if (const EQ* eq = op->condition.as<EQ>()) {
const Mod* mod = eq->a.as<Mod>();
int64_t factor, offset;
if (mod && arith::GetConst(eq->b, &offset)) {
const Variable *var = mod->a.as<Variable>();
if (var && arith::GetConst(mod->b, &factor)) {
arith::ModularEntry old = align_map_[var];
if (factor > old.coeff) {
arith::ModularEntry e;
e.coeff = static_cast<int>(factor);
e.base = static_cast<int>(offset);
// new alignment info,
align_map_[var] = e;
this->VisitStmt(op->body);
// restore old info
align_map_[var] = old;
return;
}
}
}
}
this->VisitStmt(op->body);
}
void CodeGenLLVM::VisitStmt_(const LetStmt* op) {
......
......@@ -456,6 +456,7 @@ void CodeGenStackVM::VisitStmt_(const AssertStmt *op) {
this->Push(op->condition);
this->PushOp(StackVM::ASSERT, sid);
}
this->Push(op->body);
}
void CodeGenStackVM::VisitStmt_(const AttrStmt *op) {
......
......@@ -3,6 +3,7 @@
* \file buffer.cc
*/
#include <tvm/buffer.h>
#include <tvm/runtime/device_api.h>
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
......@@ -26,7 +27,9 @@ Buffer decl_buffer(Array<Expr> shape,
shape,
Array<Expr>(),
Expr(),
name, "", 0);
name,
"",
0, 0);
}
// The buffer offset in convention of number of elements of
......@@ -124,6 +127,7 @@ Buffer Buffer::MakeSlice(Array<Expr> begins, Array<Expr> extents) const {
elem_offset,
n->name + "_slice",
n->scope,
n->data_alignment,
0);
}
......@@ -134,7 +138,8 @@ Buffer BufferNode::make(Var data,
Expr elem_offset,
std::string name,
std::string scope,
int offset_alignment) {
int data_alignment,
int offset_factor) {
auto n = std::make_shared<BufferNode>();
n->data = std::move(data);
n->dtype = dtype;
......@@ -145,11 +150,15 @@ Buffer BufferNode::make(Var data,
if (!elem_offset.defined()) {
elem_offset = make_const(n->shape[0].type(), 0);
}
if (offset_alignment == 0) {
offset_alignment = 1;
if (data_alignment == 0) {
data_alignment = runtime::kAllocAlignment;
}
if (offset_factor == 0) {
offset_factor = 1;
}
n->elem_offset = elem_offset;
n->offset_alignment = offset_alignment;
n->data_alignment = data_alignment;
n->offset_factor = offset_factor;
return Buffer(n);
}
......
......@@ -24,7 +24,7 @@ void BinderAddAssert(Expr cond,
if (!is_one(cond)) {
std::ostringstream os;
os << "Argument " << arg_name << " has an unsatisfied constraint";
asserts->emplace_back(AssertStmt::make(cond, os.str()));
asserts->emplace_back(AssertStmt::make(cond, os.str(), Evaluate::make(0)));
}
}
......@@ -107,7 +107,14 @@ void ArgBinder::BindBuffer(const Buffer& arg,
this->BindArray(arg->shape, value->shape, arg_name + ".shape");
this->BindArray(arg->strides, value->strides, arg_name + ".strides");
}
this->Bind(arg->elem_offset, value->elem_offset, arg_name + ".elem_offset");
if (Bind_(arg->elem_offset, value->elem_offset, arg_name + ".elem_offset", false)) {
if (arg->offset_factor > 1) {
Expr offset = value->elem_offset;
Expr factor = make_const(offset.type(), arg->offset_factor);
Expr zero = make_zero(offset.type());
BinderAddAssert(offset % factor == zero, arg_name + ".elem_offset", &asserts_);
}
}
}
inline Expr TVMArrayGet(Type t, Var arr, intrinsic::TVMStructFieldKind kind) {
......@@ -117,7 +124,7 @@ inline Expr TVMArrayGet(Type t, Var arr, intrinsic::TVMStructFieldKind kind) {
inline Stmt AssertNull(Var handle, std::string msg) {
return AssertStmt::make(Call::make(
Bool(1), intrinsic::tvm_handle_is_null,
{handle}, Call::PureIntrinsic), msg);
{handle}, Call::PureIntrinsic), msg, Evaluate::make(0));
}
void ArgBinder::BindDLTensor(const Buffer& buffer,
......@@ -136,7 +143,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
ndim_err_msg << arg_name
<< ".ndim is expected to equal "
<< buffer->shape.size();
asserts_.emplace_back(AssertStmt::make(a_ndim == v_ndim, ndim_err_msg.str()));
asserts_.emplace_back(AssertStmt::make(a_ndim == v_ndim, ndim_err_msg.str(), nop));
// type checks
Type dtype = buffer->dtype;
std::ostringstream type_err_msg;
......@@ -147,7 +154,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
UIntImm::make(UInt(8), dtype.bits()) &&
TVMArrayGet(UInt(16), handle, intrinsic::kArrTypeLanes) ==
UIntImm::make(UInt(16), dtype.lanes()));
asserts_.emplace_back(AssertStmt::make(cond, type_err_msg.str()));
asserts_.emplace_back(AssertStmt::make(cond, type_err_msg.str(), nop));
// data field
if (Bind_(buffer->data, TVMArrayGet(Handle(), handle, intrinsic::kArrData),
arg_name + ".data", true)) {
......@@ -156,7 +163,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
// mark alignment of external bufs
init_nest_.emplace_back(AttrStmt::make(
vptr, ir::attr::storage_alignment,
IntImm::make(Int(32), runtime::kAllocAlignment), nop));
IntImm::make(Int(32), buffer->data_alignment), nop));
}
Var v_shape(arg_name + ".shape", Handle());
......@@ -202,11 +209,18 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
TVMArrayGet(UInt(64), handle, intrinsic::kArrByteOffset),
arg_name + ".byte_offset", true);
} else {
Bind_(buffer->elem_offset,
cast(buffer->elem_offset.type(),
(TVMArrayGet(UInt(64), handle, intrinsic::kArrByteOffset) /
make_const(UInt(64), data_bytes))),
arg_name + ".elem_offset", true);
if (Bind_(buffer->elem_offset,
cast(buffer->elem_offset.type(),
(TVMArrayGet(UInt(64), handle, intrinsic::kArrByteOffset) /
make_const(UInt(64), data_bytes))),
arg_name + ".elem_offset", true)) {
if (buffer->offset_factor > 1) {
Expr offset = buffer->elem_offset;
Expr factor = make_const(offset.type(), buffer->offset_factor);
Expr zero = make_zero(offset.type());
BinderAddAssert(offset % factor == zero, arg_name + ".elem_offset", &asserts_);
}
}
}
// device info.
Bind_(device_type,
......
......@@ -118,6 +118,7 @@ class IRDeepCompare :
const AssertStmt* rhs = other.as<AssertStmt>();
if (CompareExpr(op->condition, rhs->condition) != 0) return;
if (CompareExpr(op->message, rhs->message) != 0) return;
if (CompareStmt(op->body, rhs->body) != 0) return;
}
void VisitStmt_(const ProducerConsumer* op, const Stmt& other) final {
......@@ -127,7 +128,6 @@ class IRDeepCompare :
if (CompareStmt(op->body, rhs->body) != 0) return;
}
void VisitStmt_(const Provide* op, const Stmt& other) final {
const Provide* rhs = other.as<Provide>();
if (CompareNodeRef(op->func, rhs->func) != 0) return;
......
......@@ -219,11 +219,14 @@ Stmt IRMutator::Mutate_(const Block* op, const Stmt& s) {
Stmt IRMutator::Mutate_(const AssertStmt *op, const Stmt& s) {
Expr condition = this->Mutate(op->condition);
Expr message = this->Mutate(op->message);
Stmt body = this->Mutate(op->body);
if (condition.same_as(op->condition) && message.same_as(op->message)) {
if (condition.same_as(op->condition) &&
message.same_as(op->message) &&
body.same_as(op->body)) {
return s;
} else {
return AssertStmt::make(condition, message);
return AssertStmt::make(condition, message, body);
}
}
......
......@@ -34,7 +34,10 @@ Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body) {
n->then_case = body;
body = Stmt(n);
} else if (s.as<AssertStmt>()) {
body = Block::make(s, body);
auto n = std::make_shared<AssertStmt>(*s.as<AssertStmt>());
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else if (s.as<Allocate>()) {
auto n = std::make_shared<Allocate>(*s.as<Allocate>());
CHECK(is_no_op(n->body));
......
......@@ -162,6 +162,7 @@ void IRVisitor::Visit_(const Broadcast *op) {
void IRVisitor::Visit_(const AssertStmt *op) {
this->Visit(op->condition);
this->Visit(op->message);
this->Visit(op->body);
}
void IRVisitor::Visit_(const ProducerConsumer *op) {
......
......@@ -19,7 +19,7 @@ namespace tvm {
namespace ir {
inline Stmt MakeAssertEQ(Expr lhs, Expr rhs, std::string msg) {
return AssertStmt::make(lhs == rhs, msg);
return AssertStmt::make(lhs == rhs, msg, Evaluate::make(0));
}
LoweredFunc MakeAPI(Stmt body,
......@@ -100,16 +100,16 @@ LoweredFunc MakeAPI(Stmt body,
seq_check.emplace_back(
AssertStmt::make(tcode == kHandle ||
tcode == kArrayHandle ||
tcode == kNull, msg.str()));
tcode == kNull, msg.str(), nop));
} else if (t.is_int() || t.is_uint()) {
std::ostringstream msg;
msg << "Expect argument " << i << " to be int";
seq_check.emplace_back(AssertStmt::make(tcode == kInt, msg.str()));
seq_check.emplace_back(AssertStmt::make(tcode == kInt, msg.str(), nop));
} else {
CHECK(t.is_float());
std::ostringstream msg;
msg << "Expect argument " << i << " to be float";
seq_check.emplace_back(AssertStmt::make(tcode == kFloat, msg.str()));
seq_check.emplace_back(AssertStmt::make(tcode == kFloat, msg.str(), nop));
}
} else {
args.push_back(v_arg);
......
......@@ -16,10 +16,12 @@ def test_llvm_add_pipeline():
return
# Specifically allow offset to test codepath when offset is available
Ab = tvm.decl_buffer(
A.shape, A.dtype, elem_offset=tvm.var('Aoffset'),
A.shape, A.dtype,
elem_offset=tvm.var('Aoffset'),
offset_factor=8,
name='A')
binds = {A : Ab}
# build and invoke the kernel.
# BUILD and invoke the kernel.
f = tvm.build(s, [Ab, B, C], "llvm", binds=binds)
ctx = tvm.cpu(0)
# launch the kernel.
......@@ -31,7 +33,8 @@ def test_llvm_add_pipeline():
np.testing.assert_allclose(
c.asnumpy(), a.asnumpy() + b.asnumpy())
check_llvm()
with tvm.build_config(offset_factor=4):
check_llvm()
def test_llvm_flip_pipeline():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment