Commit f0ae8e31 by Sergei Grechanik Committed by ziheng

[TVM][BUGFIX] Fix missing reduction init predicates (#2495)

* [TVM][BUGFIX] Fix reductions in split axes

* A test case for the problem

* Fix the fix: skip loops that are related to reduction AND are unrelated to axis
parent 389fbb5c
...@@ -457,11 +457,11 @@ ComputeLoopNest ComputeLoopNest::make( ...@@ -457,11 +457,11 @@ ComputeLoopNest ComputeLoopNest::make(
ret.init_vmap[iv] = ret.main_vmap.at(iv); ret.init_vmap[iv] = ret.main_vmap.at(iv);
} }
ret.num_common_loop = begin_loop; ret.num_common_loop = begin_loop;
// skip loops that does not relates to axis. // skip loops that are related to reduction and are unrelated to axis.
std::unordered_set<IterVar> skip_iter; std::unordered_set<IterVar> skip_iter;
for (auto kv : update_state) { for (auto kv : update_state) {
int flag = kv.second; int flag = kv.second;
if ((flag & 1) == 0) skip_iter.insert(kv.first); if (flag == 2) skip_iter.insert(kv.first);
} }
ret.init_nest = op::MakeLoopNest( ret.init_nest = op::MakeLoopNest(
stage, dom_map, begin_loop, true, stage, dom_map, begin_loop, true,
......
...@@ -215,11 +215,11 @@ ComputeLoopNest MakeLoopNest( ...@@ -215,11 +215,11 @@ ComputeLoopNest MakeLoopNest(
ret.init_vmap[iv] = ret.main_vmap.at(iv); ret.init_vmap[iv] = ret.main_vmap.at(iv);
} }
ret.num_common_loop = begin_loop; ret.num_common_loop = begin_loop;
// skip loops that does not relates to axis. // skip loops that are related to reduction and are unrelated to axis.
std::unordered_set<IterVar> skip_iter; std::unordered_set<IterVar> skip_iter;
for (auto kv : update_state) { for (auto kv : update_state) {
int flag = kv.second; int flag = kv.second;
if ((flag & 1) == 0) skip_iter.insert(kv.first); if (flag == 2) skip_iter.insert(kv.first);
} }
ret.init_nest = op::MakeLoopNest( ret.init_nest = op::MakeLoopNest(
stage, dom_map, begin_loop, true, stage, dom_map, begin_loop, true,
......
import tvm import tvm
import numpy as np
def test_schedule0(): def test_schedule0():
m = tvm.var('m') m = tvm.var('m')
...@@ -432,6 +432,32 @@ def test_loop_dep_reduce_cache_write(): ...@@ -432,6 +432,32 @@ def test_loop_dep_reduce_cache_write():
s.cache_write(Y, 'local') s.cache_write(Y, 'local')
f = tvm.build(s, [X, Y]) f = tvm.build(s, [X, Y])
def test_reduction_and_dummy_fuse_split():
n = 10
X = tvm.placeholder(shape=(n,), dtype='int32', name="X")
k = tvm.reduce_axis((0, n))
Y = tvm.compute((), lambda: tvm.sum(X[k], k), name="Y")
s = tvm.create_schedule([Y.op])
ax = s[Y.op].fuse(*Y.op.axis)
axo, axi = s[Y.op].split(ax, nparts=20)
f = tvm.build(s, [Y, X])
args = [tvm.nd.empty((), 'int32')] + [tvm.ndarray.array(np.ones((n,), dtype='int32'))]
f(*args)
assert args[0].asnumpy() == n
n = 10
X = tvm.placeholder(shape=(n,), dtype='int32', name="X")
k = tvm.reduce_axis((0, n))
Y = tvm.compute((n,), lambda i: tvm.sum(X[k], k), name="Y")
s = tvm.create_schedule([Y.op])
ax = s[Y.op].fuse(*(list(Y.op.axis) + list(Y.op.reduce_axis)))
f = tvm.build(s, [Y, X])
args = [tvm.ndarray.array(np.ones((n,), dtype='int32'))] + \
[tvm.ndarray.array(np.ones((n,), dtype='int32'))]
f(*args)
assert np.all(args[0].asnumpy() == n)
if __name__ == "__main__": if __name__ == "__main__":
test_loop_dep_reduce() test_loop_dep_reduce()
...@@ -456,3 +482,4 @@ if __name__ == "__main__": ...@@ -456,3 +482,4 @@ if __name__ == "__main__":
test_schedule_tensor_compute1() test_schedule_tensor_compute1()
test_schedule_tensor_compute2() test_schedule_tensor_compute2()
test_schedule_tensor_compute3() test_schedule_tensor_compute3()
test_reduction_and_dummy_fuse_split()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment