Commit 1e78d41c by Wuwei Lin Committed by Tianqi Chen

Mutate free variables in CommReducer in cache_write (#2354)

parent 8dd928c7
......@@ -35,6 +35,40 @@ class VarReplacer : public ir::IRMutator {
return e;
}
ir::CommReducer MutateCommReducer(ir::CommReducer combiner) {
// Replace free variables in combiner
auto new_identity = ir::UpdateArray(combiner->identity_element, [this] (const Expr& e) {
return this->Mutate(e);
});
auto new_result = ir::UpdateArray(combiner->result, [this] (const Expr& e) {
return this->Mutate(e);
});
if (combiner->identity_element.same_as(new_identity) &&
combiner->identity_element.same_as(new_result)) {
return combiner;
} else {
return ir::CommReducerNode::make(
combiner->lhs, combiner->rhs, new_result, new_identity);
}
}
Expr Mutate_(const ir::Reduce* op, const Expr& e) {
Expr new_e = IRMutator::Mutate_(op, e);
const ir::Reduce* new_reduce = new_e.as<ir::Reduce>();
ir::CommReducer new_combiner = MutateCommReducer(op->combiner);
if (op->combiner.same_as(new_combiner)) {
return new_e;
} else {
return ir::Reduce::make(
new_combiner,
new_reduce->source,
new_reduce->axis,
new_reduce->condition,
new_reduce->value_index);
}
}
private:
const std::unordered_map<const Variable*, Expr>& vsub_;
};
......
......@@ -420,8 +420,22 @@ def test_loop_dep_reduce():
f = tvm.build(s, [X, Y])
def test_loop_dep_reduce_cache_write():
X = tvm.placeholder(shape=(10,), name="x")
def f(n):
rv = tvm.reduce_axis((0, n))
init = lambda dtype: tvm.select(n > 1, tvm.const(0, dtype), n.astype(dtype))
sum = tvm.comm_reducer(lambda x, y: tvm.max(x + y, n.astype('float32')), init, name='sum')
return sum(X[rv], axis=rv)
Y = tvm.compute(X.shape, f, name="y")
s = tvm.create_schedule([Y.op])
s.cache_write(Y, 'local')
f = tvm.build(s, [X, Y])
if __name__ == "__main__":
test_loop_dep_reduce()
test_loop_dep_reduce_cache_write()
test_schedule_middle_cache()
test_inline_multi_reduce()
test_schedule_cache_relayout4()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment