Commit 3b8ad0a2 by Tianqi Chen Committed by GitHub

[SCHEDULE] Normalize returns a new schedule (#94)

parent f9d604dd
...@@ -191,6 +191,11 @@ class Schedule : public NodeRef { ...@@ -191,6 +191,11 @@ class Schedule : public NodeRef {
*/ */
explicit Schedule(Array<Operation> ops); explicit Schedule(Array<Operation> ops);
/*! /*!
* \brief Get a copy of current schedule.
* \return The copied schedule.
*/
Schedule copy() const;
/*!
* \brief Get the stage corresponds to the op * \brief Get the stage corresponds to the op
* \param op The operation. * \param op The operation.
*/ */
...@@ -257,7 +262,7 @@ class Schedule : public NodeRef { ...@@ -257,7 +262,7 @@ class Schedule : public NodeRef {
* *
* \return A normalized schedule, can be same as current one. * \return A normalized schedule, can be same as current one.
*/ */
void normalize(); Schedule normalize();
/*! /*!
* \brief access the internal node container * \brief access the internal node container
* \return the pointer to the internal node container * \return the pointer to the internal node container
......
...@@ -57,7 +57,7 @@ def lower(sch, ...@@ -57,7 +57,7 @@ def lower(sch,
else: else:
raise ValueError("args must be Tensor, Buffer or Var") raise ValueError("args must be Tensor, Buffer or Var")
# normalize schedule first # normalize schedule first
sch.normalize() sch = sch.normalize()
bounds = schedule.InferBound(sch) bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds) stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.StorageFlatten(stmt, binds) stmt = ir_pass.StorageFlatten(stmt, binds)
......
...@@ -78,12 +78,17 @@ class Schedule(NodeBase): ...@@ -78,12 +78,17 @@ class Schedule(NodeBase):
return self.stage_map[k] return self.stage_map[k]
def normalize(self): def normalize(self):
"""Build a normalized schedule. """Build a normalized schedule from the current schedule.
Insert necessary rebase to make certain iter var to start from 0. Insert necessary rebase to make certain iter var to start from 0.
This is needed before bound inference and followup step. This is needed before bound inference and followup step.
Returns
-------
sch : Schedule
The normalized schedule.
""" """
_api_internal._ScheduleNormalize(self) return _api_internal._ScheduleNormalize(self)
def create_group(self, outputs, inputs, include_inputs=False): def create_group(self, outputs, inputs, include_inputs=False):
"""Create stage group by giving output and input boundary. """Create stage group by giving output and input boundary.
...@@ -261,7 +266,7 @@ class Stage(NodeBase): ...@@ -261,7 +266,7 @@ class Stage(NodeBase):
threads : list of threads threads : list of threads
The threads to be launched. The threads to be launched.
""" """
if isinstance(threads, _collections.IterVar): if isinstance(threads, IterVar):
threads = [threads] threads = [threads]
_api_internal._StageEnvThreads(self, threads) _api_internal._StageEnvThreads(self, threads)
......
...@@ -311,7 +311,7 @@ TVM_REGISTER_API(_StageParallel) ...@@ -311,7 +311,7 @@ TVM_REGISTER_API(_StageParallel)
TVM_REGISTER_API(_ScheduleNormalize) TVM_REGISTER_API(_ScheduleNormalize)
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Schedule() *ret = args[0].operator Schedule()
.normalize(); .normalize();
}); });
......
...@@ -242,9 +242,11 @@ void InjectInline(ScheduleNode* sch) { ...@@ -242,9 +242,11 @@ void InjectInline(ScheduleNode* sch) {
ReplaceDataFlow(sch->stages, &repl); ReplaceDataFlow(sch->stages, &repl);
} }
void Schedule::normalize() { Schedule Schedule::normalize() {
InjectInline(operator->()); Schedule sn = copy();
RebaseNonZeroMinLoop(*this); InjectInline(sn.operator->());
RebaseNonZeroMinLoop(sn);
return sn;
} }
// Handle reduction factor. // Handle reduction factor.
......
...@@ -355,6 +355,52 @@ Schedule::Schedule(Array<Operation> ops) { ...@@ -355,6 +355,52 @@ Schedule::Schedule(Array<Operation> ops) {
} }
} }
Stage CopyStage(const Stage& s) {
std::shared_ptr<StageNode> n =
std::make_shared<StageNode>(*s.operator->());
return Stage(n);
}
Schedule Schedule::copy() const {
// map of stages.
const ScheduleNode* self = operator->();
std::unordered_map<Stage, Stage, NodeHash, NodeEqual> smap;
std::shared_ptr<ScheduleNode> n = std::make_shared<ScheduleNode>();
n->outputs = self->outputs;
// Copy the stages.
for (Stage s : self->stages) {
Stage scopy = CopyStage(s);
smap[s] = scopy;
n->stages.push_back(scopy);
}
for (Stage g : self->groups) {
Stage gcopy = CopyStage(g);
smap[g] = gcopy;
n->groups.push_back(gcopy);
}
// Remaps the reference relations.
for (auto kv : self->stage_map) {
n->stage_map.Set(kv.first, smap.at(kv.second));
}
for (Stage s : n->stages) {
if (s->attach_stage.defined()) {
s->attach_stage = smap.at(s->attach_stage);
}
if (s->group.defined()) {
s->group = smap.at(s->group);
}
}
for (Stage s : n->groups) {
if (s->attach_stage.defined()) {
s->attach_stage = smap.at(s->attach_stage);
}
if (s->group.defined()) {
s->group = smap.at(s->group);
}
}
return Schedule(n);
}
Stage Schedule::operator[](const Operation& op) { Stage Schedule::operator[](const Operation& op) {
auto it = (*this)->stage_map.find(op); auto it = (*this)->stage_map.find(op);
CHECK(it != (*this)->stage_map.end()) CHECK(it != (*this)->stage_map.end())
......
...@@ -10,7 +10,7 @@ def lower(s, args, name="mydot"): ...@@ -10,7 +10,7 @@ def lower(s, args, name="mydot"):
buf = tvm.decl_buffer(x.shape, dtype=x.dtype, name=x.op.name) buf = tvm.decl_buffer(x.shape, dtype=x.dtype, name=x.op.name)
binds[x] = buf binds[x] = buf
arg_list.append(buf) arg_list.append(buf)
s.normalize() s = s.normalize()
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt = tvm.ir_pass.StorageFlatten(stmt, binds) stmt = tvm.ir_pass.StorageFlatten(stmt, binds)
......
...@@ -60,7 +60,7 @@ def test_gemm(): ...@@ -60,7 +60,7 @@ def test_gemm():
max_auto_unroll_step = 0 max_auto_unroll_step = 0
# lowering test # lowering test
s.normalize() s = s.normalize()
# one line to build the function. # one line to build the function.
def check_device(device, host="stackvm"): def check_device(device, host="stackvm"):
......
...@@ -16,6 +16,7 @@ def test_add_pipeline(): ...@@ -16,6 +16,7 @@ def test_add_pipeline():
s[C].bind(xi, tvm.thread_axis("blockIdx.x")) s[C].bind(xi, tvm.thread_axis("blockIdx.x"))
# compile to IR # compile to IR
s = s.normalize()
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A') Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
......
...@@ -22,6 +22,8 @@ def test_bound2(): ...@@ -22,6 +22,8 @@ def test_bound2():
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2') A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
s = tvm.create_schedule(A2.op) s = tvm.create_schedule(A2.op)
xo, yo, xi, yi = s[A2].tile(A2.op.axis[0], A2.op.axis[1], 8, 8) xo, yo, xi, yi = s[A2].tile(A2.op.axis[0], A2.op.axis[1], 8, 8)
# test normalize not affecting schedule
_ = s.normalize()
s[A1].compute_at(s[A2], yo) s[A1].compute_at(s[A2], yo)
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map) assert isinstance(bounds, tvm.collections.Map)
...@@ -41,6 +43,8 @@ def test_bound3(): ...@@ -41,6 +43,8 @@ def test_bound3():
xi0, xi1 = s[A2].split(xi, nparts=16) xi0, xi1 = s[A2].split(xi, nparts=16)
s[A2].bind(xi0, tvm.thread_axis("threadIdx.x")) s[A2].bind(xi0, tvm.thread_axis("threadIdx.x"))
yo, yi = s[A2].split(A2.op.axis[1], 16) yo, yi = s[A2].split(A2.op.axis[1], 16)
# test normalize not affecting schedule
_ = s.normalize()
s[A2].reorder(xo, xi0, yo, xi1, yi) s[A2].reorder(xo, xi0, yo, xi1, yi)
s[A1].compute_at(s[A2], yo) s[A1].compute_at(s[A2], yo)
...@@ -63,7 +67,7 @@ def test_bound_scan(): ...@@ -63,7 +67,7 @@ def test_bound_scan():
XX = s.cache_read(X, "local", s_update) XX = s.cache_read(X, "local", s_update)
xo, xi = s[s_update].split(s_update.op.axis[1], factor=4) xo, xi = s[s_update].split(s_update.op.axis[1], factor=4)
s[XX].compute_at(s[s_update], xo) s[XX].compute_at(s[s_update], xo)
s.normalize() s = s.normalize()
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
assert bounds[XX.op.axis[1]].extent.value == 4 assert bounds[XX.op.axis[1]].extent.value == 4
...@@ -77,7 +81,7 @@ def test_bound_conv1d(): ...@@ -77,7 +81,7 @@ def test_bound_conv1d():
B = tvm.compute(n, computeB, name='B') B = tvm.compute(n, computeB, name='B')
s = tvm.create_schedule(B.op) s = tvm.create_schedule(B.op)
s[A].compute_at(s[B], B.op.axis[0]) s[A].compute_at(s[B], B.op.axis[0])
s.normalize() s = s.normalize()
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
assert(bounds[A.op.axis[0]].extent.value == 3) assert(bounds[A.op.axis[0]].extent.value == 3)
...@@ -92,7 +96,7 @@ def test_bound_blur(): ...@@ -92,7 +96,7 @@ def test_bound_blur():
B = tvm.compute((n-2, n-2), computeB, name='B') B = tvm.compute((n-2, n-2), computeB, name='B')
s = tvm.create_schedule(B.op) s = tvm.create_schedule(B.op)
s[A].compute_at(s[B], B.op.axis[1]) s[A].compute_at(s[B], B.op.axis[1])
s.normalize() s = s.normalize()
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
assert(bounds[A.op.axis[0]].extent.value == 3) assert(bounds[A.op.axis[0]].extent.value == 3)
assert(bounds[A.op.axis[1]].extent.value == 3) assert(bounds[A.op.axis[1]].extent.value == 3)
...@@ -106,7 +110,7 @@ def test_bound_rfactor(): ...@@ -106,7 +110,7 @@ def test_bound_rfactor():
s = tvm.create_schedule(B.op) s = tvm.create_schedule(B.op)
kf, ki = s[B].split(k, nparts=4) kf, ki = s[B].split(k, nparts=4)
BF = s.rfactor(B, kf) BF = s.rfactor(B, kf)
s.normalize() s = s.normalize()
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
assert(bounds[BF.op.axis[0]].extent.value == 4) assert(bounds[BF.op.axis[0]].extent.value == 4)
...@@ -123,7 +127,7 @@ def test_bound_group_schedule(): ...@@ -123,7 +127,7 @@ def test_bound_group_schedule():
g.compute_at(s[x2], x2.op.axis[0]) g.compute_at(s[x2], x2.op.axis[0])
assert s[x1].group == g assert s[x1].group == g
assert s[x].group == g assert s[x].group == g
s.normalize() s = s.normalize()
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
assert bounds[x.op.axis[0]].extent.value == 1 assert bounds[x.op.axis[0]].extent.value == 1
assert bounds[x.op.axis[1]].extent == n assert bounds[x.op.axis[1]].extent == n
...@@ -141,7 +145,7 @@ def test_bound_nest_group(): ...@@ -141,7 +145,7 @@ def test_bound_nest_group():
assert s[x1].group == g2 assert s[x1].group == g2
g2.compute_at(s[x2], x2.op.axis[0]) g2.compute_at(s[x2], x2.op.axis[0])
g1.compute_at(s[x1], s[x1].op.axis[1]) g1.compute_at(s[x1], s[x1].op.axis[1])
s.normalize() s = s.normalize()
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
assert bounds[x.op.axis[0]].extent.value == 1 assert bounds[x.op.axis[0]].extent.value == 1
assert bounds[x.op.axis[1]].extent.value == 1 assert bounds[x.op.axis[1]].extent.value == 1
...@@ -169,7 +173,7 @@ def test_bound_nest_thread(): ...@@ -169,7 +173,7 @@ def test_bound_nest_thread():
_, xi = s[A2].split(A2.op.axis[0], nparts=1) _, xi = s[A2].split(A2.op.axis[0], nparts=1)
s[A2].bind(xi, thread_x) s[A2].bind(xi, thread_x)
s[A1].compute_at(s[A3], tx) s[A1].compute_at(s[A3], tx)
s.normalize() s = s.normalize()
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
assert(bounds[A1.op.axis[0]].extent.value==1) assert(bounds[A1.op.axis[0]].extent.value==1)
assert(bounds[A2.op.axis[0]].extent.value==32) assert(bounds[A2.op.axis[0]].extent.value==32)
...@@ -225,7 +229,7 @@ def test_gemm_bound(): ...@@ -225,7 +229,7 @@ def test_gemm_bound():
tx, xi = s[BB].split(xi, nparts=num_thread) tx, xi = s[BB].split(xi, nparts=num_thread)
s[BB].bind(ty, thread_y) s[BB].bind(ty, thread_y)
s[BB].bind(tx, thread_x) s[BB].bind(tx, thread_x)
s.normalize() s = s.normalize()
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
assert(bounds[BB.op.axis[0]].extent.value==64) assert(bounds[BB.op.axis[0]].extent.value==64)
assert(bounds[AA.op.axis[0]].extent.value==64) assert(bounds[AA.op.axis[0]].extent.value==64)
......
...@@ -51,7 +51,7 @@ def test_schedule_scan(): ...@@ -51,7 +51,7 @@ def test_schedule_scan():
assert tuple(res.shape) == (m, n) assert tuple(res.shape) == (m, n)
s = tvm.create_schedule(res.op) s = tvm.create_schedule(res.op)
s.normalize() s = s.normalize()
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
assert(bounds[res.op.scan_axis].min.value == 1) assert(bounds[res.op.scan_axis].min.value == 1)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
...@@ -68,7 +68,7 @@ def test_auto_inline(): ...@@ -68,7 +68,7 @@ def test_auto_inline():
s = tvm.create_schedule(T2.op) s = tvm.create_schedule(T2.op)
tvm.schedule.AutoInlineElemWise(s) tvm.schedule.AutoInlineElemWise(s)
s.normalize() s = s.normalize()
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
...@@ -83,7 +83,7 @@ def test_inline_mixed(): ...@@ -83,7 +83,7 @@ def test_inline_mixed():
xo, xi = s[C].split(C.op.axis[0], factor=8) xo, xi = s[C].split(C.op.axis[0], factor=8)
s[A1].compute_at(s[C], xo) s[A1].compute_at(s[C], xo)
s[A2].compute_inline() s[A2].compute_inline()
s.normalize() s = s.normalize()
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
print(stmt) print(stmt)
......
...@@ -11,7 +11,7 @@ def lower(s, args, name): ...@@ -11,7 +11,7 @@ def lower(s, args, name):
buf = tvm.decl_buffer(x.shape, dtype=x.dtype, name=x.op.name) buf = tvm.decl_buffer(x.shape, dtype=x.dtype, name=x.op.name)
binds[x] = buf binds[x] = buf
arg_list.append(buf) arg_list.append(buf)
s.normalize() s = s.normalize()
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt = tvm.ir_pass.StorageFlatten(stmt, binds) stmt = tvm.ir_pass.StorageFlatten(stmt, binds)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment