Commit f2abd9f6 by Andrew Tulloch Committed by Tianqi Chen

[TVM] Rewrite simplification rule to eliminate unnecessary conditionals. (#4076)

The current bounds checking infrastructure inserts checks like:

```
for (i, 0, bounds[n]) {
  if (likely(i < bounds[n]) {
     ...
  }
}
```

into the TVM IR which is currently not removed by simplification infrastructure.
This is a little unclean, as these are trivially true since for a loop var `i`
with a given min and extent, we are guaranteed that `i >= min` and `i < min +
extent`. Thus, we can insert these checks into the IR and use them to eliminate
trivial bounds checks early on.
parent c12275ee
...@@ -245,6 +245,8 @@ class RewriteSimplifier { ...@@ -245,6 +245,8 @@ class RewriteSimplifier {
const Expr& new_expr, const Expr& new_expr,
bool override = false); bool override = false);
std::function<void()> EnterConstraint(const Expr& constraint);
private: private:
friend class Analyzer; friend class Analyzer;
friend class ConstraintContext; friend class ConstraintContext;
......
...@@ -67,8 +67,10 @@ void ConstraintContext::EnterWithScope() { ...@@ -67,8 +67,10 @@ void ConstraintContext::EnterWithScope() {
// entering the scope. // entering the scope.
auto f0 = analyzer_->const_int_bound.EnterConstraint(constraint_); auto f0 = analyzer_->const_int_bound.EnterConstraint(constraint_);
auto f1 = analyzer_->modular_set.EnterConstraint(constraint_); auto f1 = analyzer_->modular_set.EnterConstraint(constraint_);
auto f2 = analyzer_->rewrite_simplify.EnterConstraint(constraint_);
// recovery function. // recovery function.
exit_ = [f0, f1]() { exit_ = [f0, f1, f2]() {
if (f2 != nullptr) f2();
if (f1 != nullptr) f1(); if (f1 != nullptr) f1();
if (f0 != nullptr) f0(); if (f0 != nullptr) f0();
}; };
......
...@@ -220,6 +220,17 @@ Mutate_(const Add* op, const Expr& self) { ...@@ -220,6 +220,17 @@ Mutate_(const Add* op, const Expr& self) {
return ret; return ret;
} }
std::function<void()> RewriteSimplifier::Impl::EnterConstraint(const Expr& constraint) {
size_t old_literal_size = literal_constraints_.size();
literal_constraints_.push_back(constraint);
size_t new_literal_size = literal_constraints_.size();
auto frecover = [old_literal_size, new_literal_size, this]() {
CHECK_EQ(literal_constraints_.size(), new_literal_size);
literal_constraints_.resize(old_literal_size);
};
return frecover;
}
Expr RewriteSimplifier::Impl:: Expr RewriteSimplifier::Impl::
Mutate_(const Sub* op, const Expr& self) { Mutate_(const Sub* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self); Expr ret = IRMutator::Mutate_(op, self);
...@@ -1705,6 +1716,14 @@ Mutate_(const Call* op, const Expr& self) { ...@@ -1705,6 +1716,14 @@ Mutate_(const Call* op, const Expr& self) {
return op->args[0] & op->args[1]; return op->args[0] & op->args[1];
} }
} }
if (op->is_intrinsic(Call::likely)) {
for (const auto& constraint : literal_constraints_) {
// Cases such as for (i, 0, bound) {if (likely(iter_var < bound)) { .. } }
if (Equal(constraint, op->args[0])) {
return make_const(op->type, true);
}
}
}
return ret; return ret;
} }
...@@ -1761,6 +1780,10 @@ void RewriteSimplifier::Update(const Var& var, ...@@ -1761,6 +1780,10 @@ void RewriteSimplifier::Update(const Var& var,
impl_->Update(var, info, override); impl_->Update(var, info, override);
} }
std::function<void()> RewriteSimplifier::EnterConstraint(const Expr& constraint) {
return impl_->EnterConstraint(constraint);
}
RewriteSimplifier::RewriteSimplifier(Analyzer* parent) RewriteSimplifier::RewriteSimplifier(Analyzer* parent)
: impl_(new Impl(parent)) { : impl_(new Impl(parent)) {
} }
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <tvm/expr_operator.h> #include <tvm/expr_operator.h>
#include <tvm/ir_mutator.h> #include <tvm/ir_mutator.h>
#include <unordered_map> #include <unordered_map>
#include <vector>
#include "const_fold.h" #include "const_fold.h"
#include "pattern_match.h" #include "pattern_match.h"
#include "ir_mutator_with_analyzer.h" #include "ir_mutator_with_analyzer.h"
...@@ -74,6 +75,8 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { ...@@ -74,6 +75,8 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
Expr Mutate_(const Cast* op, const Expr& self) override; Expr Mutate_(const Cast* op, const Expr& self) override;
Expr Mutate_(const Let* op, const Expr& self) override; Expr Mutate_(const Let* op, const Expr& self) override;
std::function<void()> EnterConstraint(const Expr& constraint);
protected: protected:
/*! \brief internal structure for comparison. */ /*! \brief internal structure for comparison. */
enum CompareResult { enum CompareResult {
...@@ -89,6 +92,9 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { ...@@ -89,6 +92,9 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
int recur_depth_{0}; int recur_depth_{0};
// internal variable map // internal variable map
std::unordered_map<Var, Expr, ExprHash, ExprEqual> var_map_; std::unordered_map<Var, Expr, ExprHash, ExprEqual> var_map_;
std::vector<Expr> literal_constraints_;
// maximum number of recursion allowed during a single pass. // maximum number of recursion allowed during a single pass.
static const constexpr int kMaxRecurDepth = 5; static const constexpr int kMaxRecurDepth = 5;
......
...@@ -51,6 +51,13 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { ...@@ -51,6 +51,13 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
return Mutate(stmt); return Mutate(stmt);
} }
Stmt Mutate_(const For* op, const Stmt& s) final {
analyzer_->Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent));
With<ConstraintContext> ctx1(analyzer_, op->loop_var >= op->min);
With<ConstraintContext> ctx2(analyzer_, op->loop_var < op->min + op->extent);
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const LetStmt* op, const Stmt& s) { Stmt Mutate_(const LetStmt* op, const Stmt& s) {
Expr value = this->Mutate(op->value); Expr value = this->Mutate(op->value);
if (!ir::HasSideEffect(value)) { if (!ir::HasSideEffect(value)) {
......
...@@ -47,6 +47,62 @@ def test_thread_extent_simplify(): ...@@ -47,6 +47,62 @@ def test_thread_extent_simplify():
assert isinstance(body.body.body.body, tvm.stmt.Store) assert isinstance(body.body.body.body, tvm.stmt.Store)
def test_basic_likely_elimination():
n = tvm.var('n')
X = tvm.placeholder(shape=(n,), name="x")
W = tvm.placeholder(shape=(n + 1,), dtype="int32", name="w")
def f(i):
start = W[i]
extent = W[i+1] - W[i]
rv = tvm.reduce_axis((0, extent))
return tvm.sum(X[rv + start], axis=rv)
Y = tvm.compute(X.shape, f, name="y")
s = tvm.create_schedule([Y.op])
stmt = tvm.lower(s, [X, W, Y], simple_mode=True)
assert('if' not in str(stmt))
def test_complex_likely_elimination():
def cumsum(X):
"""
Y[i] = sum(X[:i])
"""
(m, ) = X.shape
s_state = tvm.placeholder((m + 1, ), dtype="int32", name="state")
s_init = tvm.compute((1, ), lambda _: tvm.const(0, "int32"))
s_update = tvm.compute((m + 1, ), lambda l: s_state[l - 1] + X[l - 1])
return tvm.scan(s_init, s_update, s_state, inputs=[X], name="cumsum")
def sparse_lengths_sum(data, indices, lengths):
oshape = list(data.shape)
oshape[0] = lengths.shape[0]
length_offsets = cumsum(lengths)
def sls(n, d):
gg = tvm.reduce_axis((0, lengths[n]))
indices_idx = length_offsets[n] + gg
data_idx = indices[indices_idx]
data_val = data[data_idx, d]
return tvm.sum(data_val, axis=gg)
return tvm.compute(oshape, sls)
m, n, d, i, l = tvm.var('m'), tvm.var('n'), tvm.var('d'), tvm.var('i'), tvm.var('l')
data_ph = tvm.placeholder((m, d * 32), name="data")
indices_ph = tvm.placeholder((i,), name="indices", dtype="int32")
lengths_ph = tvm.placeholder((n,), name="lengths", dtype="int32")
Y = sparse_lengths_sum(data_ph, indices_ph, lengths_ph)
s = tvm.create_schedule([Y.op])
(n, d) = s[Y].op.axis
(do, di) = s[Y].split(d, factor=32)
(gg,) = s[Y].op.reduce_axis
s[Y].reorder(n, do, gg, di)
s[Y].vectorize(di)
stmt = tvm.lower(s, [data_ph, indices_ph, lengths_ph, Y], simple_mode=True)
assert('if' not in str(stmt))
if __name__ == "__main__": if __name__ == "__main__":
test_stmt_simplify() test_stmt_simplify()
test_thread_extent_simplify() test_thread_extent_simplify()
test_basic_likely_elimination()
test_complex_likely_elimination()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment