Commit 88338826 by Tianqi Chen Committed by GitHub

[BUGFIX/TESTS] Bugfix of Tenso slicing. Union. (#66)

parent 3fb85796
......@@ -125,10 +125,13 @@ def compute(shape, fcompute, name="compute"):
"""
shape = (shape,) if isinstance(shape, _expr.Expr) else shape
ndim = len(shape)
arg_names = fcompute.__code__.co_varnames
code = fcompute.__code__
if fcompute.__code__.co_argcount == 0 and len(arg_names) == 1:
if fcompute.__code__.co_argcount == 0:
arg_names = ["i%d" % i for i in range(ndim)]
else:
arg_names = code.co_varnames[:code.co_argcount]
if ndim != len(arg_names):
raise ValueError("fcompute do not match dimension, ndim=%d" % ndim)
......
......@@ -16,6 +16,8 @@ class TensorSlice(SliceBase, _expr.ExprOp):
self.indices = indices
def __getitem__(self, indices):
if not isinstance(indices, tuple):
indices = (indices,)
return TensorSlice(self.tensor, self.indices + indices)
......
......@@ -168,14 +168,10 @@ IntSet Union(const Array<IntSet>& sets) {
for (size_t i = 1; i < sets.size(); ++i) {
IntSet s = sets[i].cover_interval();
const Interval& y = s.as<IntervalSet>()->i;
if (can_prove(x.max + 1 >= y.min)) {
x.max = y.max;
} else if (can_prove(y.max + 1 >= x.min)) {
x.min = y.min;
} else {
x.include(y);
}
x.include(y);
}
x.max = ir::Simplify(x.max);
x.min = ir::Simplify(x.min);
return IntervalSet::make(x);
}
......
......@@ -18,6 +18,21 @@ def test_tensor():
assert(d[T] == 1)
def test_conv1d():
n = tvm.Var('n')
A = tvm.placeholder((n+2), name='A')
def computeB(ii):
i = ii + 1
return A[i-1] + A[i] + A[i+1]
B = tvm.compute(n, computeB)
def test_tensor_slice():
n = tvm.Var('n')
A = tvm.compute((n, n), lambda i, j: 1)
B = tvm.compute((n,), lambda i: A[0][i] + A[0][i])
def test_tensor_reduce():
m = tvm.Var('m')
n = tvm.Var('n')
......@@ -44,8 +59,32 @@ def test_tensor_scan():
s)
assert tuple(res.shape) == (m, n)
def test_scan_multi_out():
m = tvm.Var("m")
n = tvm.Var("n")
x1 = tvm.placeholder((m, n))
s1 = tvm.placeholder((m, n))
x2 = tvm.placeholder((m, n))
s2 = tvm.placeholder((m, n))
s1_init = tvm.compute((1, n), lambda _, i: x1[0, i])
s2_init = tvm.compute((1, n), lambda _, i: x2[0, i])
s1_update = tvm.compute((m, n), lambda t, i: s1[t-1, i] + s2[t-1, i] + x1[t, i])
s2_update = tvm.compute((m, n), lambda t, i: x2[t, i] + s2[t-1,i])
r0, r1 = tvm.scan([s1_init, s2_init],
[s1_update, s2_update],
[s1, s2])
assert(r0.value_index == 0)
assert(r1.value_index == 1)
json_str = tvm.save_json(r0.op)
zz = tvm.load_json(json_str)
assert isinstance(zz, tvm.tensor.ScanOp)
if __name__ == "__main__":
test_conv1d()
test_tensor_slice()
test_tensor()
test_tensor_reduce()
test_tensor_scan()
test_scan_multi_out()
......@@ -71,8 +71,38 @@ def test_bound_scan():
stmt = tvm.schedule.ScheduleOps(s, bounds)
assert bounds[XX.op.axis[1]].extent.value == 4
def test_bound_conv1d():
n = tvm.Var('n')
A = tvm.compute((n+2), lambda i: 1, name='A')
def computeB(ii):
i = ii + 1
return A[i-1] + A[i] + A[i+1]
B = tvm.compute(n, computeB, name='B')
s = tvm.Schedule(B.op)
s[A].compute_at(s[B], B.op.axis[0])
s.normalize()
bounds = tvm.schedule.InferBound(s)
assert(bounds[A.op.axis[0]].extent.value == 3)
def test_bound_blur():
n = tvm.convert(12)
A = tvm.compute((n, n), lambda i, j: 1, name='A')
def computeB(ii, jj):
# set the correct center
i = ii + 1
j = jj + 1
return A[i][j] + A[i-1][j] + A[i+1][j] + A[i][j+1] + A[i][j-1]
B = tvm.compute((n-2, n-2), computeB, name='B')
s = tvm.Schedule(B.op)
s[A].compute_at(s[B], B.op.axis[1])
s.normalize()
bounds = tvm.schedule.InferBound(s)
assert(bounds[A.op.axis[0]].extent.value == 3)
assert(bounds[A.op.axis[1]].extent.value == 3)
if __name__ == "__main__":
test_bound_blur()
test_bound_conv1d()
test_bound_scan()
test_bound3()
test_bound1()
......
......@@ -74,10 +74,29 @@ def test_scan_fix_point():
assert(fxpt[s_scan.op.spatial_axis_[0]].value == 0)
assert(fxpt[s_scan.op.spatial_axis_[1]].value == 0)
def test_scan5_multi_output():
m = tvm.Var("m")
n = tvm.Var("n")
x1 = tvm.placeholder((m, n))
s1 = tvm.placeholder((m, n))
x2 = tvm.placeholder((m, n))
s2 = tvm.placeholder((m, n))
s1_init = tvm.compute((1, n), lambda _, i: x1[0, i])
s2_init = tvm.compute((1, n), lambda _, i: x2[0, i])
s1_update = tvm.compute((m, n), lambda t, i: s1[t-1, i] + x1[t, i])
s2_update = tvm.compute((m, n), lambda t, i: x2[t, i] + s2[t-1,i])
r0, r1 = tvm.scan([s1_init, s2_init],
[s1_update, s2_update],
[s1, s2])
body = tvm.schedule.ScanGetBody(r0.op)
fxpt = tvm.schedule.ScanFixPointAnalysis(r0.op, body)
assert(fxpt[r1.op.spatial_axis_[0]].value == 1)
test_scan0()
test_scan1()
test_scan3_not_exact_reach()
test_scan4_reach_other()
test_scan5_multi_output()
def test_create_read_graph():
m = tvm.Var('m')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment