Commit 8bef595d by Yizhi Liu Committed by Tianqi Chen

[Bugfix] tvm.scan follow by tvm.compute segfault (#3723)

* [bugfix] tvm.scan follow by tvm.compute segfault

* more strict bound condition check

* access k + 1 -> k

* fix scan test
parent 9161efbc
......@@ -83,7 +83,7 @@ Operation ScanOpNode::make(std::string name,
<< "init.shape[0] need to match scan_axis.dom.min";
CHECK(prove_equal(
state_placeholder[i]->shape[0], axis->dom->min + axis->dom->extent))
<< "shate_placeholder.shape[0] need to match"
<< "state_placeholder.shape[0] need to match"
<< " scan_axis.dom.min + scan_axis.dom.extent";
CHECK_EQ(state_placeholder[i].ndim(), init[i].ndim())
<< "The dimension of init need to match state_placeholder";
......@@ -242,7 +242,7 @@ void ScanOpNode::GatherBound(
CHECK(fix_pt.count(sp_ax));
if (fix_pt[sp_ax].as<ir::IntImm>()->value) {
// fix point, we can slice it.
(*out_dom_map)[sp_ax] = arith::Union(d.data[k + 1]).cover_range(sp_ax->dom);
(*out_dom_map)[sp_ax] = arith::Union(d.data[k]).cover_range(sp_ax->dom);
} else {
// not a fix point, need to include everything.
(*out_dom_map)[sp_ax] = sp_ax->dom;
......
......@@ -24,7 +24,9 @@ def test_scan():
s_state = tvm.placeholder((m, n))
s_init = tvm.compute((1, n), lambda _, i: X[0, i])
s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i])
res = tvm.scan(s_init, s_update, s_state)
scan = tvm.scan(s_init, s_update, s_state)
# test scan + compute case
res = tvm.compute((m, n), lambda i, j: scan[i, j])
# schedule
s = tvm.create_schedule(res.op)
......@@ -37,6 +39,9 @@ def test_scan():
xo, xi = s[s_update].split(s_update.op.axis[1], factor=num_thread)
s[s_update].bind(xo, block_x)
s[s_update].bind(xi, thread_x)
xo, xi = s[res].split(res.op.axis[1], factor=num_thread)
s[res].bind(xo, block_x)
s[res].bind(xi, thread_x)
# one line to build the function.
def check_device(device):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment