Commit ff26cd68 by tqchen

Fix Tile, add a few more test cases on bound inference

parent 0f693212
......@@ -154,7 +154,7 @@ Schedule& Schedule::tile(IterVar x_parent, IterVar y_parent,
Expr x_factor, Expr y_factor) { // NOLINT(*)
split(x_parent, p_x_outer, p_x_inner, x_factor);
split(y_parent, p_y_outer, p_y_inner, y_factor);
reorder(Array<IterVar>({*p_x_inner, *p_y_inner, *p_x_outer, *p_y_outer}));
reorder(Array<IterVar>({*p_x_outer, *p_y_outer, *p_x_inner, *p_y_inner}));
return *this;
}
......
......@@ -165,8 +165,15 @@ bool ScopeRelax(const IterVar& iv, const std::string& scope) {
{"shared", 1},
{"local", 2}
};
return scope_rank.at(scope) <= scope_rank.at(iv->thread_tag);
static std::unordered_map<std::string, int> thread_tag_rank{
{"gridIdx.x", 0},
{"gridIdx.y", 0},
{"gridIdx.z", 0},
{"threadIdx.x", 1},
{"threadIdx.y", 1},
{"threadIdx.z", 1}
};
return scope_rank.at(scope) <= thread_tag_rank.at(iv->thread_tag);
}
void InferBound(
......
......@@ -220,6 +220,8 @@ void PassUp(const SplitNode* s,
*parent = IntSet::make_range(dom_map.at(s->parent));
return;
}
CHECK(outer.defined());
CHECK(inner.defined());
// copy construct
auto n = std::make_shared<IntSetNode>(*(inner.operator->()));
......@@ -228,7 +230,6 @@ void PassUp(const SplitNode* s,
n->base = Range::make_with_min_extent(
AsNumber(outer) * s->factor + inner->base->min,
inner->base->extent);
*parent = IntSet(n);
} else {
// default use all domains in the data.
n->domain.push_back(outer->base);
......@@ -238,6 +239,7 @@ void PassUp(const SplitNode* s,
n->stride.push_back(outer->stride[i] * s->factor);
}
}
*parent = IntSet(n);
}
void PassUp(const FuseNode* s,
......
import tvm
def test_bound_inference():
def test_bound1():
m = tvm.Var('m')
l = tvm.Var('l')
A = tvm.placeholder((m, l), name='A')
......@@ -12,8 +12,42 @@ def test_bound_inference():
sA1.compute_at(sA2, xo)
bounds = tvm.schedule.InferBound(sA2)
assert isinstance(bounds, tvm.collections.Map)
print(bounds[A1.op.dim_var[0]])
print(bounds[A1.op.dim_var[1]])
assert(bounds[A1.op.dim_var[0]].extent.value == 8)
def test_bound2():
m = tvm.Var('m')
l = tvm.Var('l')
A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
sA1 = tvm.Schedule(A1.op)
sA2 = tvm.Schedule(A2.op)
xo, yo, xi, yi = sA2.tile(A2.op.dim_var[0], A2.op.dim_var[1], 8, 8)
sA1.compute_at(sA2, yo)
bounds = tvm.schedule.InferBound(sA2)
assert isinstance(bounds, tvm.collections.Map)
assert(bounds[A1.op.dim_var[0]].extent.value == 8)
assert(bounds[A1.op.dim_var[1]].extent.value == 8)
def test_bound3():
m = tvm.Var('m')
l = tvm.Var('l')
A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
sA1 = tvm.Schedule(A1.op, scope="shared")
sA2 = tvm.Schedule(A2.op)
thread_x = tvm.IterVar((0, 16), thread_tag="threadIdx.x")
xo, xi = sA2.split(A2.op.dim_var[0], 32)
xi0, xi1 = sA2.split(xi, outer=thread_x)
yo, yi = sA2.split(A2.op.dim_var[1], 16)
sA2.reorder(xo, xi0, yo, xi1, yi)
sA1.compute_at(sA2, yo)
bounds = tvm.schedule.InferBound(sA2)
assert isinstance(bounds, tvm.collections.Map)
assert(bounds[A1.op.dim_var[0]].extent.value==32)
assert(bounds[A1.op.dim_var[1]].extent.value==16)
def test_create_read_graph():
......@@ -31,5 +65,7 @@ def test_create_read_graph():
if __name__ == "__main__":
test_bound_inference()
test_bound3()
test_bound1()
test_bound2()
test_create_read_graph()
......@@ -34,6 +34,16 @@ def test_reorder():
sch_T.reorder(*order)
assert tuple(sch_T.leaf_iter_vars) == order
def test_split():
m = tvm.Var('m')
A = tvm.placeholder((m,), name='A')
T = tvm.compute((m,), lambda i: A[i])
sT = tvm.Schedule(T.op)
xo, xi = sT.split(T.op.dim_var[0], factor=10)
assert tuple(sT.leaf_iter_vars) == (xo, xi)
def test_tile():
m = tvm.Var('m')
n = tvm.Var('n')
......@@ -42,9 +52,10 @@ def test_tile():
sch_T = tvm.Schedule(T.op, scope="shared")
xo, yo, xi, yi = sch_T.tile(T.op.dim_var[0], T.op.dim_var[1], x_factor=10, y_factor=5)
assert tuple(sch_T.leaf_iter_vars) == (xi, yi, xo, yo)
assert tuple(sch_T.leaf_iter_vars) == (xo, yo, xi, yi)
if __name__ == "__main__":
test_schedule_create()
test_reorder()
test_tile()
test_split()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment