Unverified Commit f9d8d063 by Wuwei Lin Committed by GitHub

Fix ArgBinder assert order (#3794)

parent b76b627b
...@@ -239,7 +239,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, ...@@ -239,7 +239,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
AssertStmt::make(arith::ComputeReduce<ir::And>(conds, Expr()), AssertStmt::make(arith::ComputeReduce<ir::And>(conds, Expr()),
stride_err_msg.str(), Evaluate::make(0)); stride_err_msg.str(), Evaluate::make(0));
check = IfThenElse::make(Not::make(is_null), check, Stmt()); check = IfThenElse::make(Not::make(is_null), check, Stmt());
init_nest_.emplace_back(Block::make(check, Evaluate::make(0))); asserts_.emplace_back(Block::make(check, Evaluate::make(0)));
} }
} else if (buffer->buffer_type == kAutoBroadcast) { } else if (buffer->buffer_type == kAutoBroadcast) {
Type stype = buffer->DefaultIndexType(); Type stype = buffer->DefaultIndexType();
......
...@@ -32,5 +32,13 @@ def test_lower_rfactor(): ...@@ -32,5 +32,13 @@ def test_lower_rfactor():
s[BF].compute_at(s[B], s[B].op.reduce_axis[0]) s[BF].compute_at(s[B], s[B].op.reduce_axis[0])
fapi = tvm.lower(s, [A, B]) fapi = tvm.lower(s, [A, B])
def test_dependent_output_shape():
n, m, x = tvm.var('n'), tvm.var('m'), tvm.var('x')
A = tvm.placeholder((n, m))
B = tvm.compute((m, n/x), lambda i, j: A[i,j] , name='B')
s = tvm.create_schedule(B.op)
mod = tvm.build(s, [A, B, x])
if __name__ == "__main__": if __name__ == "__main__":
test_lower_rfactor() test_lower_rfactor()
test_dependent_output_shape()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment