Commit 88338826 by Tianqi Chen Committed by GitHub

[BUGFIX/TESTS] Bugfix of Tenso slicing. Union. (#66)

parent 3fb85796
...@@ -125,10 +125,13 @@ def compute(shape, fcompute, name="compute"): ...@@ -125,10 +125,13 @@ def compute(shape, fcompute, name="compute"):
""" """
shape = (shape,) if isinstance(shape, _expr.Expr) else shape shape = (shape,) if isinstance(shape, _expr.Expr) else shape
ndim = len(shape) ndim = len(shape)
arg_names = fcompute.__code__.co_varnames code = fcompute.__code__
if fcompute.__code__.co_argcount == 0 and len(arg_names) == 1: if fcompute.__code__.co_argcount == 0:
arg_names = ["i%d" % i for i in range(ndim)] arg_names = ["i%d" % i for i in range(ndim)]
else:
arg_names = code.co_varnames[:code.co_argcount]
if ndim != len(arg_names): if ndim != len(arg_names):
raise ValueError("fcompute do not match dimension, ndim=%d" % ndim) raise ValueError("fcompute do not match dimension, ndim=%d" % ndim)
......
...@@ -16,6 +16,8 @@ class TensorSlice(SliceBase, _expr.ExprOp): ...@@ -16,6 +16,8 @@ class TensorSlice(SliceBase, _expr.ExprOp):
self.indices = indices self.indices = indices
def __getitem__(self, indices): def __getitem__(self, indices):
if not isinstance(indices, tuple):
indices = (indices,)
return TensorSlice(self.tensor, self.indices + indices) return TensorSlice(self.tensor, self.indices + indices)
......
...@@ -168,14 +168,10 @@ IntSet Union(const Array<IntSet>& sets) { ...@@ -168,14 +168,10 @@ IntSet Union(const Array<IntSet>& sets) {
for (size_t i = 1; i < sets.size(); ++i) { for (size_t i = 1; i < sets.size(); ++i) {
IntSet s = sets[i].cover_interval(); IntSet s = sets[i].cover_interval();
const Interval& y = s.as<IntervalSet>()->i; const Interval& y = s.as<IntervalSet>()->i;
if (can_prove(x.max + 1 >= y.min)) {
x.max = y.max;
} else if (can_prove(y.max + 1 >= x.min)) {
x.min = y.min;
} else {
x.include(y); x.include(y);
} }
} x.max = ir::Simplify(x.max);
x.min = ir::Simplify(x.min);
return IntervalSet::make(x); return IntervalSet::make(x);
} }
......
...@@ -18,6 +18,21 @@ def test_tensor(): ...@@ -18,6 +18,21 @@ def test_tensor():
assert(d[T] == 1) assert(d[T] == 1)
def test_conv1d():
n = tvm.Var('n')
A = tvm.placeholder((n+2), name='A')
def computeB(ii):
i = ii + 1
return A[i-1] + A[i] + A[i+1]
B = tvm.compute(n, computeB)
def test_tensor_slice():
n = tvm.Var('n')
A = tvm.compute((n, n), lambda i, j: 1)
B = tvm.compute((n,), lambda i: A[0][i] + A[0][i])
def test_tensor_reduce(): def test_tensor_reduce():
m = tvm.Var('m') m = tvm.Var('m')
n = tvm.Var('n') n = tvm.Var('n')
...@@ -44,8 +59,32 @@ def test_tensor_scan(): ...@@ -44,8 +59,32 @@ def test_tensor_scan():
s) s)
assert tuple(res.shape) == (m, n) assert tuple(res.shape) == (m, n)
def test_scan_multi_out():
m = tvm.Var("m")
n = tvm.Var("n")
x1 = tvm.placeholder((m, n))
s1 = tvm.placeholder((m, n))
x2 = tvm.placeholder((m, n))
s2 = tvm.placeholder((m, n))
s1_init = tvm.compute((1, n), lambda _, i: x1[0, i])
s2_init = tvm.compute((1, n), lambda _, i: x2[0, i])
s1_update = tvm.compute((m, n), lambda t, i: s1[t-1, i] + s2[t-1, i] + x1[t, i])
s2_update = tvm.compute((m, n), lambda t, i: x2[t, i] + s2[t-1,i])
r0, r1 = tvm.scan([s1_init, s2_init],
[s1_update, s2_update],
[s1, s2])
assert(r0.value_index == 0)
assert(r1.value_index == 1)
json_str = tvm.save_json(r0.op)
zz = tvm.load_json(json_str)
assert isinstance(zz, tvm.tensor.ScanOp)
if __name__ == "__main__": if __name__ == "__main__":
test_conv1d()
test_tensor_slice()
test_tensor() test_tensor()
test_tensor_reduce() test_tensor_reduce()
test_tensor_scan() test_tensor_scan()
test_scan_multi_out()
...@@ -71,8 +71,38 @@ def test_bound_scan(): ...@@ -71,8 +71,38 @@ def test_bound_scan():
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
assert bounds[XX.op.axis[1]].extent.value == 4 assert bounds[XX.op.axis[1]].extent.value == 4
def test_bound_conv1d():
n = tvm.Var('n')
A = tvm.compute((n+2), lambda i: 1, name='A')
def computeB(ii):
i = ii + 1
return A[i-1] + A[i] + A[i+1]
B = tvm.compute(n, computeB, name='B')
s = tvm.Schedule(B.op)
s[A].compute_at(s[B], B.op.axis[0])
s.normalize()
bounds = tvm.schedule.InferBound(s)
assert(bounds[A.op.axis[0]].extent.value == 3)
def test_bound_blur():
n = tvm.convert(12)
A = tvm.compute((n, n), lambda i, j: 1, name='A')
def computeB(ii, jj):
# set the correct center
i = ii + 1
j = jj + 1
return A[i][j] + A[i-1][j] + A[i+1][j] + A[i][j+1] + A[i][j-1]
B = tvm.compute((n-2, n-2), computeB, name='B')
s = tvm.Schedule(B.op)
s[A].compute_at(s[B], B.op.axis[1])
s.normalize()
bounds = tvm.schedule.InferBound(s)
assert(bounds[A.op.axis[0]].extent.value == 3)
assert(bounds[A.op.axis[1]].extent.value == 3)
if __name__ == "__main__": if __name__ == "__main__":
test_bound_blur()
test_bound_conv1d()
test_bound_scan() test_bound_scan()
test_bound3() test_bound3()
test_bound1() test_bound1()
......
...@@ -74,10 +74,29 @@ def test_scan_fix_point(): ...@@ -74,10 +74,29 @@ def test_scan_fix_point():
assert(fxpt[s_scan.op.spatial_axis_[0]].value == 0) assert(fxpt[s_scan.op.spatial_axis_[0]].value == 0)
assert(fxpt[s_scan.op.spatial_axis_[1]].value == 0) assert(fxpt[s_scan.op.spatial_axis_[1]].value == 0)
def test_scan5_multi_output():
m = tvm.Var("m")
n = tvm.Var("n")
x1 = tvm.placeholder((m, n))
s1 = tvm.placeholder((m, n))
x2 = tvm.placeholder((m, n))
s2 = tvm.placeholder((m, n))
s1_init = tvm.compute((1, n), lambda _, i: x1[0, i])
s2_init = tvm.compute((1, n), lambda _, i: x2[0, i])
s1_update = tvm.compute((m, n), lambda t, i: s1[t-1, i] + x1[t, i])
s2_update = tvm.compute((m, n), lambda t, i: x2[t, i] + s2[t-1,i])
r0, r1 = tvm.scan([s1_init, s2_init],
[s1_update, s2_update],
[s1, s2])
body = tvm.schedule.ScanGetBody(r0.op)
fxpt = tvm.schedule.ScanFixPointAnalysis(r0.op, body)
assert(fxpt[r1.op.spatial_axis_[0]].value == 1)
test_scan0() test_scan0()
test_scan1() test_scan1()
test_scan3_not_exact_reach() test_scan3_not_exact_reach()
test_scan4_reach_other() test_scan4_reach_other()
test_scan5_multi_output()
def test_create_read_graph(): def test_create_read_graph():
m = tvm.Var('m') m = tvm.Var('m')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment