Commit 2f4a5ad9 by Tianqi Chen Committed by GitHub

[SCHEDULE] Further fix of reduce inline with multiple outputs (#508)

parent f631fb43
...@@ -24,6 +24,13 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -24,6 +24,13 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE(ComputeOpNode); TVM_REGISTER_NODE_TYPE(ComputeOpNode);
inline bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) {
return (a->combiner.same_as(b->combiner)) &&
(a->source.same_as(b->source)) &&
(a->axis.same_as(b->axis)) &&
(a->condition.same_as(b->condition));
}
int ComputeOpNode::num_outputs() const { int ComputeOpNode::num_outputs() const {
return body.size(); return body.size();
} }
...@@ -98,13 +105,6 @@ Array<Tensor> compute(Array<Expr> shape, ...@@ -98,13 +105,6 @@ Array<Tensor> compute(Array<Expr> shape,
return outputs; return outputs;
} }
inline bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) {
return (a->combiner.same_as(b->combiner)) &&
(a->source.same_as(b->source)) &&
(a->axis.same_as(b->axis)) &&
(a->condition.same_as(b->condition));
}
Operation ComputeOpNode::make(std::string name, Operation ComputeOpNode::make(std::string name,
std::string tag, std::string tag,
Array<IterVar> axis, Array<IterVar> axis,
...@@ -151,9 +151,35 @@ Operation ComputeOpNode::ReplaceInputs( ...@@ -151,9 +151,35 @@ Operation ComputeOpNode::ReplaceInputs(
const Operation& self, const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const { const std::unordered_map<Tensor, Tensor>& rmap) const {
CHECK_EQ(self.operator->(), this); CHECK_EQ(self.operator->(), this);
Array<Expr> arr = UpdateArray(this->body, [&rmap] (const Expr& e) { Array<Expr> arr;
return op::ReplaceTensor(e, rmap); if (this->body[0]->is_type<ir::Reduce>()) {
}); // Specially handle reduce so the replaced op
// still share all the components
const ir::Reduce* reduce = this->body[0].as<ir::Reduce>();
for (size_t i = 1; i < this->body.size(); ++i) {
const ir::Reduce* reduce_ = this->body[i].as<ir::Reduce>();
CHECK(reduce_);
CHECK(ReduceEqual(reduce_, reduce))
<< "The Reduce inputs of ComputeOp should "
<< "have the same attribute except value_index";
}\
Expr new_reduce = op::ReplaceTensor(this->body[0], rmap);
if (!new_reduce.same_as(this->body[0])) {
const ir::Reduce* r = new_reduce.as<ir::Reduce>();
for (size_t k = 0; k < this->body.size(); ++k) {
std::shared_ptr<ir::Reduce> n = std::make_shared<ir::Reduce>(*r);
n->value_index = static_cast<int>(k);
n->type = r->source[k].type();
arr.push_back(Expr(n));
}
} else {
arr = this->body;
}
} else {
arr = UpdateArray(this->body, [&rmap] (const Expr& e) {
return op::ReplaceTensor(e, rmap);
});
}
if (!arr.same_as(this->body)) { if (!arr.same_as(this->body)) {
return ComputeOpNode::make(name, tag, axis, arr); return ComputeOpNode::make(name, tag, axis, arr);
} else { } else {
......
...@@ -162,6 +162,7 @@ class TensorReplacer : public ir::IRMutator { ...@@ -162,6 +162,7 @@ class TensorReplacer : public ir::IRMutator {
public: public:
explicit TensorReplacer(const std::unordered_map<Tensor, Tensor>& vmap) explicit TensorReplacer(const std::unordered_map<Tensor, Tensor>& vmap)
: vmap_(vmap) {} : vmap_(vmap) {}
Expr Mutate_(const ir::Call* op, const Expr& e) { Expr Mutate_(const ir::Call* op, const Expr& e) {
if (op->call_type == ir::Call::Halide) { if (op->call_type == ir::Call::Halide) {
Tensor t = Operation(op->func.node_).output(op->value_index); Tensor t = Operation(op->func.node_).output(op->value_index);
......
...@@ -68,16 +68,18 @@ def test_inline_multi_reduce(): ...@@ -68,16 +68,18 @@ def test_inline_multi_reduce():
m = tvm.var('m') m = tvm.var('m')
n = tvm.var('n') n = tvm.var('n')
val = tvm.placeholder((m, n), name='val', dtype='float32') val = tvm.placeholder((m, n), name='val', dtype='float32')
val2 = tvm.compute((m, n), lambda i, j: tvm.exp(val[i, j]), name='val2') val1 = tvm.compute((m, n), lambda i, j: val[i, j]+1, name='val1')
val2 = tvm.compute((m, n), lambda i, j: tvm.exp(val1[i, j]), name='val2')
k = tvm.reduce_axis((0, n), 'k') k = tvm.reduce_axis((0, n), 'k')
T_idx, T_val = tvm.compute((m, ), lambda i: argmax((k.var, val2[i, k]), axis=k), name='T') T_idx, T_val = tvm.compute((m, ), lambda i: argmax((k.var, val2[i, k]), axis=k), name='T')
s = tvm.create_schedule(T_idx.op) s = tvm.create_schedule(T_idx.op)
s[val2].compute_inline() s[val1].compute_inline()
s = s.normalize() s = s.normalize()
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_auto_inline(): def test_auto_inline():
m = tvm.var('m') m = tvm.var('m')
n = tvm.var('n') n = tvm.var('n')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment