Commit f631fb43 by Tianqi Chen Committed by GitHub

[SCHEDULE] Fix inline with multiple outputs (#507)

parent af8cbdde
...@@ -98,7 +98,7 @@ Array<Tensor> compute(Array<Expr> shape, ...@@ -98,7 +98,7 @@ Array<Tensor> compute(Array<Expr> shape,
return outputs; return outputs;
} }
bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) { inline bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) {
return (a->combiner.same_as(b->combiner)) && return (a->combiner.same_as(b->combiner)) &&
(a->source.same_as(b->source)) && (a->source.same_as(b->source)) &&
(a->axis.same_as(b->axis)) && (a->axis.same_as(b->axis)) &&
......
...@@ -275,10 +275,17 @@ void RebaseNonZeroMinLoop(const Schedule& sch) { ...@@ -275,10 +275,17 @@ void RebaseNonZeroMinLoop(const Schedule& sch) {
} }
} }
inline bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) {
return (a->combiner.same_as(b->combiner)) &&
(a->source.same_as(b->source)) &&
(a->axis.same_as(b->axis)) &&
(a->condition.same_as(b->condition));
}
void InjectInline(ScheduleNode* sch) { void InjectInline(ScheduleNode* sch) {
sch->InvalidateCache(); sch->InvalidateCache();
std::vector<Array<Expr>> new_body(sch->stages.size()); std::vector<Array<Expr> > new_body(sch->stages.size());
std::vector<bool> changed(sch->stages.size(), false); std::vector<bool> changed(sch->stages.size(), false);
// inline all the ops // inline all the ops
for (size_t i = sch->stages.size(); i != 0; --i) { for (size_t i = sch->stages.size(); i != 0; --i) {
...@@ -286,7 +293,7 @@ void InjectInline(ScheduleNode* sch) { ...@@ -286,7 +293,7 @@ void InjectInline(ScheduleNode* sch) {
if (stage->attach_type == kInline) { if (stage->attach_type == kInline) {
stage->attach_type = kInlinedAlready; stage->attach_type = kInlinedAlready;
Array<Var> args; Array<Var> args;
Array<Expr> body; Expr body;
{ {
// setup args // setup args
const ComputeOpNode* compute = stage->op.as<ComputeOpNode>(); const ComputeOpNode* compute = stage->op.as<ComputeOpNode>();
...@@ -295,7 +302,9 @@ void InjectInline(ScheduleNode* sch) { ...@@ -295,7 +302,9 @@ void InjectInline(ScheduleNode* sch) {
for (auto iv : compute->axis) { for (auto iv : compute->axis) {
args.push_back(iv->var); args.push_back(iv->var);
} }
body = compute->body; CHECK_EQ(compute->body.size(), 1U)
<< "can only inline compute op with 1 output";
body = compute->body[0];
} }
for (size_t j = i; j < sch->stages.size(); ++j) { for (size_t j = i; j < sch->stages.size(); ++j) {
Stage s = sch->stages[j]; Stage s = sch->stages[j];
...@@ -304,10 +313,39 @@ void InjectInline(ScheduleNode* sch) { ...@@ -304,10 +313,39 @@ void InjectInline(ScheduleNode* sch) {
if (!new_body[j].size()) { if (!new_body[j].size()) {
new_body[j] = s->op.as<ComputeOpNode>()->body; new_body[j] = s->op.as<ComputeOpNode>()->body;
} }
for (size_t k = 0; k < body.size(); ++k) { if (new_body[j][0]->is_type<ir::Reduce>()) {
changed[j] = true; // specially handle reduction inline for multiplre reductions.
new_body[j].Set(k, ir::Inline(ir::Evaluate::make(new_body[j][k]), const ir::Reduce* reduce = new_body[j][0].as<ir::Reduce>();
stage->op, args, body[k]).as<ir::Evaluate>()->value); for (size_t k = 1; k < new_body[j].size(); ++k) {
const ir::Reduce* reduce_ = new_body[j][k].as<ir::Reduce>();
CHECK(reduce_);
CHECK(ReduceEqual(reduce_, reduce))
<< "The Reduce inputs of ComputeOp should "
<< "have the same attribute except value_index";
}
Expr new_value = ir::Inline(ir::Evaluate::make(new_body[j][0]),
stage->op, args, body).as<ir::Evaluate>()->value;
if (!new_value.same_as(new_body[j][0])) {
changed[j] = true;
const ir::Reduce* r = new_value.as<ir::Reduce>();
CHECK_EQ(new_body[j].size(), r->source.size());
CHECK(r != nullptr);
for (size_t k = 0; k < new_body[j].size(); ++k) {
std::shared_ptr<ir::Reduce> n = std::make_shared<ir::Reduce>(*r);
n->value_index = static_cast<int>(k);
n->type = r->source[k].type();
new_body[j].Set(k, Expr(n));
}
}
} else {
for (size_t k = 0; k < new_body[j].size(); ++k) {
Expr new_value = ir::Inline(ir::Evaluate::make(new_body[j][k]),
stage->op, args, body).as<ir::Evaluate>()->value;
if (!new_value.same_as(new_body[j][k])) {
new_body[j].Set(k, new_value);
changed[j] = true;
}
}
} }
} }
} }
......
...@@ -56,6 +56,28 @@ def test_schedule_scan(): ...@@ -56,6 +56,28 @@ def test_schedule_scan():
assert(bounds[res.op.scan_axis].min.value == 1) assert(bounds[res.op.scan_axis].min.value == 1)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_inline_multi_reduce():
def argmax_comp(x, y):
idx = tvm.select((x[1] >= y[1]), x[0], y[0])
val = tvm.select((x[1] >= y[1]), x[1], y[1])
return idx, val
def argmax_init(idx_typ, val_typ):
return tvm.const(-1, idx_typ), tvm.min_value(val_typ)
argmax = tvm.comm_reducer(argmax_comp, argmax_init, name='argmax')
m = tvm.var('m')
n = tvm.var('n')
val = tvm.placeholder((m, n), name='val', dtype='float32')
val2 = tvm.compute((m, n), lambda i, j: tvm.exp(val[i, j]), name='val2')
k = tvm.reduce_axis((0, n), 'k')
T_idx, T_val = tvm.compute((m, ), lambda i: argmax((k.var, val2[i, k]), axis=k), name='T')
s = tvm.create_schedule(T_idx.op)
s[val2].compute_inline()
s = s.normalize()
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_auto_inline(): def test_auto_inline():
m = tvm.var('m') m = tvm.var('m')
n = tvm.var('n') n = tvm.var('n')
...@@ -207,6 +229,7 @@ def test_schedule_cache_relayout3(): ...@@ -207,6 +229,7 @@ def test_schedule_cache_relayout3():
if __name__ == "__main__": if __name__ == "__main__":
test_inline_multi_reduce()
test_schedule_cache_relayout3() test_schedule_cache_relayout3()
test_schedule_cache_relayout2() test_schedule_cache_relayout2()
test_schedule_cache_relayout1() test_schedule_cache_relayout1()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment