Commit 8bef595d by Yizhi Liu Committed by Tianqi Chen

[Bugfix] tvm.scan follow by tvm.compute segfault (#3723)

* [bugfix] tvm.scan follow by tvm.compute segfault

* more strict bound condition check

* access k + 1 -> k

* fix scan test
parent 9161efbc
...@@ -83,7 +83,7 @@ Operation ScanOpNode::make(std::string name, ...@@ -83,7 +83,7 @@ Operation ScanOpNode::make(std::string name,
<< "init.shape[0] need to match scan_axis.dom.min"; << "init.shape[0] need to match scan_axis.dom.min";
CHECK(prove_equal( CHECK(prove_equal(
state_placeholder[i]->shape[0], axis->dom->min + axis->dom->extent)) state_placeholder[i]->shape[0], axis->dom->min + axis->dom->extent))
<< "shate_placeholder.shape[0] need to match" << "state_placeholder.shape[0] need to match"
<< " scan_axis.dom.min + scan_axis.dom.extent"; << " scan_axis.dom.min + scan_axis.dom.extent";
CHECK_EQ(state_placeholder[i].ndim(), init[i].ndim()) CHECK_EQ(state_placeholder[i].ndim(), init[i].ndim())
<< "The dimension of init need to match state_placeholder"; << "The dimension of init need to match state_placeholder";
...@@ -242,7 +242,7 @@ void ScanOpNode::GatherBound( ...@@ -242,7 +242,7 @@ void ScanOpNode::GatherBound(
CHECK(fix_pt.count(sp_ax)); CHECK(fix_pt.count(sp_ax));
if (fix_pt[sp_ax].as<ir::IntImm>()->value) { if (fix_pt[sp_ax].as<ir::IntImm>()->value) {
// fix point, we can slice it. // fix point, we can slice it.
(*out_dom_map)[sp_ax] = arith::Union(d.data[k + 1]).cover_range(sp_ax->dom); (*out_dom_map)[sp_ax] = arith::Union(d.data[k]).cover_range(sp_ax->dom);
} else { } else {
// not a fix point, need to include everything. // not a fix point, need to include everything.
(*out_dom_map)[sp_ax] = sp_ax->dom; (*out_dom_map)[sp_ax] = sp_ax->dom;
......
...@@ -24,7 +24,9 @@ def test_scan(): ...@@ -24,7 +24,9 @@ def test_scan():
s_state = tvm.placeholder((m, n)) s_state = tvm.placeholder((m, n))
s_init = tvm.compute((1, n), lambda _, i: X[0, i]) s_init = tvm.compute((1, n), lambda _, i: X[0, i])
s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i]) s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i])
res = tvm.scan(s_init, s_update, s_state) scan = tvm.scan(s_init, s_update, s_state)
# test scan + compute case
res = tvm.compute((m, n), lambda i, j: scan[i, j])
# schedule # schedule
s = tvm.create_schedule(res.op) s = tvm.create_schedule(res.op)
...@@ -37,6 +39,9 @@ def test_scan(): ...@@ -37,6 +39,9 @@ def test_scan():
xo, xi = s[s_update].split(s_update.op.axis[1], factor=num_thread) xo, xi = s[s_update].split(s_update.op.axis[1], factor=num_thread)
s[s_update].bind(xo, block_x) s[s_update].bind(xo, block_x)
s[s_update].bind(xi, thread_x) s[s_update].bind(xi, thread_x)
xo, xi = s[res].split(res.op.axis[1], factor=num_thread)
s[res].bind(xo, block_x)
s[res].bind(xi, thread_x)
# one line to build the function. # one line to build the function.
def check_device(device): def check_device(device):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment