Commit 547a0913 by Sergei Grechanik Committed by Tianqi Chen

[TVM] Reduction simplification improvements (#2284)

parent 9d20fa1b
...@@ -781,12 +781,138 @@ T Simplify_(T a, Map<Var, Range> vrange) { ...@@ -781,12 +781,138 @@ T Simplify_(T a, Map<Var, Range> vrange) {
} }
/*!
* \brief Simplify just the combiner of the given reduce node.
*
* This function applies Simplify to the components of the top reduction's
* combiner, but not to the source or condition of the reduction.
* It also removes all components which are not used to
* compute the resulting value (the value_index-th value).
*
* If \p expr is not a reduction node, it is left unchanged.
*
* \param expr The expression to be simplifed.
* \return Simplified expression.
*/
Expr SimplifyCombiner(const Expr& expr, const Map<Var, Range>& vrange = Map<Var, Range>()) {
const Reduce* op = expr.as<Reduce>();
if (!op) {
return expr;
}
// First simplify the results
Array<Expr> simplified_result;
for (const auto& res : op->combiner->result) {
simplified_result.push_back(Simplify(res, vrange));
}
// Which components to keep
std::vector<int> used(op->combiner->result.size(), false);
// This function recursively marks the used components starting from
// the index idx
std::function<void(int)> mark_used;
mark_used = [&used, &simplified_result, op, &mark_used](size_t idx) {
// if the idx-th component was marked as used before, do nothing
if (used[idx]) return;
used[idx] = true;
// check if the idx-th result expr uses some lhs or rhs variables
// and recursively mark the corresponding components
for (size_t i = 0; i < simplified_result.size(); ++i)
if (!used[i]) {
if (ExprUseVar(simplified_result[idx], op->combiner->lhs[i]) ||
ExprUseVar(simplified_result[idx], op->combiner->rhs[i]))
mark_used(i);
}
};
// mark all used components starting from the value_index
mark_used(op->value_index);
// components which have side effects should also be preserved
for (size_t i = 0; i < used.size(); ++i) {
if (HasSideEffect(op->source[i]) || HasSideEffect(op->combiner->identity_element[i]) ||
HasSideEffect(op->combiner->result[i])) {
mark_used(i);
}
}
int new_value_index = op->value_index;
Array<Expr> new_result;
Array<Expr> new_identity;
Array<Var> new_lhs;
Array<Var> new_rhs;
Array<Expr> new_source;
// new stuff is old stuff which is used
for (size_t i = 0; i < used.size(); ++i) {
if (used[i]) {
// We simplify the result and identity, but not the source
new_result.push_back(simplified_result[i]);
new_identity.push_back(Simplify(op->combiner->identity_element[i], vrange));
new_lhs.push_back(op->combiner->lhs[i]);
new_rhs.push_back(op->combiner->rhs[i]);
new_source.push_back(op->source[i]);
} else if (static_cast<int>(i) < op->value_index) {
// value_index should also be adjusted
new_value_index--;
}
}
CommReducer new_combiner = CommReducerNode::make(new_lhs, new_rhs, new_result, new_identity);
return Reduce::make(new_combiner, new_source, op->axis, op->condition, new_value_index);
}
/*!
* \brief Remove a single reduction over empty axis.
*
* If \p e is a reduction node and its axis is empty, replace it with its source,
* otherwise return \p e unchanged.
*
* \param e The expression to be transformed.
* \return The transformed expression.
*/
Expr RemoveEmptyReduction(const Expr& e) {
const Reduce* r = e.as<Reduce>();
if (r && r->axis.empty()) {
// Note that here we assume that the identity element is indeed identity. Without this
// assumption we would have to perform a single iteration of the loop, i.e. use
// `(*r->combiner.get())(r->combiner->identity_element, r->source)[r->value_index]`
// instead of `r->source[r->value_index]`. The former may be more difficult to simplify.
return Select::make(r->condition,
r->source[r->value_index],
r->combiner->identity_element[r->value_index]);
}
return e;
}
Expr Simplify(Expr a, Map<Var, Range> vrange) { Expr Simplify(Expr a, Map<Var, Range> vrange) {
// We should not pass an expression having a non-HalideIR op to // We should not pass an expression having a non-HalideIR op to
// Halide::Internal::simplify. Reduce op is the only such op at this time // Halide::Internal::simplify. Reduce op is the only such op at this time
// and it only appears as the top op in an expression. So we strip it // and it only appears as the top op in an expression. So we strip it
// first and send the sub-expressions to the simplifier. // first and send the sub-expressions to the simplifier.
if (const Reduce* r = a.as<Reduce>()) { if (const Reduce* r = a.as<Reduce>()) {
// If axis is empty, we can remove the reduce op completely.
if (r->axis.empty())
return Simplify_(RemoveEmptyReduction(a), vrange);
// Simplify the combiner of the reduction
a = SimplifyCombiner(a, vrange);
r = a.as<Reduce>();
// If axis is not empty then we add the information about ranges to vrange
for (const IterVar& iv : r->axis) {
if (vrange.count(iv->var)) {
Range existing_range = vrange[iv->var];
CHECK(Equal(existing_range->min, iv->dom->min) &&
Equal(existing_range->extent, iv->dom->extent))
<< "Simplify was given vrange stating that the range of the reduction var "
<< iv << " is " << existing_range << ". This is probably a mistake.";
}
vrange.Set(iv->var, iv->dom);
}
Array<Expr> new_source; Array<Expr> new_source;
for (auto& e : r->source) { for (auto& e : r->source) {
new_source.push_back(Simplify_(e, vrange)); new_source.push_back(Simplify_(e, vrange));
......
import tvm import tvm
import numpy import numpy
from tvm import comm_reducer
from tvm.ir_pass import Simplify, CanonicalSimplify, Equal
def test_simplify(): def test_simplify():
"""Not yet working, mock design""" """Not yet working, mock design"""
...@@ -52,8 +54,90 @@ def test_canonical(): ...@@ -52,8 +54,90 @@ def test_canonical():
ret2 = tvm.ir_pass.CanonicalSimplify(x % 3 + x % 4) ret2 = tvm.ir_pass.CanonicalSimplify(x % 3 + x % 4)
assert (tvm.ir_pass.Equal(ret1, ret2)) assert (tvm.ir_pass.Equal(ret1, ret2))
def test_simplify_combiner():
dummy = tvm.var('dummy')
prod = comm_reducer(lambda x, y: x*y, lambda t0: tvm.const(1, t0))
sum_or_prod = comm_reducer(lambda x, y: tvm.expr.Select(dummy < 0,
x + y, x*y),
lambda t0: tvm.expr.Select(dummy < 0,
tvm.const(0, t0), tvm.const(1, t0)))
sum_and_prod = comm_reducer(lambda x, y: (x[0] + y[0],
x[1]*y[1]),
lambda t0, t1: (tvm.const(0, t0),
tvm.const(5, t0) - tvm.const(4, t0)))
sum_and_prod2 = comm_reducer(lambda x, y: (x[0] + y[0],
x[1]*y[1] + 0*x[0] + y[0] - y[0]),
lambda t0, t1: (tvm.const(5, t0) - tvm.const(5, t0),
tvm.const(1, t1)))
some_reducer1 = comm_reducer(lambda x, y: (x[0] + y[0],
x[0] + y[0] + x[1] + y[1],
x[0]*y[2] + y[0]*x[2],
x[1] + y[2],
4.0),
lambda t0, t1, t2, t3, t4: (tvm.const(0, t0),
tvm.const(1, t1),
tvm.const(2, t2),
tvm.const(3, t3),
tvm.const(4, t4)))
k = tvm.reduce_axis((0, 10), name="k")
A = tvm.placeholder((10,), name='A')
# Test that SimplifyCombiner makes use of vranges
vrange = {dummy: tvm.Range(-10, -5)}
assert Equal(Simplify(sum_or_prod(A[k], k), vrange), tvm.sum(A[k], k))
vrange = {dummy: tvm.Range(5, 10)}
assert Equal(Simplify(sum_or_prod(A[k], k), vrange), prod(A[k], k))
assert Equal(Simplify(sum_and_prod((A[k], A[10-k]), k)[0]), tvm.sum(A[k], k))
assert Equal(Simplify(sum_and_prod((A[k], A[10-k]), k)[1]), prod(A[10-k], k))
assert Equal(Simplify(sum_and_prod2((A[k], A[10-k]), k)[0]), tvm.sum(A[k], k))
assert Equal(Simplify(sum_and_prod2((A[k], A[10-k]), k)[1]), prod(A[10-k], k))
reference_simplified_sources = [[A[0]],
[A[0], A[1]],
[A[0], A[2]],
[A[0], A[1], A[2], A[3]],
[A[4]]]
for j in range(5):
# Here we use the j-th component of the result, so only it and the components it
# depends on are left.
simplified = Simplify(some_reducer1((A[0], A[1], A[2], A[3], A[4]), k)[j])
# Check that the remaining components are the expected ones.
for lhs, rhs in zip(simplified.source, reference_simplified_sources[j]):
assert Equal(lhs, rhs)
# Test that components with side effects are not removed
side_effect = lambda *xs: tvm.make.Call("int32", "dummy", xs, tvm.expr.Call.Intrinsic, None, 0)
assert Equal(Simplify(sum_and_prod((A[k], side_effect(A[10-k])), k)[0]),
sum_and_prod((A[k], side_effect(A[10-k])), k)[0])
assert Equal(Simplify(sum_and_prod((side_effect(A[k]), A[10-k]), k)[0]),
tvm.sum(side_effect(A[k]), k))
def test_simplify_reduce():
k = tvm.reduce_axis((0, 10), name="k")
j = tvm.reduce_axis((-5, 3), name="j")
A = tvm.placeholder((10,), name='A')
assert Equal(Simplify(tvm.sum(k/10, k)), tvm.sum(tvm.const(0, "int32"), k))
assert Equal(Simplify(tvm.sum(A[3], [])), A[3])
assert Equal(Simplify(tvm.sum(tvm.expr.Select(k + j < 12, k + j, 0), [k, j])),
tvm.sum(k + j, [k, j]))
if __name__ == "__main__": if __name__ == "__main__":
test_bound() test_bound()
test_basic() test_basic()
test_simplify() test_simplify()
test_canonical() test_canonical()
test_simplify_combiner()
test_simplify_reduce()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment