Commit 00506a62 by Tianqi Chen Committed by GitHub

[IR] Add body to AssertStmt (#220)

* [IR] Add body to AssertStmt

* fix lint
parent c9da7254
Subproject commit 860199eea031a4ea694b8fce03ad0bf8127910ac Subproject commit 36ecc1eec0898411ae70e98c315b03247d5fb4a0
...@@ -119,9 +119,9 @@ class IntSet : public NodeRef { ...@@ -119,9 +119,9 @@ class IntSet : public NodeRef {
*/ */
struct ModularEntry { struct ModularEntry {
/*! \brief The base */ /*! \brief The base */
int base; int base{0};
/*! \brief linear co-efficient */ /*! \brief linear co-efficient */
int coeff; int coeff{1};
/*! \return entry represent everything */ /*! \return entry represent everything */
static ModularEntry everything() { static ModularEntry everything() {
......
...@@ -68,7 +68,10 @@ class Buffer : public NodeRef { ...@@ -68,7 +68,10 @@ class Buffer : public NodeRef {
class BufferNode : public Node { class BufferNode : public Node {
public: public:
// Data fields. // Data fields.
/*! \brief The pointer to the head of the data */ /*!
* \brief The pointer to the head of the data
* \sa data_alignment The alignment of data in bytes.
*/
Var data; Var data;
/*! \brief data type in the content of the tensor */ /*! \brief data type in the content of the tensor */
Type dtype; Type dtype;
...@@ -86,8 +89,13 @@ class BufferNode : public Node { ...@@ -86,8 +89,13 @@ class BufferNode : public Node {
std::string name; std::string name;
/*! \brief storage scope of the buffer, if other than global */ /*! \brief storage scope of the buffer, if other than global */
std::string scope; std::string scope;
/*! \brief Alignment multiple in terms of dtype elements (including lanes) */ /*! \brief Alignment requirement of data pointer in bytes. */
int offset_alignment; int data_alignment;
/*!
* \brief Factor of elem_offset field,
* elem_offset is guaranteed to be multiple of offset_factor.
*/
int offset_factor;
/*! \brief constructor */ /*! \brief constructor */
BufferNode() {} BufferNode() {}
...@@ -99,9 +107,12 @@ class BufferNode : public Node { ...@@ -99,9 +107,12 @@ class BufferNode : public Node {
v->Visit("elem_offset", &elem_offset); v->Visit("elem_offset", &elem_offset);
v->Visit("name", &name); v->Visit("name", &name);
v->Visit("scope", &scope); v->Visit("scope", &scope);
v->Visit("offset_alignment", &offset_alignment); v->Visit("data_alignment", &data_alignment);
v->Visit("offset_factor", &offset_factor);
} }
// User can specify data_alignment and offset_factor to be 0
// A default value will be picked.
static Buffer make(Var ptr, static Buffer make(Var ptr,
Type dtype, Type dtype,
Array<Expr> shape, Array<Expr> shape,
...@@ -109,7 +120,8 @@ class BufferNode : public Node { ...@@ -109,7 +120,8 @@ class BufferNode : public Node {
Expr byte_offset, Expr byte_offset,
std::string name, std::string name,
std::string scope, std::string scope,
int offset_alignment); int data_alignment,
int offset_factor);
static constexpr const char* _type_key = "Buffer"; static constexpr const char* _type_key = "Buffer";
TVM_DECLARE_NODE_TYPE_INFO(BufferNode, Node); TVM_DECLARE_NODE_TYPE_INFO(BufferNode, Node);
......
...@@ -135,7 +135,7 @@ struct TensorKey { ...@@ -135,7 +135,7 @@ struct TensorKey {
} }
}; };
/*! \brief namespace of possible attribute sin AttrStmt.type_key */ /*! \brief namespace of possible attribute sin AttrStmt.attr_key */
namespace attr { namespace attr {
// The above attr does not pass to ir stage. // The above attr does not pass to ir stage.
/*! \brief Mark launching extent of thread, used by device API. */ /*! \brief Mark launching extent of thread, used by device API. */
......
...@@ -390,7 +390,8 @@ def decl_buffer(shape, ...@@ -390,7 +390,8 @@ def decl_buffer(shape,
strides=None, strides=None,
elem_offset=None, elem_offset=None,
scope="", scope="",
offset_alignment=0): data_alignment=0,
offset_factor=0):
"""Decleare a new symbolic buffer. """Decleare a new symbolic buffer.
Normally buffer is created automatically during lower and build. Normally buffer is created automatically during lower and build.
...@@ -423,8 +424,15 @@ def decl_buffer(shape, ...@@ -423,8 +424,15 @@ def decl_buffer(shape,
The storage scope of the buffer, if not global. The storage scope of the buffer, if not global.
If scope equals empty string, it means it is global memory. If scope equals empty string, it means it is global memory.
offset_alignment: int, optional data_alignment: int, optional
The alignment of offset The alignment of data pointer in bytes.
If 0 is passed, the alignment will be set to TVM's internal default.
offset_factor: int, optional
The factor of elem_offset field, when set,
elem_offset is required to be multiple of offset_factor.
If 0 is pssed, the alignment will be set to 1.
if non-zero is passed, we will created a Var for elem_offset if elem_offset is not None.
Returns Returns
------- -------
...@@ -447,11 +455,14 @@ def decl_buffer(shape, ...@@ -447,11 +455,14 @@ def decl_buffer(shape,
shape = (shape,) if isinstance(shape, (_expr.Expr, _Integral)) else shape shape = (shape,) if isinstance(shape, (_expr.Expr, _Integral)) else shape
dtype = float32 if dtype is None else dtype dtype = float32 if dtype is None else dtype
strides = () if strides is None else strides strides = () if strides is None else strides
if offset_factor != 0 and elem_offset is None:
elem_offset = var('%s_elem_offset' % name, shape[0].dtype)
if data is None: if data is None:
data = var(name, "handle") data = var(name, "handle")
return _api_internal._Buffer( return _api_internal._Buffer(
data, dtype, shape, strides, elem_offset, name, scope, offset_alignment) data, dtype, shape, strides, elem_offset, name, scope,
data_alignment, offset_factor)
def _IterVar(dom, name, iter_type, thread_tag=''): def _IterVar(dom, name, iter_type, thread_tag=''):
......
...@@ -26,7 +26,8 @@ class BuildConfig(object): ...@@ -26,7 +26,8 @@ class BuildConfig(object):
'auto_unroll_max_step': 0, 'auto_unroll_max_step': 0,
'auto_unroll_min_depth': 1, 'auto_unroll_min_depth': 1,
'unroll_explicit': True, 'unroll_explicit': True,
'detect_global_barrier': False 'detect_global_barrier': False,
'offset_factor': 0
} }
def __init__(self, **kwargs): def __init__(self, **kwargs):
self._old_scope = None self._old_scope = None
...@@ -76,6 +77,10 @@ def build_config(**kwargs): ...@@ -76,6 +77,10 @@ def build_config(**kwargs):
detect_global_barrier: bool, default=True detect_global_barrier: bool, default=True
Whether detect global barrier. Whether detect global barrier.
offset_factor: int, default=0
The factor used in default buffer declaration.
If specified as 0, offset field is not used.
Returns Returns
------- -------
config: BuildConfig config: BuildConfig
...@@ -105,10 +110,12 @@ def get_binds(args, binds=None): ...@@ -105,10 +110,12 @@ def get_binds(args, binds=None):
The list of symbolic buffers of arguments. The list of symbolic buffers of arguments.
""" """
binds = {} if binds is None else binds.copy() binds = {} if binds is None else binds.copy()
offset_factor = BuildConfig.current.offset_factor
arg_list = [] arg_list = []
for x in args: for x in args:
if isinstance(x, tensor.Tensor): if isinstance(x, tensor.Tensor):
buf = api.decl_buffer(x.shape, dtype=x.dtype, name=x.name) buf = api.decl_buffer(x.shape, dtype=x.dtype, name=x.name,
offset_factor=offset_factor)
assert x not in binds assert x not in binds
binds[x] = buf binds[x] = buf
arg_list.append(buf) arg_list.append(buf)
......
...@@ -143,7 +143,7 @@ REGISTER_MAKE2(Cast); ...@@ -143,7 +143,7 @@ REGISTER_MAKE2(Cast);
REGISTER_MAKE2(Broadcast); REGISTER_MAKE2(Broadcast);
REGISTER_MAKE3(Let); REGISTER_MAKE3(Let);
REGISTER_MAKE3(LetStmt); REGISTER_MAKE3(LetStmt);
REGISTER_MAKE2(AssertStmt); REGISTER_MAKE3(AssertStmt);
REGISTER_MAKE3(ProducerConsumer); REGISTER_MAKE3(ProducerConsumer);
REGISTER_MAKE5(Allocate); REGISTER_MAKE5(Allocate);
REGISTER_MAKE4(Provide); REGISTER_MAKE4(Provide);
......
...@@ -152,7 +152,8 @@ TVM_REGISTER_API("_Buffer") ...@@ -152,7 +152,8 @@ TVM_REGISTER_API("_Buffer")
args[4], args[4],
args[5], args[5],
args[6], args[6],
args[7]); args[7],
args[8]);
}); });
TVM_REGISTER_API("_Tensor") TVM_REGISTER_API("_Tensor")
......
...@@ -724,6 +724,7 @@ void CodeGenC::VisitStmt_(const AssertStmt* op) { ...@@ -724,6 +724,7 @@ void CodeGenC::VisitStmt_(const AssertStmt* op) {
} else { } else {
stream << "assert(" << cond << ");\n"; stream << "assert(" << cond << ");\n";
} }
this->PrintStmt(op->body);
} }
void CodeGenC::VisitStmt_(const For* op) { void CodeGenC::VisitStmt_(const For* op) {
......
...@@ -1377,6 +1377,31 @@ void CodeGenLLVM::VisitStmt_(const AssertStmt* op) { ...@@ -1377,6 +1377,31 @@ void CodeGenLLVM::VisitStmt_(const AssertStmt* op) {
builder_->CreateRet(llvm::ConstantInt::getSigned(t_int32_, -1)); builder_->CreateRet(llvm::ConstantInt::getSigned(t_int32_, -1));
// otherwise set it to be new end. // otherwise set it to be new end.
builder_->SetInsertPoint(end_block); builder_->SetInsertPoint(end_block);
// Detect useful invariant pattern and use them to visit child.
// Pattern: Var % const == 0
// TODO(tqchen) move these pattern to a generic scope info visitor.
if (const EQ* eq = op->condition.as<EQ>()) {
const Mod* mod = eq->a.as<Mod>();
int64_t factor, offset;
if (mod && arith::GetConst(eq->b, &offset)) {
const Variable *var = mod->a.as<Variable>();
if (var && arith::GetConst(mod->b, &factor)) {
arith::ModularEntry old = align_map_[var];
if (factor > old.coeff) {
arith::ModularEntry e;
e.coeff = static_cast<int>(factor);
e.base = static_cast<int>(offset);
// new alignment info,
align_map_[var] = e;
this->VisitStmt(op->body);
// restore old info
align_map_[var] = old;
return;
}
}
}
}
this->VisitStmt(op->body);
} }
void CodeGenLLVM::VisitStmt_(const LetStmt* op) { void CodeGenLLVM::VisitStmt_(const LetStmt* op) {
......
...@@ -456,6 +456,7 @@ void CodeGenStackVM::VisitStmt_(const AssertStmt *op) { ...@@ -456,6 +456,7 @@ void CodeGenStackVM::VisitStmt_(const AssertStmt *op) {
this->Push(op->condition); this->Push(op->condition);
this->PushOp(StackVM::ASSERT, sid); this->PushOp(StackVM::ASSERT, sid);
} }
this->Push(op->body);
} }
void CodeGenStackVM::VisitStmt_(const AttrStmt *op) { void CodeGenStackVM::VisitStmt_(const AttrStmt *op) {
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
* \file buffer.cc * \file buffer.cc
*/ */
#include <tvm/buffer.h> #include <tvm/buffer.h>
#include <tvm/runtime/device_api.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
...@@ -26,7 +27,9 @@ Buffer decl_buffer(Array<Expr> shape, ...@@ -26,7 +27,9 @@ Buffer decl_buffer(Array<Expr> shape,
shape, shape,
Array<Expr>(), Array<Expr>(),
Expr(), Expr(),
name, "", 0); name,
"",
0, 0);
} }
// The buffer offset in convention of number of elements of // The buffer offset in convention of number of elements of
...@@ -124,6 +127,7 @@ Buffer Buffer::MakeSlice(Array<Expr> begins, Array<Expr> extents) const { ...@@ -124,6 +127,7 @@ Buffer Buffer::MakeSlice(Array<Expr> begins, Array<Expr> extents) const {
elem_offset, elem_offset,
n->name + "_slice", n->name + "_slice",
n->scope, n->scope,
n->data_alignment,
0); 0);
} }
...@@ -134,7 +138,8 @@ Buffer BufferNode::make(Var data, ...@@ -134,7 +138,8 @@ Buffer BufferNode::make(Var data,
Expr elem_offset, Expr elem_offset,
std::string name, std::string name,
std::string scope, std::string scope,
int offset_alignment) { int data_alignment,
int offset_factor) {
auto n = std::make_shared<BufferNode>(); auto n = std::make_shared<BufferNode>();
n->data = std::move(data); n->data = std::move(data);
n->dtype = dtype; n->dtype = dtype;
...@@ -145,11 +150,15 @@ Buffer BufferNode::make(Var data, ...@@ -145,11 +150,15 @@ Buffer BufferNode::make(Var data,
if (!elem_offset.defined()) { if (!elem_offset.defined()) {
elem_offset = make_const(n->shape[0].type(), 0); elem_offset = make_const(n->shape[0].type(), 0);
} }
if (offset_alignment == 0) { if (data_alignment == 0) {
offset_alignment = 1; data_alignment = runtime::kAllocAlignment;
}
if (offset_factor == 0) {
offset_factor = 1;
} }
n->elem_offset = elem_offset; n->elem_offset = elem_offset;
n->offset_alignment = offset_alignment; n->data_alignment = data_alignment;
n->offset_factor = offset_factor;
return Buffer(n); return Buffer(n);
} }
......
...@@ -24,7 +24,7 @@ void BinderAddAssert(Expr cond, ...@@ -24,7 +24,7 @@ void BinderAddAssert(Expr cond,
if (!is_one(cond)) { if (!is_one(cond)) {
std::ostringstream os; std::ostringstream os;
os << "Argument " << arg_name << " has an unsatisfied constraint"; os << "Argument " << arg_name << " has an unsatisfied constraint";
asserts->emplace_back(AssertStmt::make(cond, os.str())); asserts->emplace_back(AssertStmt::make(cond, os.str(), Evaluate::make(0)));
} }
} }
...@@ -107,7 +107,14 @@ void ArgBinder::BindBuffer(const Buffer& arg, ...@@ -107,7 +107,14 @@ void ArgBinder::BindBuffer(const Buffer& arg,
this->BindArray(arg->shape, value->shape, arg_name + ".shape"); this->BindArray(arg->shape, value->shape, arg_name + ".shape");
this->BindArray(arg->strides, value->strides, arg_name + ".strides"); this->BindArray(arg->strides, value->strides, arg_name + ".strides");
} }
this->Bind(arg->elem_offset, value->elem_offset, arg_name + ".elem_offset"); if (Bind_(arg->elem_offset, value->elem_offset, arg_name + ".elem_offset", false)) {
if (arg->offset_factor > 1) {
Expr offset = value->elem_offset;
Expr factor = make_const(offset.type(), arg->offset_factor);
Expr zero = make_zero(offset.type());
BinderAddAssert(offset % factor == zero, arg_name + ".elem_offset", &asserts_);
}
}
} }
inline Expr TVMArrayGet(Type t, Var arr, intrinsic::TVMStructFieldKind kind) { inline Expr TVMArrayGet(Type t, Var arr, intrinsic::TVMStructFieldKind kind) {
...@@ -117,7 +124,7 @@ inline Expr TVMArrayGet(Type t, Var arr, intrinsic::TVMStructFieldKind kind) { ...@@ -117,7 +124,7 @@ inline Expr TVMArrayGet(Type t, Var arr, intrinsic::TVMStructFieldKind kind) {
inline Stmt AssertNull(Var handle, std::string msg) { inline Stmt AssertNull(Var handle, std::string msg) {
return AssertStmt::make(Call::make( return AssertStmt::make(Call::make(
Bool(1), intrinsic::tvm_handle_is_null, Bool(1), intrinsic::tvm_handle_is_null,
{handle}, Call::PureIntrinsic), msg); {handle}, Call::PureIntrinsic), msg, Evaluate::make(0));
} }
void ArgBinder::BindDLTensor(const Buffer& buffer, void ArgBinder::BindDLTensor(const Buffer& buffer,
...@@ -136,7 +143,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, ...@@ -136,7 +143,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
ndim_err_msg << arg_name ndim_err_msg << arg_name
<< ".ndim is expected to equal " << ".ndim is expected to equal "
<< buffer->shape.size(); << buffer->shape.size();
asserts_.emplace_back(AssertStmt::make(a_ndim == v_ndim, ndim_err_msg.str())); asserts_.emplace_back(AssertStmt::make(a_ndim == v_ndim, ndim_err_msg.str(), nop));
// type checks // type checks
Type dtype = buffer->dtype; Type dtype = buffer->dtype;
std::ostringstream type_err_msg; std::ostringstream type_err_msg;
...@@ -147,7 +154,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, ...@@ -147,7 +154,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
UIntImm::make(UInt(8), dtype.bits()) && UIntImm::make(UInt(8), dtype.bits()) &&
TVMArrayGet(UInt(16), handle, intrinsic::kArrTypeLanes) == TVMArrayGet(UInt(16), handle, intrinsic::kArrTypeLanes) ==
UIntImm::make(UInt(16), dtype.lanes())); UIntImm::make(UInt(16), dtype.lanes()));
asserts_.emplace_back(AssertStmt::make(cond, type_err_msg.str())); asserts_.emplace_back(AssertStmt::make(cond, type_err_msg.str(), nop));
// data field // data field
if (Bind_(buffer->data, TVMArrayGet(Handle(), handle, intrinsic::kArrData), if (Bind_(buffer->data, TVMArrayGet(Handle(), handle, intrinsic::kArrData),
arg_name + ".data", true)) { arg_name + ".data", true)) {
...@@ -156,7 +163,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, ...@@ -156,7 +163,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
// mark alignment of external bufs // mark alignment of external bufs
init_nest_.emplace_back(AttrStmt::make( init_nest_.emplace_back(AttrStmt::make(
vptr, ir::attr::storage_alignment, vptr, ir::attr::storage_alignment,
IntImm::make(Int(32), runtime::kAllocAlignment), nop)); IntImm::make(Int(32), buffer->data_alignment), nop));
} }
Var v_shape(arg_name + ".shape", Handle()); Var v_shape(arg_name + ".shape", Handle());
...@@ -202,11 +209,18 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, ...@@ -202,11 +209,18 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
TVMArrayGet(UInt(64), handle, intrinsic::kArrByteOffset), TVMArrayGet(UInt(64), handle, intrinsic::kArrByteOffset),
arg_name + ".byte_offset", true); arg_name + ".byte_offset", true);
} else { } else {
Bind_(buffer->elem_offset, if (Bind_(buffer->elem_offset,
cast(buffer->elem_offset.type(), cast(buffer->elem_offset.type(),
(TVMArrayGet(UInt(64), handle, intrinsic::kArrByteOffset) / (TVMArrayGet(UInt(64), handle, intrinsic::kArrByteOffset) /
make_const(UInt(64), data_bytes))), make_const(UInt(64), data_bytes))),
arg_name + ".elem_offset", true); arg_name + ".elem_offset", true)) {
if (buffer->offset_factor > 1) {
Expr offset = buffer->elem_offset;
Expr factor = make_const(offset.type(), buffer->offset_factor);
Expr zero = make_zero(offset.type());
BinderAddAssert(offset % factor == zero, arg_name + ".elem_offset", &asserts_);
}
}
} }
// device info. // device info.
Bind_(device_type, Bind_(device_type,
......
...@@ -118,6 +118,7 @@ class IRDeepCompare : ...@@ -118,6 +118,7 @@ class IRDeepCompare :
const AssertStmt* rhs = other.as<AssertStmt>(); const AssertStmt* rhs = other.as<AssertStmt>();
if (CompareExpr(op->condition, rhs->condition) != 0) return; if (CompareExpr(op->condition, rhs->condition) != 0) return;
if (CompareExpr(op->message, rhs->message) != 0) return; if (CompareExpr(op->message, rhs->message) != 0) return;
if (CompareStmt(op->body, rhs->body) != 0) return;
} }
void VisitStmt_(const ProducerConsumer* op, const Stmt& other) final { void VisitStmt_(const ProducerConsumer* op, const Stmt& other) final {
...@@ -127,7 +128,6 @@ class IRDeepCompare : ...@@ -127,7 +128,6 @@ class IRDeepCompare :
if (CompareStmt(op->body, rhs->body) != 0) return; if (CompareStmt(op->body, rhs->body) != 0) return;
} }
void VisitStmt_(const Provide* op, const Stmt& other) final { void VisitStmt_(const Provide* op, const Stmt& other) final {
const Provide* rhs = other.as<Provide>(); const Provide* rhs = other.as<Provide>();
if (CompareNodeRef(op->func, rhs->func) != 0) return; if (CompareNodeRef(op->func, rhs->func) != 0) return;
......
...@@ -219,11 +219,14 @@ Stmt IRMutator::Mutate_(const Block* op, const Stmt& s) { ...@@ -219,11 +219,14 @@ Stmt IRMutator::Mutate_(const Block* op, const Stmt& s) {
Stmt IRMutator::Mutate_(const AssertStmt *op, const Stmt& s) { Stmt IRMutator::Mutate_(const AssertStmt *op, const Stmt& s) {
Expr condition = this->Mutate(op->condition); Expr condition = this->Mutate(op->condition);
Expr message = this->Mutate(op->message); Expr message = this->Mutate(op->message);
Stmt body = this->Mutate(op->body);
if (condition.same_as(op->condition) && message.same_as(op->message)) { if (condition.same_as(op->condition) &&
message.same_as(op->message) &&
body.same_as(op->body)) {
return s; return s;
} else { } else {
return AssertStmt::make(condition, message); return AssertStmt::make(condition, message, body);
} }
} }
......
...@@ -34,7 +34,10 @@ Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body) { ...@@ -34,7 +34,10 @@ Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body) {
n->then_case = body; n->then_case = body;
body = Stmt(n); body = Stmt(n);
} else if (s.as<AssertStmt>()) { } else if (s.as<AssertStmt>()) {
body = Block::make(s, body); auto n = std::make_shared<AssertStmt>(*s.as<AssertStmt>());
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else if (s.as<Allocate>()) { } else if (s.as<Allocate>()) {
auto n = std::make_shared<Allocate>(*s.as<Allocate>()); auto n = std::make_shared<Allocate>(*s.as<Allocate>());
CHECK(is_no_op(n->body)); CHECK(is_no_op(n->body));
......
...@@ -162,6 +162,7 @@ void IRVisitor::Visit_(const Broadcast *op) { ...@@ -162,6 +162,7 @@ void IRVisitor::Visit_(const Broadcast *op) {
void IRVisitor::Visit_(const AssertStmt *op) { void IRVisitor::Visit_(const AssertStmt *op) {
this->Visit(op->condition); this->Visit(op->condition);
this->Visit(op->message); this->Visit(op->message);
this->Visit(op->body);
} }
void IRVisitor::Visit_(const ProducerConsumer *op) { void IRVisitor::Visit_(const ProducerConsumer *op) {
......
...@@ -19,7 +19,7 @@ namespace tvm { ...@@ -19,7 +19,7 @@ namespace tvm {
namespace ir { namespace ir {
inline Stmt MakeAssertEQ(Expr lhs, Expr rhs, std::string msg) { inline Stmt MakeAssertEQ(Expr lhs, Expr rhs, std::string msg) {
return AssertStmt::make(lhs == rhs, msg); return AssertStmt::make(lhs == rhs, msg, Evaluate::make(0));
} }
LoweredFunc MakeAPI(Stmt body, LoweredFunc MakeAPI(Stmt body,
...@@ -100,16 +100,16 @@ LoweredFunc MakeAPI(Stmt body, ...@@ -100,16 +100,16 @@ LoweredFunc MakeAPI(Stmt body,
seq_check.emplace_back( seq_check.emplace_back(
AssertStmt::make(tcode == kHandle || AssertStmt::make(tcode == kHandle ||
tcode == kArrayHandle || tcode == kArrayHandle ||
tcode == kNull, msg.str())); tcode == kNull, msg.str(), nop));
} else if (t.is_int() || t.is_uint()) { } else if (t.is_int() || t.is_uint()) {
std::ostringstream msg; std::ostringstream msg;
msg << "Expect argument " << i << " to be int"; msg << "Expect argument " << i << " to be int";
seq_check.emplace_back(AssertStmt::make(tcode == kInt, msg.str())); seq_check.emplace_back(AssertStmt::make(tcode == kInt, msg.str(), nop));
} else { } else {
CHECK(t.is_float()); CHECK(t.is_float());
std::ostringstream msg; std::ostringstream msg;
msg << "Expect argument " << i << " to be float"; msg << "Expect argument " << i << " to be float";
seq_check.emplace_back(AssertStmt::make(tcode == kFloat, msg.str())); seq_check.emplace_back(AssertStmt::make(tcode == kFloat, msg.str(), nop));
} }
} else { } else {
args.push_back(v_arg); args.push_back(v_arg);
......
...@@ -16,10 +16,12 @@ def test_llvm_add_pipeline(): ...@@ -16,10 +16,12 @@ def test_llvm_add_pipeline():
return return
# Specifically allow offset to test codepath when offset is available # Specifically allow offset to test codepath when offset is available
Ab = tvm.decl_buffer( Ab = tvm.decl_buffer(
A.shape, A.dtype, elem_offset=tvm.var('Aoffset'), A.shape, A.dtype,
elem_offset=tvm.var('Aoffset'),
offset_factor=8,
name='A') name='A')
binds = {A : Ab} binds = {A : Ab}
# build and invoke the kernel. # BUILD and invoke the kernel.
f = tvm.build(s, [Ab, B, C], "llvm", binds=binds) f = tvm.build(s, [Ab, B, C], "llvm", binds=binds)
ctx = tvm.cpu(0) ctx = tvm.cpu(0)
# launch the kernel. # launch the kernel.
...@@ -31,7 +33,8 @@ def test_llvm_add_pipeline(): ...@@ -31,7 +33,8 @@ def test_llvm_add_pipeline():
np.testing.assert_allclose( np.testing.assert_allclose(
c.asnumpy(), a.asnumpy() + b.asnumpy()) c.asnumpy(), a.asnumpy() + b.asnumpy())
check_llvm() with tvm.build_config(offset_factor=4):
check_llvm()
def test_llvm_flip_pipeline(): def test_llvm_flip_pipeline():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment