Commit 7e68d63f by Salem Derisavi Committed by ziheng

1) fixed a functional bug in loop partitioning algorithm that is exposed when…

1) fixed a functional bug in loop partitioning algorithm that is exposed when double splitting with indivisible factors 2) added a testcase (#2956)
parent 8b5b180a
...@@ -15,12 +15,21 @@ ...@@ -15,12 +15,21 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import tvm import tvm
import numpy
def collect_visit(stmt, f): def collect_visit(stmt, f):
ret = [] ret = []
tvm.ir_pass.PostOrderVisit(stmt, lambda x : ret.append(f(x))) tvm.ir_pass.PostOrderVisit(stmt, lambda x : ret.append(f(x)))
return ret return ret
def find_top_produce(stmt):
def f(x, ret):
if isinstance(x, tvm.stmt.ProducerConsumer):
ret.append(x)
ret = []
tvm.ir_pass.PostOrderVisit(stmt, lambda x : f(x, ret))
return ret[-1]
def lower(sch, args): def lower(sch, args):
binds = {} binds = {}
arg_list = [] arg_list = []
...@@ -344,6 +353,37 @@ def test_conv_tiling(): ...@@ -344,6 +353,37 @@ def test_conv_tiling():
stmt = tvm.ir_pass.Simplify(stmt) stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse)))) assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse))))
def test_double_splitting_with_indivisible_factors():
m = 48
dtype="float32"
A = tvm.placeholder((m,), name='A', dtype=dtype)
C = tvm.compute((m,), lambda i: A[i], name='C')
D = tvm.compute((m,), lambda i: C[i], name='D')
s = tvm.create_schedule(D.op)
co, ci = s[C].split(C.op.axis[0], factor=10)
do, di = s[D].split(D.op.axis[0], 32)
s[C].compute_at(s[D], do)
target = 'llvm'
with tvm.build_config(partition_const_loop=True):
f = tvm.lower(s, [A, C, D], name="fadd1", simple_mode=False)
func = tvm.build(f, target=target)
# Find the beginning of the Halide IR corresponding to kernel code
# and make sure it doesn't have an if statements left
top_produce = find_top_produce(f.body)
assert(not any(collect_visit(top_produce, lambda x: isinstance(x, tvm.stmt.IfThenElse))))
# check functional correctness of generated code
ctx = tvm.context(target, 0)
a = tvm.nd.array(numpy.ones(m,).astype(dtype), ctx)
c = tvm.nd.array(numpy.zeros(m,).astype(dtype), ctx)
d = tvm.nd.array(numpy.zeros(m,).astype(dtype), ctx)
func(a, c, d)
tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy(), rtol=1e-5)
tvm.testing.assert_allclose(d.asnumpy(), a.asnumpy(), rtol=1e-5)
if __name__ == "__main__": if __name__ == "__main__":
test_basic() test_basic()
test_const_loop() test_const_loop()
...@@ -361,3 +401,4 @@ if __name__ == "__main__": ...@@ -361,3 +401,4 @@ if __name__ == "__main__":
test_cce_loop_2() test_cce_loop_2()
test_cce_loop_3() test_cce_loop_3()
test_conv_tiling() test_conv_tiling()
test_double_splitting_with_indivisible_factors()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment