Commit ff26cd68 by tqchen

Fix Tile, add a few more test cases on bound inference

parent 0f693212
...@@ -154,7 +154,7 @@ Schedule& Schedule::tile(IterVar x_parent, IterVar y_parent, ...@@ -154,7 +154,7 @@ Schedule& Schedule::tile(IterVar x_parent, IterVar y_parent,
Expr x_factor, Expr y_factor) { // NOLINT(*) Expr x_factor, Expr y_factor) { // NOLINT(*)
split(x_parent, p_x_outer, p_x_inner, x_factor); split(x_parent, p_x_outer, p_x_inner, x_factor);
split(y_parent, p_y_outer, p_y_inner, y_factor); split(y_parent, p_y_outer, p_y_inner, y_factor);
reorder(Array<IterVar>({*p_x_inner, *p_y_inner, *p_x_outer, *p_y_outer})); reorder(Array<IterVar>({*p_x_outer, *p_y_outer, *p_x_inner, *p_y_inner}));
return *this; return *this;
} }
......
...@@ -165,8 +165,15 @@ bool ScopeRelax(const IterVar& iv, const std::string& scope) { ...@@ -165,8 +165,15 @@ bool ScopeRelax(const IterVar& iv, const std::string& scope) {
{"shared", 1}, {"shared", 1},
{"local", 2} {"local", 2}
}; };
static std::unordered_map<std::string, int> thread_tag_rank{
return scope_rank.at(scope) <= scope_rank.at(iv->thread_tag); {"gridIdx.x", 0},
{"gridIdx.y", 0},
{"gridIdx.z", 0},
{"threadIdx.x", 1},
{"threadIdx.y", 1},
{"threadIdx.z", 1}
};
return scope_rank.at(scope) <= thread_tag_rank.at(iv->thread_tag);
} }
void InferBound( void InferBound(
......
...@@ -220,6 +220,8 @@ void PassUp(const SplitNode* s, ...@@ -220,6 +220,8 @@ void PassUp(const SplitNode* s,
*parent = IntSet::make_range(dom_map.at(s->parent)); *parent = IntSet::make_range(dom_map.at(s->parent));
return; return;
} }
CHECK(outer.defined());
CHECK(inner.defined());
// copy construct // copy construct
auto n = std::make_shared<IntSetNode>(*(inner.operator->())); auto n = std::make_shared<IntSetNode>(*(inner.operator->()));
...@@ -228,7 +230,6 @@ void PassUp(const SplitNode* s, ...@@ -228,7 +230,6 @@ void PassUp(const SplitNode* s,
n->base = Range::make_with_min_extent( n->base = Range::make_with_min_extent(
AsNumber(outer) * s->factor + inner->base->min, AsNumber(outer) * s->factor + inner->base->min,
inner->base->extent); inner->base->extent);
*parent = IntSet(n);
} else { } else {
// default use all domains in the data. // default use all domains in the data.
n->domain.push_back(outer->base); n->domain.push_back(outer->base);
...@@ -238,6 +239,7 @@ void PassUp(const SplitNode* s, ...@@ -238,6 +239,7 @@ void PassUp(const SplitNode* s,
n->stride.push_back(outer->stride[i] * s->factor); n->stride.push_back(outer->stride[i] * s->factor);
} }
} }
*parent = IntSet(n);
} }
void PassUp(const FuseNode* s, void PassUp(const FuseNode* s,
......
import tvm import tvm
def test_bound_inference(): def test_bound1():
m = tvm.Var('m') m = tvm.Var('m')
l = tvm.Var('l') l = tvm.Var('l')
A = tvm.placeholder((m, l), name='A') A = tvm.placeholder((m, l), name='A')
...@@ -12,8 +12,42 @@ def test_bound_inference(): ...@@ -12,8 +12,42 @@ def test_bound_inference():
sA1.compute_at(sA2, xo) sA1.compute_at(sA2, xo)
bounds = tvm.schedule.InferBound(sA2) bounds = tvm.schedule.InferBound(sA2)
assert isinstance(bounds, tvm.collections.Map) assert isinstance(bounds, tvm.collections.Map)
print(bounds[A1.op.dim_var[0]]) assert(bounds[A1.op.dim_var[0]].extent.value == 8)
print(bounds[A1.op.dim_var[1]])
def test_bound2():
m = tvm.Var('m')
l = tvm.Var('l')
A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
sA1 = tvm.Schedule(A1.op)
sA2 = tvm.Schedule(A2.op)
xo, yo, xi, yi = sA2.tile(A2.op.dim_var[0], A2.op.dim_var[1], 8, 8)
sA1.compute_at(sA2, yo)
bounds = tvm.schedule.InferBound(sA2)
assert isinstance(bounds, tvm.collections.Map)
assert(bounds[A1.op.dim_var[0]].extent.value == 8)
assert(bounds[A1.op.dim_var[1]].extent.value == 8)
def test_bound3():
m = tvm.Var('m')
l = tvm.Var('l')
A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
sA1 = tvm.Schedule(A1.op, scope="shared")
sA2 = tvm.Schedule(A2.op)
thread_x = tvm.IterVar((0, 16), thread_tag="threadIdx.x")
xo, xi = sA2.split(A2.op.dim_var[0], 32)
xi0, xi1 = sA2.split(xi, outer=thread_x)
yo, yi = sA2.split(A2.op.dim_var[1], 16)
sA2.reorder(xo, xi0, yo, xi1, yi)
sA1.compute_at(sA2, yo)
bounds = tvm.schedule.InferBound(sA2)
assert isinstance(bounds, tvm.collections.Map)
assert(bounds[A1.op.dim_var[0]].extent.value==32)
assert(bounds[A1.op.dim_var[1]].extent.value==16)
def test_create_read_graph(): def test_create_read_graph():
...@@ -31,5 +65,7 @@ def test_create_read_graph(): ...@@ -31,5 +65,7 @@ def test_create_read_graph():
if __name__ == "__main__": if __name__ == "__main__":
test_bound_inference() test_bound3()
test_bound1()
test_bound2()
test_create_read_graph() test_create_read_graph()
...@@ -34,6 +34,16 @@ def test_reorder(): ...@@ -34,6 +34,16 @@ def test_reorder():
sch_T.reorder(*order) sch_T.reorder(*order)
assert tuple(sch_T.leaf_iter_vars) == order assert tuple(sch_T.leaf_iter_vars) == order
def test_split():
m = tvm.Var('m')
A = tvm.placeholder((m,), name='A')
T = tvm.compute((m,), lambda i: A[i])
sT = tvm.Schedule(T.op)
xo, xi = sT.split(T.op.dim_var[0], factor=10)
assert tuple(sT.leaf_iter_vars) == (xo, xi)
def test_tile(): def test_tile():
m = tvm.Var('m') m = tvm.Var('m')
n = tvm.Var('n') n = tvm.Var('n')
...@@ -42,9 +52,10 @@ def test_tile(): ...@@ -42,9 +52,10 @@ def test_tile():
sch_T = tvm.Schedule(T.op, scope="shared") sch_T = tvm.Schedule(T.op, scope="shared")
xo, yo, xi, yi = sch_T.tile(T.op.dim_var[0], T.op.dim_var[1], x_factor=10, y_factor=5) xo, yo, xi, yi = sch_T.tile(T.op.dim_var[0], T.op.dim_var[1], x_factor=10, y_factor=5)
assert tuple(sch_T.leaf_iter_vars) == (xi, yi, xo, yo) assert tuple(sch_T.leaf_iter_vars) == (xo, yo, xi, yi)
if __name__ == "__main__": if __name__ == "__main__":
test_schedule_create() test_schedule_create()
test_reorder() test_reorder()
test_tile() test_tile()
test_split()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment