Commit f631fb43 by Tianqi Chen Committed by GitHub

[SCHEDULE] Fix inline with multiple outputs (#507)

parent af8cbdde
......@@ -98,7 +98,7 @@ Array<Tensor> compute(Array<Expr> shape,
return outputs;
}
bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) {
inline bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) {
return (a->combiner.same_as(b->combiner)) &&
(a->source.same_as(b->source)) &&
(a->axis.same_as(b->axis)) &&
......
......@@ -275,10 +275,17 @@ void RebaseNonZeroMinLoop(const Schedule& sch) {
}
}
inline bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) {
return (a->combiner.same_as(b->combiner)) &&
(a->source.same_as(b->source)) &&
(a->axis.same_as(b->axis)) &&
(a->condition.same_as(b->condition));
}
void InjectInline(ScheduleNode* sch) {
sch->InvalidateCache();
std::vector<Array<Expr>> new_body(sch->stages.size());
std::vector<Array<Expr> > new_body(sch->stages.size());
std::vector<bool> changed(sch->stages.size(), false);
// inline all the ops
for (size_t i = sch->stages.size(); i != 0; --i) {
......@@ -286,7 +293,7 @@ void InjectInline(ScheduleNode* sch) {
if (stage->attach_type == kInline) {
stage->attach_type = kInlinedAlready;
Array<Var> args;
Array<Expr> body;
Expr body;
{
// setup args
const ComputeOpNode* compute = stage->op.as<ComputeOpNode>();
......@@ -295,7 +302,9 @@ void InjectInline(ScheduleNode* sch) {
for (auto iv : compute->axis) {
args.push_back(iv->var);
}
body = compute->body;
CHECK_EQ(compute->body.size(), 1U)
<< "can only inline compute op with 1 output";
body = compute->body[0];
}
for (size_t j = i; j < sch->stages.size(); ++j) {
Stage s = sch->stages[j];
......@@ -304,10 +313,39 @@ void InjectInline(ScheduleNode* sch) {
if (!new_body[j].size()) {
new_body[j] = s->op.as<ComputeOpNode>()->body;
}
for (size_t k = 0; k < body.size(); ++k) {
changed[j] = true;
new_body[j].Set(k, ir::Inline(ir::Evaluate::make(new_body[j][k]),
stage->op, args, body[k]).as<ir::Evaluate>()->value);
if (new_body[j][0]->is_type<ir::Reduce>()) {
// specially handle reduction inline for multiplre reductions.
const ir::Reduce* reduce = new_body[j][0].as<ir::Reduce>();
for (size_t k = 1; k < new_body[j].size(); ++k) {
const ir::Reduce* reduce_ = new_body[j][k].as<ir::Reduce>();
CHECK(reduce_);
CHECK(ReduceEqual(reduce_, reduce))
<< "The Reduce inputs of ComputeOp should "
<< "have the same attribute except value_index";
}
Expr new_value = ir::Inline(ir::Evaluate::make(new_body[j][0]),
stage->op, args, body).as<ir::Evaluate>()->value;
if (!new_value.same_as(new_body[j][0])) {
changed[j] = true;
const ir::Reduce* r = new_value.as<ir::Reduce>();
CHECK_EQ(new_body[j].size(), r->source.size());
CHECK(r != nullptr);
for (size_t k = 0; k < new_body[j].size(); ++k) {
std::shared_ptr<ir::Reduce> n = std::make_shared<ir::Reduce>(*r);
n->value_index = static_cast<int>(k);
n->type = r->source[k].type();
new_body[j].Set(k, Expr(n));
}
}
} else {
for (size_t k = 0; k < new_body[j].size(); ++k) {
Expr new_value = ir::Inline(ir::Evaluate::make(new_body[j][k]),
stage->op, args, body).as<ir::Evaluate>()->value;
if (!new_value.same_as(new_body[j][k])) {
new_body[j].Set(k, new_value);
changed[j] = true;
}
}
}
}
}
......
......@@ -56,6 +56,28 @@ def test_schedule_scan():
assert(bounds[res.op.scan_axis].min.value == 1)
stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_inline_multi_reduce():
def argmax_comp(x, y):
idx = tvm.select((x[1] >= y[1]), x[0], y[0])
val = tvm.select((x[1] >= y[1]), x[1], y[1])
return idx, val
def argmax_init(idx_typ, val_typ):
return tvm.const(-1, idx_typ), tvm.min_value(val_typ)
argmax = tvm.comm_reducer(argmax_comp, argmax_init, name='argmax')
m = tvm.var('m')
n = tvm.var('n')
val = tvm.placeholder((m, n), name='val', dtype='float32')
val2 = tvm.compute((m, n), lambda i, j: tvm.exp(val[i, j]), name='val2')
k = tvm.reduce_axis((0, n), 'k')
T_idx, T_val = tvm.compute((m, ), lambda i: argmax((k.var, val2[i, k]), axis=k), name='T')
s = tvm.create_schedule(T_idx.op)
s[val2].compute_inline()
s = s.normalize()
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_auto_inline():
m = tvm.var('m')
n = tvm.var('n')
......@@ -207,6 +229,7 @@ def test_schedule_cache_relayout3():
if __name__ == "__main__":
test_inline_multi_reduce()
test_schedule_cache_relayout3()
test_schedule_cache_relayout2()
test_schedule_cache_relayout1()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment