Commit b55361b4 by Tianqi Chen Committed by GitHub

[PASS] Allow compact checking when strides is available (#669)

* [PASS] Allow compact checking when strides is available

* remove assert compact
parent 70ccc8b6
...@@ -362,6 +362,7 @@ void CodeGenStackVM::VisitExpr_(const Or *op) { ...@@ -362,6 +362,7 @@ void CodeGenStackVM::VisitExpr_(const Or *op) {
} }
void CodeGenStackVM::VisitExpr_(const Not* op) { void CodeGenStackVM::VisitExpr_(const Not* op) {
this->Push(op->a);
this->PushOp(StackVM::NOT); this->PushOp(StackVM::NOT);
} }
......
...@@ -136,12 +136,6 @@ inline Expr TVMArrayGet(Type t, Var arr, intrinsic::TVMStructFieldKind kind) { ...@@ -136,12 +136,6 @@ inline Expr TVMArrayGet(Type t, Var arr, intrinsic::TVMStructFieldKind kind) {
return TVMStructGet(t, arr, 0, kind); return TVMStructGet(t, arr, 0, kind);
} }
inline Stmt AssertNull(Var handle, std::string msg) {
return AssertStmt::make(Call::make(
Bool(1), intrinsic::tvm_handle_is_null,
{handle}, Call::PureIntrinsic), msg, Evaluate::make(0));
}
void ArgBinder::BindDLTensor(const Buffer& buffer, void ArgBinder::BindDLTensor(const Buffer& buffer,
const Expr& device_type, const Expr& device_type,
const Expr& device_id, const Expr& device_id,
...@@ -201,10 +195,30 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, ...@@ -201,10 +195,30 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
v_strides, TVMArrayGet(Handle(), handle, intrinsic::kArrStrides), v_strides, TVMArrayGet(Handle(), handle, intrinsic::kArrStrides),
nop)); nop));
if (buffer->strides.size() == 0) { if (buffer->strides.size() == 0) {
// Assert the buffer is compact
Type stype = buffer->shape[0].type();
Expr expect_stride = make_const(stype, 1);
Array<Expr> conds;
for (size_t i = buffer->shape.size(); i != 0; --i) {
size_t k = i - 1;
Expr svalue = cast(
stype,
Load::make(tvm_shape_type, v_strides,
IntImm::make(Int(32), k), const_true(1)));
conds.push_back(expect_stride == svalue);
expect_stride = expect_stride * buffer->shape[k];
}
std::ostringstream stride_err_msg; std::ostringstream stride_err_msg;
stride_err_msg << arg_name << ".strides:" stride_err_msg << arg_name << ".strides:"
<< " expected to be nullptr for contiguous array"; << " expected to be compact array";
init_nest_.emplace_back(AssertNull(v_strides, stride_err_msg.str())); Stmt check =
AssertStmt::make(arith::ComputeReduce<ir::And>(conds),
stride_err_msg.str(), Evaluate::make(0));
Expr is_null = Call::make(
Bool(1), intrinsic::tvm_handle_is_null,
{v_strides}, Call::PureIntrinsic);
check = IfThenElse::make(Not::make(is_null), check, Stmt());
init_nest_.emplace_back(Block::make(check, Evaluate::make(0)));
} else { } else {
for (size_t k = 0; k < buffer->strides.size(); ++k) { for (size_t k = 0; k < buffer->strides.size(); ++k) {
std::ostringstream field_name; std::ostringstream field_name;
......
...@@ -33,6 +33,11 @@ Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body) { ...@@ -33,6 +33,11 @@ Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body) {
CHECK(!n->else_case.defined()); CHECK(!n->else_case.defined());
n->then_case = body; n->then_case = body;
body = Stmt(n); body = Stmt(n);
} else if (s.as<Block>()) {
auto n = std::make_shared<Block>(*s.as<Block>());
CHECK(is_no_op(n->rest));
n->rest = body;
body = Stmt(n);
} else if (s.as<AssertStmt>()) { } else if (s.as<AssertStmt>()) {
auto n = std::make_shared<AssertStmt>(*s.as<AssertStmt>()); auto n = std::make_shared<AssertStmt>(*s.as<AssertStmt>());
CHECK(is_no_op(n->body)); CHECK(is_no_op(n->body));
......
...@@ -225,7 +225,8 @@ def _im2col_pack(data, kernel, stride, padding, out_dtype): ...@@ -225,7 +225,8 @@ def _im2col_pack(data, kernel, stride, padding, out_dtype):
wk = tvm.reduce_axis((0, KW), name='wk') wk = tvm.reduce_axis((0, KW), name='wk')
conv = tvm.compute(ovshape, lambda n, co, im, vim, vco: \ conv = tvm.compute(ovshape, lambda n, co, im, vim, vco: \
tvm.sum(data_vec[n][im][ci][hk][wk][vim] * kernel_vec[co][ci][hk][wk][vco], tvm.sum(data_vec[n][im][ci][hk][wk][vim].astype(out_dtype) *
kernel_vec[co][ci][hk][wk][vco].astype(out_dtype),
axis=[ci, hk, wk]), name='conv') axis=[ci, hk, wk]), name='conv')
output = tvm.compute(oshape, lambda n, co, h, w: \ output = tvm.compute(oshape, lambda n, co, h, w: \
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment