Unverified Commit ec95675c by Tianqi Chen Committed by GitHub

[ARITH] Analyzer RewriteSimplifier: add/sub/mul/div/mod (#2722)

parent 0bf64ee0
......@@ -193,6 +193,39 @@ class ModularSetAnalyzer {
};
/*!
* \brief Rewrite-rule based simplifier.
*/
class RewriteSimplifier {
public:
/*!
* \brief analyze the expr
* \param expr The expression of interest.
* \return the result of the analysis.
*/
Expr operator()(const Expr& expr);
/*!
* \brief Update binding of var to a new expression.
*
* \param var The variable of interest.
* \param new_expr
* \param override Whether do we allow override of existing information.
*/
void Update(const Var& var,
const Expr& new_expr,
bool override = false);
private:
friend class Analyzer;
friend class ConstraintContext;
explicit RewriteSimplifier(Analyzer* parent);
~RewriteSimplifier();
class Impl;
/*! \brief Internal impl */
Impl* impl_;
};
/*!
* \brief A RAII constraint context.
*
* \code
......@@ -242,6 +275,8 @@ class Analyzer {
ConstIntBoundAnalyzer const_int_bound;
/*! \brief sub-analyzer: modular set */
ModularSetAnalyzer modular_set;
/*! \brief sub-analyzer rewrite simplfy */
RewriteSimplifier rewrite_simplify;
/*! \brief constructor */
Analyzer();
/*!
......
......@@ -96,6 +96,7 @@ class Analyzer:
self._const_int_bound_update = _mod("const_int_bound_update")
self._bind = _mod("bind")
self._modular_set = _mod("modular_set")
self._rewrite_simplify = _mod("rewrite_simplify")
self._enter_constraint_context = _mod("enter_constraint_context")
def const_int_bound(self, expr):
......@@ -128,6 +129,21 @@ class Analyzer:
"""
return self._modular_set(expr)
def rewrite_simplify(self, expr):
"""Simplify expression via rewriting rules.
Parameters
----------
expr : tvm.Expr
The expression.
Returns
-------
result : Expr
The result.
"""
return self._rewrite_simplify(expr)
def bind(self, var, expr):
"""Bind a variable to the expression.
......
......@@ -98,6 +98,10 @@ TVM_REGISTER_API("arith._CreateAnalyzer")
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
self->const_int_bound.Update(args[0], args[1], args[2]);
});
} else if (name == "rewrite_simplify") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
*ret = self->rewrite_simplify(args[0]);
});
} else if (name == "bind") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
auto& sptr = args[1].node_sptr();
......
......@@ -2,6 +2,7 @@
* Copyright (c) 2019 by Contributors
* \file tvm/arithmetic/analyzer.cc
*/
#include <tvm/ir.h>
#include <tvm/arithmetic.h>
namespace tvm {
......@@ -9,19 +10,22 @@ namespace arith {
Analyzer::Analyzer()
: const_int_bound(this),
modular_set(this) {
modular_set(this),
rewrite_simplify(this) {
}
void Analyzer::Bind(const VarExpr& v, const Expr& expr) {
Var var(v.node_);
this->const_int_bound.Update(var, this->const_int_bound(expr));
this->modular_set.Update(var, this->modular_set(expr));
this->rewrite_simplify.Update(var, this->rewrite_simplify(expr));
}
void Analyzer::Bind(const VarExpr& v, const Range& range) {
Var var(v.node_);
this->const_int_bound.Bind(var, range);
// skip modular_set
// skip rewrite simplify
}
ConstraintContext::ConstraintContext(Analyzer* analyzer, const Expr& constraint) {
......@@ -36,7 +40,10 @@ ConstraintContext::ConstraintContext(Analyzer* analyzer, const Expr& constraint)
}
bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) {
auto bd = this->const_int_bound(expr);
if (const auto* ptr = expr.as<ir::IntImm>()) {
return ptr->value > lower_bound;
}
auto bd = this->const_int_bound(this->rewrite_simplify(expr));
if (bd->min_value >= lower_bound) return true;
return false;
}
......
......@@ -23,7 +23,9 @@ namespace arith {
* \return nullptr if constant fold fails, otherwise return folded result.
*/
template<typename Op>
inline Expr TryConstFold(Expr a, Expr b);
inline Expr TryConstFold(Expr a, Expr b) {
return Expr();
}
/*!
* \brief Try to run unary compute with constant folding.
......
......@@ -49,6 +49,7 @@
#include <tvm/ir_pass.h>
#include <tuple>
#include "const_fold.h"
namespace tvm {
namespace arith {
......@@ -242,7 +243,11 @@ class PBinaryExpr :
}
Expr Eval() const {
return NodeType::make(a_.Eval(), b_.Eval());
Expr lhs = a_.Eval();
Expr rhs = b_.Eval();
Expr ret = TryConstFold<NodeType>(lhs, rhs);
if (ret.defined()) return ret;
return NodeType::make(lhs, rhs);
}
private:
......@@ -250,12 +255,48 @@ class PBinaryExpr :
typename TB::Nested b_;
};
template<typename TA>
class PConstWithTypeLike :
public Pattern<PConstWithTypeLike<TA> > {
public:
PConstWithTypeLike(const TA& ref, int64_t value)
: ref_(ref), value_(value) {}
void InitMatch_() const {}
bool Match_(const NodeRef& node) const {
if (const ir::IntImm* ptr = node.as<ir::IntImm>()) {
return ptr->value == value_;
} else {
return false;
}
}
Expr Eval() const {
return make_const(ref_.Eval().type(), value_);
}
private:
typename TA::Nested ref_;
int64_t value_;
};
#define TVM_PATTERN_BINARY_OP(FuncName, NodeName) \
template<typename TA, typename TB> \
inline PBinaryExpr<NodeName, TA, TB> \
FuncName(const Pattern<TA>& a, const Pattern<TB>& b) { \
#define TVM_PATTERN_BINARY_OP(FuncName, NodeName) \
template<typename TA, typename TB> \
inline PBinaryExpr<NodeName, TA, TB> \
FuncName(const Pattern<TA>& a, const Pattern<TB>& b) { \
return PBinaryExpr<NodeName, TA, TB>(a.derived(), b.derived()); \
} \
template<typename TA> \
inline PBinaryExpr<NodeName, TA, PConstWithTypeLike<TA> > \
FuncName(const Pattern<TA>& a, int64_t b) { \
return FuncName(a, PConstWithTypeLike<TA>(a.derived(), b)); \
} \
template<typename TA> \
inline PBinaryExpr<NodeName, PConstWithTypeLike<TA>, TA> \
FuncName(int64_t b, const Pattern<TA>& a) { \
return FuncName(PConstWithTypeLike<TA>(a.derived(), b), a); \
}
// arithmetic expressions
......
/*!
* Copyright (c) 2019 by Contributors
* \file rewrite_simplify.cc
* \brief Rewrite-rule based simplification.
*/
// Acknowledgement: Most rewrite-rules are from Halide.
#include <tvm/arithmetic.h>
#include <tvm/expr_operator.h>
#include <tvm/ir_mutator.h>
#include "const_fold.h"
#include "pattern_match.h"
namespace tvm {
namespace arith {
using namespace ir;
// macro for doing simple rewrite
#define TVM_TRY_REWRITE(SrcExpr, ResExpr) \
if ((SrcExpr).Match(ret)) { \
return (ResExpr).Eval(); \
}
// macro for rewrite + recursively rewrite ResExpr
#define TVM_TRY_RECURSIVE_REWRITE(SrcExpr, ResExpr) \
if ((SrcExpr).Match(ret)) { \
return RecursiveRewrite((ResExpr).Eval()); \
}
// macro rewrite only if CondExor is true after match.
#define TVM_TRY_REWRITE_IF(SrcExpr, ResExpr, CondExpr) \
if ((SrcExpr).Match(ret) && (CondExpr)) { \
return (ResExpr).Eval(); \
}
// macro rewrite + recursive_rewrite only if CondExor is true after match.
#define TVM_TRY_RECURSIVE_REWRITE_IF(SrcExpr, ResExpr, CondExpr) \
if ((SrcExpr).Match(ret) && (CondExpr)) { \
return RecursiveRewrite((ResExpr).Eval()); \
}
// NOTE for developers:
//
// We mainly focus on index expression simplification.
// Besides the RewriteSimplifier, some cases can be better
// handled by CanonicalSimplifier.
//
class RewriteSimplifier::Impl : public IRMutator {
public:
explicit Impl(Analyzer* parent)
: parent_(parent) {}
void Update(const Var& var,
const Expr& info,
bool override) {
if (!override) {
CHECK(!var_map_.count(var));
}
var_map_[var] = info;
}
// Run simplification in post order
Expr PostOrderSimplify(Expr expr, int max_iter = 2) {
for (int i = 0; i < max_iter; ++i) {
Expr new_expr = this->Mutate(expr);
if (new_expr.same_as(expr)) return expr;
expr = new_expr;
}
return expr;
}
Expr Mutate_(const Add* op, const Expr& self) final;
Expr Mutate_(const Sub* op, const Expr& self) final;
Expr Mutate_(const Mul* op, const Expr& self) final;
Expr Mutate_(const Div* op, const Expr& self) final;
Expr Mutate_(const Mod* op, const Expr& self) final;
private:
// reference to the main analyzer
Analyzer* parent_;
// counter to record recursive rewrite depth.
int recur_depth_{0};
// internal variable map
std::unordered_map<Var, Expr, ExprHash, ExprEqual> var_map_;
// maximum number of recursion allowed during a single pass.
static const constexpr int kMaxRecurDepth = 5;
// Whether x >= val
bool CanProveGreaterEqual(const Expr& x, int64_t val) {
return parent_->CanProveGreaterEqual(x, val);
}
// Whether x == val
bool CanProveEqual(const Expr& x, int64_t val) {
// TODO(tqchen) refer back to super-analyzer.
Expr res = Mutate(x);
if (const auto* ptr = res.as<ir::IntImm>()) {
return ptr->value == val;
}
return false;
}
// Recursive rewrite x
// we limit maximum depth of recursive rewrite allowed to
// avoid infinite loop
Expr RecursiveRewrite(const Expr& x) {
if (recur_depth_ >= kMaxRecurDepth) return x;
++recur_depth_;
Expr res = Mutate(x);
--recur_depth_;
return res;
}
template<typename TA>
PConstWithTypeLike<TA> ZeroWithTypeLike(const Pattern<TA>& pattern) {
return PConstWithTypeLike<TA>(pattern.derived(), 0);
}
};
Expr RewriteSimplifier::Impl::
Mutate_(const Add* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
op = ret.as<Add>();
Expr const_res = TryConstFold<Add>(op->a, op->b);
if (const_res.defined()) return const_res;
// Pattern var to match any expression
PVar<Expr> x, y, z, b1, b2, s1, s2;
// Pattern var match IntImm
PVar<Integer> c1, c2, c3;
// Pattern var for lanes in broadcast and ramp
PVar<int> lanes;
// Vector rules
if (op->type.lanes() != 1) {
TVM_TRY_REWRITE(ramp(b1, s1, lanes) + ramp(b2, s2, lanes),
ramp(b1 + b2, s1 + s2, lanes));
TVM_TRY_REWRITE(ramp(b1, s1, lanes) + broadcast(x, lanes),
ramp(b1 + x, s1, lanes));
TVM_TRY_REWRITE(broadcast(x, lanes) + ramp(b1, s1, lanes),
ramp(x + b1, s1, lanes));
TVM_TRY_REWRITE(broadcast(x, lanes) + broadcast(y, lanes),
broadcast(x + y, lanes));
}
if (IsIndexType(op->type)) {
// Index rules
// cancelation rules
TVM_TRY_REWRITE((x - y) + y, x);
TVM_TRY_REWRITE(x + (y - x), y);
TVM_TRY_REWRITE((x - y) + (y - z), x - z);
TVM_TRY_REWRITE((x - y) + (z - x), z - y);
TVM_TRY_REWRITE(min(x, y - z) + z, min(x + z, y));
TVM_TRY_REWRITE(min(x - z, y) + z, min(x, y + z));
TVM_TRY_REWRITE(max(x, y - z) + z, max(x + z, y));
TVM_TRY_REWRITE(max(x - z, y) + z, max(x, y + z));
TVM_TRY_REWRITE(max(x, y) + min(x, y), x + y);
TVM_TRY_REWRITE(min(x, y) + max(x, y), x + y);
TVM_TRY_REWRITE(max(x, y) + min(y, x), x + y);
TVM_TRY_REWRITE(min(x, y) + max(y, x), x + y);
TVM_TRY_REWRITE_IF(min(x, y + c1) + c2, min(x + c2, y),
c1.Eval()->value == -c2.Eval()->value);
TVM_TRY_REWRITE_IF(min(x + c1, y) + c2, min(x, y + c2),
c1.Eval()->value == -c2.Eval()->value);
TVM_TRY_REWRITE_IF(max(x, y + c1) + c2, max(x + c2, y),
c1.Eval()->value == -c2.Eval()->value);
TVM_TRY_REWRITE_IF(max(x + c1, y) + c2, max(x, y + c2),
c1.Eval()->value == -c2.Eval()->value);
// constant folding
// NOTE: canonicalization might better at this.
TVM_TRY_REWRITE((x + c1) + c2, x + (c1 + c2));
// mul co-efficient folding
TVM_TRY_REWRITE(x + x, x * 2);
TVM_TRY_REWRITE(x * y + x, x * (y + 1));
TVM_TRY_REWRITE(y * x + x, x * (y + 1));
TVM_TRY_REWRITE(x + y * x, x * (1 + y));
TVM_TRY_REWRITE(x + x * y, x * (1 + y));
TVM_TRY_REWRITE(x * y + x * z, x * (y + z));
TVM_TRY_REWRITE(y * x + x * z, x * (y + z));
TVM_TRY_REWRITE(x * y + z * x, x * (y + z));
TVM_TRY_REWRITE(y * x + z * x, x * (y + z));
// modular-div simplification
// Always pre-condition on positive integer domain
TVM_TRY_REWRITE_IF(
(x / c1) * c1 + x % c1, x,
CanProveGreaterEqual(x.Eval(), 0) && c1.Eval()->value > 0);
// canonicalization rule
// will try rewrite again after canonicalization.
TVM_TRY_RECURSIVE_REWRITE(x + (c1 - y), (x - y) + c1);
TVM_TRY_RECURSIVE_REWRITE(x + c1 + y, (x + y) + c1);
TVM_TRY_RECURSIVE_REWRITE(x + (c1 + y), (x + y) + c1);
TVM_TRY_RECURSIVE_REWRITE((y % c1) + x * c1, x * c1 + (y % c1));
}
// condition rules.
TVM_TRY_REWRITE(select(x, b1, b2) + select(x, s1, s2),
select(x, b1 + s1, b2 + s2));
// default value
return ret;
}
Expr RewriteSimplifier::Impl::
Mutate_(const Sub* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
op = ret.as<Sub>();
Expr const_res = TryConstFold<Sub>(op->a, op->b);
if (const_res.defined()) return const_res;
// Pattern var to match any expression
PVar<Expr> x, y, z, b1, b2, s1, s2;
// Pattern var match IntImm
PVar<Integer> c1, c2, c3;
// Pattern var for lanes in broadcast and ramp
PVar<int> lanes;
// Vector rules
if (op->type.lanes() != 1) {
TVM_TRY_REWRITE(ramp(b1, s1, lanes) - ramp(b2, s2, lanes),
ramp(b1 - b2, s1 - s2, lanes));
TVM_TRY_REWRITE(ramp(b1, s1, lanes) - broadcast(x, lanes),
ramp(b1 - x, s1, lanes));
TVM_TRY_REWRITE(broadcast(x, lanes) - ramp(b1, s1, lanes),
ramp(x - b1, 0 - s1, lanes));
TVM_TRY_REWRITE(broadcast(x, lanes) - broadcast(y, lanes),
broadcast(x - y, lanes));
}
if (IsIndexType(op->type)) {
// Index rules
// cancelation rules
TVM_TRY_REWRITE((x + y) - y, x);
TVM_TRY_REWRITE((x + y) - x, y);
TVM_TRY_REWRITE(x - (y + x), 0 - y);
TVM_TRY_REWRITE(x - (x + y), 0 - y);
TVM_TRY_REWRITE(min(x, y) - x, min(0, y - x));
TVM_TRY_REWRITE(min(x, y) - y, min(x - y, 0));
TVM_TRY_REWRITE(max(x, y) - x, max(0, y - x));
TVM_TRY_REWRITE(max(x, y) - y, max(x - y, 0));
TVM_TRY_REWRITE(x - max(x, y), min(0, x - y));
TVM_TRY_REWRITE(y - max(x, y), min(y - x, 0));
TVM_TRY_REWRITE(x - min(x, y), max(0, x - y));
TVM_TRY_REWRITE(y - min(x, y), max(y - x, 0));
// mul co-efficient folding
TVM_TRY_REWRITE(x - x, ZeroWithTypeLike(x));
TVM_TRY_REWRITE(x * y - x, x * (y - 1));
TVM_TRY_REWRITE(y * x - x, x * (y - 1));
TVM_TRY_REWRITE(x - y * x, x * (1 - y));
TVM_TRY_REWRITE(x - x * y, x * (1 - y));
TVM_TRY_REWRITE(x * y - x * z, x * (y - z));
TVM_TRY_REWRITE(y * x - x * z, x * (y - z));
TVM_TRY_REWRITE(x * y - z * x, x * (y - z));
TVM_TRY_REWRITE(y * x - z * x, x * (y - z));
// constant cancelation
TVM_TRY_REWRITE((x + c1) - c2, x + (c1 - c2));
TVM_TRY_REWRITE((c1 - x) - (c2 - y), (y - x) + (c1 - c2));
// cancelization rule involving 4 operands
TVM_TRY_REWRITE((x + y) - (x + z), y - z);
TVM_TRY_REWRITE((x + y) - (z + x), y - z);
TVM_TRY_REWRITE((y + x) - (z + x), y - z);
TVM_TRY_REWRITE((y + x) - (x + z), y - z);
TVM_TRY_REWRITE(min(x + y, z) - x, min(y, z - x));
TVM_TRY_REWRITE(min(y + x, z) - x, min(y, z - x));
TVM_TRY_REWRITE(min(z, x + y) - x, min(z - x, y));
TVM_TRY_REWRITE(min(z, y + x) - x, min(z - x, y));
TVM_TRY_REWRITE(x - min(x + y, z), max(0 - y, x - z));
TVM_TRY_REWRITE(x - min(y + x, z), max(0 - y, x - z));
TVM_TRY_REWRITE(x - min(z, x + y), max(x - z, 0 - y));
TVM_TRY_REWRITE(x - min(z, y + x), max(x - z, 0 - y));
TVM_TRY_REWRITE(min(x, y) - min(y, x), ZeroWithTypeLike(x));
TVM_TRY_REWRITE(max(x, y) - max(y, x), ZeroWithTypeLike(x));
TVM_TRY_REWRITE_IF(min(b1, b2) - min(s1, s2), b1 - s1,
CanProveEqual(((b1 - s1) - (b2 - s2)).Eval(), 0));
TVM_TRY_REWRITE_IF(min(b1, b2) - min(s1, s2), b1 - s2,
CanProveEqual(((b1 - s2) - (b2 - s1)).Eval(), 0));
TVM_TRY_REWRITE_IF(max(b1, b2) - max(s1, s2), b1 - s1,
CanProveEqual(((b1 - s1) - (b2 - s2)).Eval(), 0));
TVM_TRY_REWRITE_IF(max(b1, b2) - max(s1, s2), b1 - s2,
CanProveEqual(((b1 - s2) - (b2 - s1)).Eval(), 0));
// modular-div simplification
// Always pre-condition on positive integer domain
TVM_TRY_REWRITE_IF(x - (x / c1) * c1, x % c1,
CanProveGreaterEqual(x.Eval(), 0) && c1.Eval()->value > 0);
TVM_TRY_REWRITE_IF((x / c1) * c1 - x, 0 - (x % c1),
CanProveGreaterEqual(x.Eval(), 0) && c1.Eval()->value > 0);
TVM_TRY_REWRITE_IF((x + c1) / c3 - (x + c2) / c3,
((x + (c1 % c3)) % c3 + (c1 - c2)) / c3,
CanProveGreaterEqual(x.Eval(), -c2.Eval()->value) &&
c1.Eval()->value >= c2.Eval()->value &&
c3.Eval()->value > 0);
TVM_TRY_REWRITE_IF((x + c1) / c3 - x / c3,
((x + (c1 % c3)) % c3 + c1) / c3,
CanProveGreaterEqual(x.Eval(), 0) &&
c1.Eval()->value >= 0 &&
c3.Eval()->value > 0);
// canonicalization rule
// will try rewrite again after canonicalization.
TVM_TRY_REWRITE(x - c1, x + (0 - c1));
TVM_TRY_RECURSIVE_REWRITE((x + c1) - y, (x - y) + c1);
TVM_TRY_RECURSIVE_REWRITE(x - (y - z), (x + z) - y);
TVM_TRY_RECURSIVE_REWRITE(x - y * c1, x + y * (0 - c1));
}
// condition rules.
TVM_TRY_REWRITE(select(x, b1, b2) - select(x, s1, s2),
select(x, b1 - s1, b2 - s2));
TVM_TRY_REWRITE(select(x, y, z) - z,
select(x, y - z, ZeroWithTypeLike(z)));
TVM_TRY_REWRITE(select(x, y, z) - y,
select(x, ZeroWithTypeLike(y), z - y));
return ret;
}
Expr RewriteSimplifier::Impl::
Mutate_(const Mul* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
op = ret.as<Mul>();
Expr const_res = TryConstFold<Mul>(op->a, op->b);
if (const_res.defined()) return const_res;
// Pattern var to match any expression
PVar<Expr> x, y, z, b1, b2, s1, s2;
// Pattern var match IntImm
PVar<Integer> c1, c2;
// Pattern var for lanes in broadcast and ramp
PVar<int> lanes;
// Vector rules
if (op->type.lanes() != 1) {
TVM_TRY_REWRITE(broadcast(x, lanes) * broadcast(y, lanes),
broadcast(x * y, lanes));
TVM_TRY_REWRITE(ramp(b1, s1, lanes) * broadcast(x, lanes),
ramp(b1 * x, s1 * x, lanes));
TVM_TRY_REWRITE(broadcast(x, lanes) * ramp(b1, s1, lanes),
ramp(b1 * x, s1 * x, lanes));
}
if (IsIndexType(op->type)) {
// constant simplification rule
TVM_TRY_REWRITE((x + c1) * c2, x * c2 + c1 * c2);
TVM_TRY_REWRITE((x * c1) * c2, x * (c1 * c2));
TVM_TRY_REWRITE(min(x, y) * max(x, y), x * y);
TVM_TRY_REWRITE(max(x, y) * min(x, y), x * y);
// canonicalization
TVM_TRY_RECURSIVE_REWRITE(x * (c1 * y), (x * y) * c1);
TVM_TRY_RECURSIVE_REWRITE_IF(
(x - y) * c1, (y - x) * (0 - c1),
c1.Eval()->value < 0);
}
return ret;
}
Expr RewriteSimplifier::Impl::
Mutate_(const Div* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
op = ret.as<Div>();
Expr const_res = TryConstFold<Div>(op->a, op->b);
if (const_res.defined()) return const_res;
// Pattern var to match any expression
PVar<Expr> x, y, z, b1;
// Pattern var match IntImm
PVar<Integer> c1, c2, c3;
// Pattern var for lanes in broadcast and ramp
PVar<int> lanes;
// Vector rules
if (op->type.lanes() != 1) {
TVM_TRY_REWRITE(broadcast(x, lanes) / broadcast(y, lanes),
broadcast(x / y, lanes));
// ramp / bcast
if ((ramp(b1, c1, lanes) / broadcast(c2, lanes)).Match(ret)) {
int64_t c1val = c1.Eval()->value;
int64_t c2val = c2.Eval()->value;
if (c1val % c2val == 0) {
return ramp(b1 / c2, c1 / c2, lanes).Eval();
}
// If all possible indices in ramp are the same.
if (CanProveGreaterEqual(b1.Eval(), 0)) {
ModularSet bmod = parent_->modular_set(b1.Eval());
int64_t ramp_min = bmod->base / c2val;
int64_t ramp_max = (bmod->base + (lanes.Eval() - 1) * c1val) / c2val;
if (bmod->coeff % c2val == 0 && ramp_min == ramp_max) {
return broadcast(b1 / c2, lanes).Eval();
}
}
}
}
if (IsIndexType(op->type)) {
// Be-aware of the division rules:
// We adopt the default C division uses truncation instead of floordiv.
// This means most rules need to check non-negativeness of the operands.
// while it is always true for trunc div
// restrict to common case(positive div)
TVM_TRY_REWRITE_IF((x / c1) / c2, x / (c1 * c2),
c1.Eval()->value > 0 && c2.Eval()->value > 0);
TVM_TRY_REWRITE_IF((x / c1 + c2) / c3, (x + c1 * c2) / (c1 * c3),
c1.Eval()->value > 0 &&
c2.Eval()->value >= 0 &&
c3.Eval()->value > 0 &&
CanProveGreaterEqual(x.Eval(), 0));
if (((x * c1) / c2).Match(ret)) {
int64_t c1val = c1.Eval()->value;
int64_t c2val = c2.Eval()->value;
if (c1val > 0 && c2val > 0) {
if (c1val % c2val == 0) return (x * (c1 / c2)).Eval();
if (c2val % c1val == 0) return (x / (c2 / c1)).Eval();
}
}
// Rules involving 2-operands.
TVM_TRY_REWRITE_IF((x * c1 + y) / c2, x * (c1 / c2) + y / c2,
c1.Eval()->value >= 0 &&
c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0 &&
CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual(y.Eval(), 0));
TVM_TRY_REWRITE_IF(min(x * c1, y) / c2, min(x * (c1 / c2), y / c2),
c1.Eval()->value >= 0 &&
c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0 &&
CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual(y.Eval(), 0));
TVM_TRY_REWRITE_IF(max(x * c1, y) / c2, max(x * (c1 / c2), y / c2),
c1.Eval()->value >= 0 &&
c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0 &&
CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual(y.Eval(), 0));
TVM_TRY_REWRITE_IF((y + x * c1) / c2, y / c2 + x * (c1 / c2),
c1.Eval()->value >= 0 &&
c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0 &&
CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual(y.Eval(), 0));
TVM_TRY_REWRITE_IF(min(y, x * c1) / c2, min(y / c2, x * (c1 / c2)),
c1.Eval()->value >= 0 &&
c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0 &&
CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual(y.Eval(), 0));
TVM_TRY_REWRITE_IF(max(y, x * c1) / c2, max(y / c2, x * (c1 / c2)),
c1.Eval()->value >= 0 &&
c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0 &&
CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual(y.Eval(), 0));
// Rules involving 3-operands.
TVM_TRY_REWRITE_IF((x * c1 + y + z) / c2, x * (c1 / c2) + (y + z)/ c2,
c1.Eval()->value >= 0 &&
c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0 &&
CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual((y + z).Eval(), 0));
TVM_TRY_REWRITE_IF((x * c1 - y + z) / c2, x * (c1 / c2) + (z - y)/ c2,
c1.Eval()->value >= 0 &&
c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0 &&
CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual((z - y).Eval(), 0));
TVM_TRY_REWRITE_IF((x * c1 + y - z) / c2, x * (c1 / c2) + (y - z)/ c2,
c1.Eval()->value >= 0 &&
c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0 &&
CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual((y - z).Eval(), 0));
TVM_TRY_REWRITE_IF((y + x * c1 + z) / c2, x * (c1 / c2) + (y + z) / c2,
c1.Eval()->value > 0 &&
c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0 &&
CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual((y + z).Eval(), 0));
TVM_TRY_REWRITE_IF((x + c1) / c2, x / c2 + c1 / c2,
c1.Eval()->value > 0 &&
c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0 &&
CanProveGreaterEqual(x.Eval(), 0));
TVM_TRY_REWRITE_IF((x + y) / x, y / x + 1,
CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual(y.Eval(), 0));
TVM_TRY_REWRITE_IF((y + x) / x, y / x + 1,
CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual(y.Eval(), 0));
TVM_TRY_REWRITE_IF(((x + y) + z) / x, (y + z) / x + 1,
CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual((y + z).Eval(), 0));
TVM_TRY_REWRITE_IF(((y + x) + z) / x, (y + z) / x + 1,
CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual((y + z).Eval(), 0));
TVM_TRY_REWRITE_IF((y + (z + x)) / x, (y + z) / x + 1,
CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual((y + z).Eval(), 0));
TVM_TRY_REWRITE_IF((y + (x + z)) / x, (y + z) / x + 1,
CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual((y + z).Eval(), 0));
TVM_TRY_REWRITE_IF((x * y) / y, x,
CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual(y.Eval(), 0));
TVM_TRY_REWRITE_IF((y * x) / y, x,
CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual(y.Eval(), 0));
TVM_TRY_REWRITE_IF((x * z + y) / z, x + y / z,
CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual(y.Eval(), 0) &&
CanProveGreaterEqual(z.Eval(), 0));
TVM_TRY_REWRITE_IF((z * x + y) / z, x + y / z,
CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual(y.Eval(), 0) &&
CanProveGreaterEqual(z.Eval(), 0));
TVM_TRY_REWRITE_IF((y + x * z) / z, y / z + x,
CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual(y.Eval(), 0) &&
CanProveGreaterEqual(z.Eval(), 0));
TVM_TRY_REWRITE_IF((y + z * x) / z, y / z + x,
CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual(y.Eval(), 0) &&
CanProveGreaterEqual(z.Eval(), 0));
}
return ret;
}
Expr RewriteSimplifier::Impl::
Mutate_(const Mod* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
op = ret.as<Mod>();
Expr const_res = TryConstFold<Mod>(op->a, op->b);
if (const_res.defined()) return const_res;
// Pattern var to match any expression
PVar<Expr> x, y, z, b1;
// Pattern var match IntImm
PVar<Integer> c1, c2, c3;
// Pattern var for lanes in broadcast and ramp
PVar<int> lanes;
// Vector rules
if (op->type.lanes() != 1) {
TVM_TRY_REWRITE(broadcast(x, lanes) % broadcast(y, lanes),
broadcast(x % y, lanes));
// ramp % bcast
if ((ramp(b1, c1, lanes) % broadcast(c2, lanes)).Match(ret)) {
int64_t c1val = c1.Eval()->value;
int64_t c2val = c2.Eval()->value;
if (c1val % c2val == 0) {
return broadcast(b1 % c2, lanes).Eval();
}
// If all possible indices in ramp are the same.
if (CanProveGreaterEqual(b1.Eval(), 0)) {
ModularSet bmod = parent_->modular_set(b1.Eval());
int64_t ramp_min = bmod->base / c2val;
int64_t ramp_max = (bmod->base + (lanes.Eval() - 1) * c1val) / c2val;
if (bmod->coeff % c2val == 0) {
if (ramp_min == ramp_max) {
return ramp(bmod->base % c2, c1, lanes).Eval();
} else {
return (ramp(bmod->base % c2, c1, lanes) % broadcast(c2, lanes)).Eval();
}
}
}
}
}
if (IsIndexType(op->type)) {
// Be-aware of the division rules:
// We adopt the default C division uses truncation instead of floordiv.
// This means most rules need to check non-negativeness of the operands.
TVM_TRY_REWRITE_IF((x * c1) % c2, ZeroWithTypeLike(x),
c2.Eval()->value != 0 &&
c1.Eval()->value % c2.Eval()->value == 0);
TVM_TRY_REWRITE_IF((x * c1 + y) % c2, y % c2,
c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0 &&
CanProveGreaterEqual(y.Eval(), 0));
TVM_TRY_REWRITE_IF((x + c1) % c2, x % c2,
c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0 &&
CanProveGreaterEqual(x.Eval(), 0));
TVM_TRY_REWRITE_IF((x + y * c1) % c2, x % c2,
c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0 &&
CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual(y.Eval(), 0));
// try modular analysis
if ((x % c1).Match(ret)) {
ModularSet mod = parent_->modular_set(x.Eval());
int64_t c1val = c1.Eval()->value;
if (mod->coeff % c1val == 0 &&
CanProveGreaterEqual(x.Eval(), 0)) {
return (mod->base % c1).Eval();
}
}
}
return ret;
}
Expr RewriteSimplifier::operator()(const Expr& expr) {
return impl_->PostOrderSimplify(expr);
}
void RewriteSimplifier::Update(const Var& var,
const Expr& info,
bool override) {
impl_->Update(var, info, override);
}
RewriteSimplifier::RewriteSimplifier(Analyzer* parent)
: impl_(new Impl(parent)) {
}
RewriteSimplifier::~RewriteSimplifier() {
delete impl_;
}
} // namespace arith
} // namespace tvm
......@@ -117,6 +117,7 @@ TEST(Pattern, Integer) {
// special case container of Expr
CHECK((v * c).Match(tx * 3));
CHECK_EQ(c.Eval()->value, 3);
CHECK((v * 3).Match(tx * 3));
}
// cannot match c to ty
CHECK(!(v * c).Match(tx * ty));
......
import tvm
class RewriteChecker:
def __init__(self):
self.analyzer = tvm.arith.Analyzer()
def verify(self, data, expected):
res = self.analyzer.rewrite_simplify(data)
assert tvm.ir_pass.Equal(res, expected), "data={}, res={}, expected={}".format(
data, res, expected)
def test_vector_simplify():
ck = RewriteChecker()
x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
# Add rules
ck.verify(tvm.expr.Ramp(x, 1, 4) + tvm.expr.Ramp(y, 2, 4),
tvm.expr.Ramp(x + y, 3, 4))
ck.verify(tvm.expr.Ramp(x, 1, 2) + y,
tvm.expr.Ramp(x + y, 1, 2))
ck.verify(y + tvm.expr.Ramp(x, 1, 2) ,
tvm.expr.Ramp(y + x, 1, 2))
ck.verify(y.astype("int32x2") + x.astype("int32x2"),
(y + x).astype("int32x2"))
# Sub rules
ck.verify(tvm.expr.Ramp(x, 4, 4) - tvm.expr.Ramp(y, 2, 4),
tvm.expr.Ramp(x - y, 2, 4))
ck.verify(tvm.expr.Ramp(x, 1, 2) - y,
tvm.expr.Ramp(x - y, 1, 2))
ck.verify(y - tvm.expr.Ramp(x, 1, 2) ,
tvm.expr.Ramp(y - x, -1, 2))
ck.verify(y.astype("int32x2") - x.astype("int32x2"),
(y - x).astype("int32x2"))
# Mul rules
ck.verify(y.astype("int32x2") * x.astype("int32x2"),
(y * x).astype("int32x2"))
ck.verify(tvm.expr.Ramp(x, 4, 4) * 2,
tvm.expr.Ramp(x * 2, 8, 4))
ck.verify(2 * tvm.expr.Ramp(x, 4, 4),
tvm.expr.Ramp(x * 2, 8, 4))
## Div rules
ck.verify(y.astype("int32x2") / x.astype("int32x2"),
(y / x).astype("int32x2"))
ck.verify(tvm.expr.Ramp(x, 4, 4) / 2,
tvm.expr.Ramp(x/ 2, 2, 4))
ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True)
ck.verify(tvm.expr.Ramp(x * 8 + 1, 1, 4) / 8,
(x).astype("int32x4"))
ck.verify(tvm.expr.Ramp(x * 8 + 15, 1, 4) / 8,
tvm.expr.Ramp(x * 8 + 15, 1, 4) / 8)
## Mod rules
ck.verify(y.astype("int32x2") % x.astype("int32x2"),
(y % x).astype("int32x2"))
ck.verify(tvm.expr.Ramp(x, 4, 4) % 2,
tvm.expr.Broadcast(x % 2, 4))
ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True)
ck.verify(tvm.expr.Ramp(x * 8 + 1, 1, 4) % 8,
tvm.expr.Ramp(1, 1, 4))
ck.verify(tvm.expr.Ramp(x * 8 + 1, 15, 4) % 8,
tvm.expr.Ramp(1, 15, 4) % 8)
def test_select_simplify():
ck = RewriteChecker()
x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
# Add rules
ck.verify(tvm.expr.Select(x > 0, y, 0) + tvm.expr.Select(x > 0, 1, z),
tvm.expr.Select(x > 0, y + 1, z))
ck.verify(tvm.expr.Select(x > 0, y, 1) - tvm.expr.Select(x > 0, 1, z),
tvm.expr.Select(x > 0, y + (-1), 1 - z))
ck.verify(tvm.expr.Select(x > 0, y, z) - y,
tvm.expr.Select(x > 0, 0, z - y))
ck.verify(tvm.expr.Select(x > 0, y, z) - z,
tvm.expr.Select(x > 0, y - z, 0))
def test_add_index_simplify():
ck = RewriteChecker()
x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
ck.verify(x + (y - x), y)
ck.verify(x - (y + 1) + (y + 1), x)
ck.verify((x - 10) + (10 - z), x - z)
ck.verify((x - y) + (z - x), z - y)
ck.verify(tvm.min(x, y - z) + z, tvm.min(x + z, y))
ck.verify(tvm.min(x - z, y) + z, tvm.min(x, y + z))
ck.verify(tvm.max(x, y - 10) + 10, tvm.max(x + 10, y))
ck.verify(tvm.max(x - 11, y) + 11, tvm.max(x, y + 11))
ck.verify(tvm.max(x, y * 2) + tvm.min(x, y * 2), x + y * 2);
ck.verify(tvm.min(x, y * 2) + tvm.max(x, y * 2), x + y * 2);
ck.verify(tvm.max(x, y + 2) + (-2), tvm.max(x + (-2), y));
ck.verify(tvm.min(x, y + 2) + (-2), tvm.min(x + (-2), y));
ck.verify(tvm.min(x + 2, y + 3) + (-2), tvm.min(x, y + 1));
ck.verify(x * y + x * 10, x * (y + 10))
ck.verify(y * x + x * 10, x * (y + 10))
ck.verify(y * x + 10 * x, x * (y + 10))
ck.verify(x * y + 10 * x, x * (y + 10))
ck.verify(y * (x % 8) + 10 * (x % 8), (x % 8) * (y + 10))
ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True)
ck.verify((x / 8) * 8 + x % 8, x)
# canonicalization
ck.verify(x + 2 + 3 + 4 + x, x * 2 + 9);
ck.verify(x + 2 + 3 + 4 + x * 3, x * 4 + 9);
# conservative bound
try:
ck.analyzer.update(x, tvm.arith.ConstIntBound(-1, 1000), override=True)
ck.verify((x / 8) * 8 + x % 8, x)
raise RuntimeError("bad")
except AssertionError:
pass
def test_sub_index_simplify():
ck = RewriteChecker()
x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
ck.verify(x + y - y, x)
ck.verify(x + y - x, y)
ck.verify(x - (y + x), 0 - y)
ck.verify(x - (x + y), 0 - y)
ck.verify(tvm.min(x, y) - x, tvm.min(0, y - x))
ck.verify(tvm.min(x, y) - y, tvm.min(x - y, 0))
ck.verify(tvm.max(x, y) - x, tvm.max(0, y - x))
ck.verify(tvm.max(x, y) - y, tvm.max(x - y, 0))
ck.verify(x - tvm.min(x, y), tvm.max(0, x - y))
ck.verify(y - tvm.min(x, y), tvm.max(y - x, 0))
ck.verify(x - tvm.max(x, y), tvm.min(0, x - y))
ck.verify(y - tvm.max(x, y), tvm.min(y - x, 0))
# mul co-efficient foldng
ck.verify(x - x, 0)
ck.verify(x * y - x, x * (y + (-1)))
ck.verify(x * y - 10 * x, x * (y + (-10)))
ck.verify(y * x - x * z, x * (y - z))
ck.verify(y * x - z * x, x * (y - z))
ck.verify(x + 10 - 20, x + (-10))
# 4-operands pattern
ck.verify((x + y) - (x + z), y - z)
ck.verify((y + x) - (x + z), y - z)
ck.verify((x + y) - (z + x), y - z)
ck.verify((y + x) - (z + x), y - z)
ck.verify(tvm.min(x + y, z) - x, tvm.min(y, z - x))
ck.verify(tvm.min(y + x, z) - x, tvm.min(y, z - x))
ck.verify(tvm.min(z, x + y) - x, tvm.min(z - x, y))
ck.verify(tvm.min(z, y + x) - x, tvm.min(z - x, y))
ck.verify(x - tvm.min(x + y, z), tvm.max(0 - y, x - z))
ck.verify(x - tvm.min(y + x, z), tvm.max(0 - y, x - z))
ck.verify(x - tvm.min(z, x + y), tvm.max(x - z, 0 - y))
ck.verify(x - tvm.min(z, y + x), tvm.max(x - z, 0 - y))
ck.verify(tvm.min(x, y) - tvm.min(y, x), 0)
ck.verify(tvm.max(x, y) - tvm.max(y, x), 0)
ck.verify(tvm.min(x, y) - tvm.min(x + 10, y + 10), -10)
ck.verify(tvm.min(x + 10, y + 1) - tvm.min(x, y - 9), 10)
# div pattern
ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True)
ck.verify(x - (x / 3) * 3, x % 3)
ck.verify((x + 5) / 3 - x / 3, (((x + 2) % 3) + 5)/ 3)
def test_mul_index_simplify():
ck = RewriteChecker()
x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
ck.verify((x + 2) * 3, x * 3 + 6)
ck.verify((x * 2) * 3, x * 6)
ck.verify(tvm.min(x, y) * tvm.max(x, y), x * y)
ck.verify(tvm.max(x, y) * tvm.min(x, y), x * y)
ck.verify((x - y) * (-2), (y - x) * 2)
def test_div_index_simplify():
ck = RewriteChecker()
x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True)
ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1000), override=True)
ck.analyzer.update(z, tvm.arith.ConstIntBound(0, 1000), override=True)
ck.verify(x / 2 / 3, x / 6)
ck.verify((x / 2 + 1) / 3, (x + 2) / 6)
ck.verify(x * 2 / 4, x / 2)
ck.verify(x * 4 / 2, x * 2)
ck.verify((x * 4 + y) / 2, x * 2 + y / 2)
ck.verify(tvm.min(x * 6, y) / 2, tvm.min(x * 3, y / 2))
ck.verify(tvm.max(x * 6, y) / 2, tvm.max(x * 3, y / 2))
ck.verify((y + x * 4) / 2, y / 2 + x * 2)
ck.verify(tvm.min(y, x * 6) / 2, tvm.min(y / 2, x * 3))
ck.verify(tvm.max(y, x * 6) / 2, tvm.max(y / 2, x * 3))
# 3-operands
ck.verify((x * 6 + y + z) / 2, x * 3 + (y + z) / 2)
ck.verify((x * 6 - y + (y + 3)) / 2, x * 3 + 1)
ck.verify((x * 6 + (y + 3) - y) / 2, x * 3 + 1)
ck.verify((y + x * 6 + z) / 2, x * 3 + (y + z) / 2)
ck.verify((x + 4) / 2, x / 2 + 2)
ck.verify((x + y) / x, y / x + 1)
ck.verify((y + x) / x, y / x + 1)
ck.verify(((x + y) + z) / x, (y + z) / x + 1)
ck.verify(((y + x) + z) / x, (y + z) / x + 1)
ck.verify((y + (x + z)) / x, (y + z) / x + 1)
ck.verify((y + (z + x)) / x, (y + z) / x + 1)
ck.verify((x * y) / y, x)
ck.verify((y * x) / y, x)
ck.verify((x * z + y) / z, x + y / z)
ck.verify((z * x + y) / z, x + y / z)
ck.verify((y + x * z) / z, y / z + x)
ck.verify((y + z * x) / z, y / z + x)
def test_mod_index_simplify():
ck = RewriteChecker()
x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True)
ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1000), override=True)
ck.verify(x * 10 % 2, 0)
ck.verify((x * 10 + y) % 2, y % 2)
ck.verify((x + 10) % 2, x % 2)
ck.verify((x + y * 10) % 2, x % 2)
ck.verify((x* 10 + 1 + y * 2 + 2) % 2, 1)
if __name__ == "__main__":
test_mod_index_simplify()
test_vector_simplify()
test_add_index_simplify()
test_sub_index_simplify()
test_mul_index_simplify()
test_div_index_simplify()
test_select_simplify()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment