Commit 3b8ad0a2 by Tianqi Chen Committed by GitHub

[SCHEDULE] Normalize returns a new schedule (#94)

parent f9d604dd
......@@ -191,6 +191,11 @@ class Schedule : public NodeRef {
*/
explicit Schedule(Array<Operation> ops);
/*!
* \brief Get a copy of current schedule.
* \return The copied schedule.
*/
Schedule copy() const;
/*!
* \brief Get the stage corresponds to the op
* \param op The operation.
*/
......@@ -257,7 +262,7 @@ class Schedule : public NodeRef {
*
* \return A normalized schedule, can be same as current one.
*/
void normalize();
Schedule normalize();
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
......
......@@ -57,7 +57,7 @@ def lower(sch,
else:
raise ValueError("args must be Tensor, Buffer or Var")
# normalize schedule first
sch.normalize()
sch = sch.normalize()
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.StorageFlatten(stmt, binds)
......
......@@ -78,12 +78,17 @@ class Schedule(NodeBase):
return self.stage_map[k]
def normalize(self):
"""Build a normalized schedule.
"""Build a normalized schedule from the current schedule.
Insert necessary rebase to make certain iter var to start from 0.
This is needed before bound inference and followup step.
Returns
-------
sch : Schedule
The normalized schedule.
"""
_api_internal._ScheduleNormalize(self)
return _api_internal._ScheduleNormalize(self)
def create_group(self, outputs, inputs, include_inputs=False):
"""Create stage group by giving output and input boundary.
......@@ -261,7 +266,7 @@ class Stage(NodeBase):
threads : list of threads
The threads to be launched.
"""
if isinstance(threads, _collections.IterVar):
if isinstance(threads, IterVar):
threads = [threads]
_api_internal._StageEnvThreads(self, threads)
......
......@@ -311,7 +311,7 @@ TVM_REGISTER_API(_StageParallel)
TVM_REGISTER_API(_ScheduleNormalize)
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Schedule()
*ret = args[0].operator Schedule()
.normalize();
});
......
......@@ -242,9 +242,11 @@ void InjectInline(ScheduleNode* sch) {
ReplaceDataFlow(sch->stages, &repl);
}
void Schedule::normalize() {
InjectInline(operator->());
RebaseNonZeroMinLoop(*this);
Schedule Schedule::normalize() {
Schedule sn = copy();
InjectInline(sn.operator->());
RebaseNonZeroMinLoop(sn);
return sn;
}
// Handle reduction factor.
......
......@@ -355,6 +355,52 @@ Schedule::Schedule(Array<Operation> ops) {
}
}
Stage CopyStage(const Stage& s) {
std::shared_ptr<StageNode> n =
std::make_shared<StageNode>(*s.operator->());
return Stage(n);
}
Schedule Schedule::copy() const {
// map of stages.
const ScheduleNode* self = operator->();
std::unordered_map<Stage, Stage, NodeHash, NodeEqual> smap;
std::shared_ptr<ScheduleNode> n = std::make_shared<ScheduleNode>();
n->outputs = self->outputs;
// Copy the stages.
for (Stage s : self->stages) {
Stage scopy = CopyStage(s);
smap[s] = scopy;
n->stages.push_back(scopy);
}
for (Stage g : self->groups) {
Stage gcopy = CopyStage(g);
smap[g] = gcopy;
n->groups.push_back(gcopy);
}
// Remaps the reference relations.
for (auto kv : self->stage_map) {
n->stage_map.Set(kv.first, smap.at(kv.second));
}
for (Stage s : n->stages) {
if (s->attach_stage.defined()) {
s->attach_stage = smap.at(s->attach_stage);
}
if (s->group.defined()) {
s->group = smap.at(s->group);
}
}
for (Stage s : n->groups) {
if (s->attach_stage.defined()) {
s->attach_stage = smap.at(s->attach_stage);
}
if (s->group.defined()) {
s->group = smap.at(s->group);
}
}
return Schedule(n);
}
Stage Schedule::operator[](const Operation& op) {
auto it = (*this)->stage_map.find(op);
CHECK(it != (*this)->stage_map.end())
......
......@@ -10,7 +10,7 @@ def lower(s, args, name="mydot"):
buf = tvm.decl_buffer(x.shape, dtype=x.dtype, name=x.op.name)
binds[x] = buf
arg_list.append(buf)
s.normalize()
s = s.normalize()
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt = tvm.ir_pass.StorageFlatten(stmt, binds)
......
......@@ -60,7 +60,7 @@ def test_gemm():
max_auto_unroll_step = 0
# lowering test
s.normalize()
s = s.normalize()
# one line to build the function.
def check_device(device, host="stackvm"):
......
......@@ -16,6 +16,7 @@ def test_add_pipeline():
s[C].bind(xi, tvm.thread_axis("blockIdx.x"))
# compile to IR
s = s.normalize()
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
......
......@@ -22,6 +22,8 @@ def test_bound2():
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
s = tvm.create_schedule(A2.op)
xo, yo, xi, yi = s[A2].tile(A2.op.axis[0], A2.op.axis[1], 8, 8)
# test normalize not affecting schedule
_ = s.normalize()
s[A1].compute_at(s[A2], yo)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
......@@ -41,6 +43,8 @@ def test_bound3():
xi0, xi1 = s[A2].split(xi, nparts=16)
s[A2].bind(xi0, tvm.thread_axis("threadIdx.x"))
yo, yi = s[A2].split(A2.op.axis[1], 16)
# test normalize not affecting schedule
_ = s.normalize()
s[A2].reorder(xo, xi0, yo, xi1, yi)
s[A1].compute_at(s[A2], yo)
......@@ -63,7 +67,7 @@ def test_bound_scan():
XX = s.cache_read(X, "local", s_update)
xo, xi = s[s_update].split(s_update.op.axis[1], factor=4)
s[XX].compute_at(s[s_update], xo)
s.normalize()
s = s.normalize()
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
assert bounds[XX.op.axis[1]].extent.value == 4
......@@ -77,7 +81,7 @@ def test_bound_conv1d():
B = tvm.compute(n, computeB, name='B')
s = tvm.create_schedule(B.op)
s[A].compute_at(s[B], B.op.axis[0])
s.normalize()
s = s.normalize()
bounds = tvm.schedule.InferBound(s)
assert(bounds[A.op.axis[0]].extent.value == 3)
......@@ -92,7 +96,7 @@ def test_bound_blur():
B = tvm.compute((n-2, n-2), computeB, name='B')
s = tvm.create_schedule(B.op)
s[A].compute_at(s[B], B.op.axis[1])
s.normalize()
s = s.normalize()
bounds = tvm.schedule.InferBound(s)
assert(bounds[A.op.axis[0]].extent.value == 3)
assert(bounds[A.op.axis[1]].extent.value == 3)
......@@ -106,7 +110,7 @@ def test_bound_rfactor():
s = tvm.create_schedule(B.op)
kf, ki = s[B].split(k, nparts=4)
BF = s.rfactor(B, kf)
s.normalize()
s = s.normalize()
bounds = tvm.schedule.InferBound(s)
assert(bounds[BF.op.axis[0]].extent.value == 4)
......@@ -123,7 +127,7 @@ def test_bound_group_schedule():
g.compute_at(s[x2], x2.op.axis[0])
assert s[x1].group == g
assert s[x].group == g
s.normalize()
s = s.normalize()
bounds = tvm.schedule.InferBound(s)
assert bounds[x.op.axis[0]].extent.value == 1
assert bounds[x.op.axis[1]].extent == n
......@@ -141,7 +145,7 @@ def test_bound_nest_group():
assert s[x1].group == g2
g2.compute_at(s[x2], x2.op.axis[0])
g1.compute_at(s[x1], s[x1].op.axis[1])
s.normalize()
s = s.normalize()
bounds = tvm.schedule.InferBound(s)
assert bounds[x.op.axis[0]].extent.value == 1
assert bounds[x.op.axis[1]].extent.value == 1
......@@ -169,7 +173,7 @@ def test_bound_nest_thread():
_, xi = s[A2].split(A2.op.axis[0], nparts=1)
s[A2].bind(xi, thread_x)
s[A1].compute_at(s[A3], tx)
s.normalize()
s = s.normalize()
bounds = tvm.schedule.InferBound(s)
assert(bounds[A1.op.axis[0]].extent.value==1)
assert(bounds[A2.op.axis[0]].extent.value==32)
......@@ -225,7 +229,7 @@ def test_gemm_bound():
tx, xi = s[BB].split(xi, nparts=num_thread)
s[BB].bind(ty, thread_y)
s[BB].bind(tx, thread_x)
s.normalize()
s = s.normalize()
bounds = tvm.schedule.InferBound(s)
assert(bounds[BB.op.axis[0]].extent.value==64)
assert(bounds[AA.op.axis[0]].extent.value==64)
......
......@@ -51,7 +51,7 @@ def test_schedule_scan():
assert tuple(res.shape) == (m, n)
s = tvm.create_schedule(res.op)
s.normalize()
s = s.normalize()
bounds = tvm.schedule.InferBound(s)
assert(bounds[res.op.scan_axis].min.value == 1)
stmt = tvm.schedule.ScheduleOps(s, bounds)
......@@ -68,7 +68,7 @@ def test_auto_inline():
s = tvm.create_schedule(T2.op)
tvm.schedule.AutoInlineElemWise(s)
s.normalize()
s = s.normalize()
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
......@@ -83,7 +83,7 @@ def test_inline_mixed():
xo, xi = s[C].split(C.op.axis[0], factor=8)
s[A1].compute_at(s[C], xo)
s[A2].compute_inline()
s.normalize()
s = s.normalize()
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
print(stmt)
......
......@@ -11,7 +11,7 @@ def lower(s, args, name):
buf = tvm.decl_buffer(x.shape, dtype=x.dtype, name=x.op.name)
binds[x] = buf
arg_list.append(buf)
s.normalize()
s = s.normalize()
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt = tvm.ir_pass.StorageFlatten(stmt, binds)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment