Unverified Commit 7afbca56 by Tianqi Chen Committed by GitHub

[ARITH] Analyzer CanonicalSimplifier (#2891)

parent eb1ed116
......@@ -218,6 +218,7 @@ class RewriteSimplifier {
private:
friend class Analyzer;
friend class ConstraintContext;
friend class CanonicalSimplifier;
explicit RewriteSimplifier(Analyzer* parent);
~RewriteSimplifier();
class Impl;
......@@ -226,6 +227,39 @@ class RewriteSimplifier {
};
/*!
* \brief Canonical-form based simplifier.
*/
class CanonicalSimplifier {
public:
/*!
* \brief analyze the expr
* \param expr The expression of interest.
* \return the result of the analysis.
*/
Expr operator()(const Expr& expr);
/*!
* \brief Update binding of var to a new expression.
*
* \param var The variable of interest.
* \param new_expr
* \param override Whether do we allow override of existing information.
*/
void Update(const Var& var,
const Expr& new_expr,
bool override = false);
private:
friend class Analyzer;
friend class ConstraintContext;
explicit CanonicalSimplifier(Analyzer* parent);
~CanonicalSimplifier();
class Impl;
/*! \brief Internal impl */
Impl* impl_;
};
/*!
* \brief A RAII constraint context.
*
* \code
......@@ -277,6 +311,8 @@ class Analyzer {
ModularSetAnalyzer modular_set;
/*! \brief sub-analyzer rewrite simplfy */
RewriteSimplifier rewrite_simplify;
/*! \brief sub-analyzer rewrite simplfy */
CanonicalSimplifier canonical_simplify;
/*! \brief constructor */
Analyzer();
/*!
......
......@@ -12,6 +12,7 @@
#include <string>
#include <memory>
#include <functional>
#include <utility>
#include "runtime/registry.h"
namespace tvm {
......@@ -32,6 +33,44 @@ using ::tvm::AttrVisitor;
using ContainerType = NodeName; \
}; \
/*!
* \brief Macro to make it easy to define node ref type that
* has a CopyOnWrite member function.
*
* CopyOnWrite will generate a unique copy of the internal node.
* The node will be copied if it is referenced by multiple places.
* The function returns the raw pointer to the node to allow modification
* of the content.
*
* \code
*
* MyCOWNodeRef ref, ref2;
* ref2 = ref;
* ref.CopyOnWrite()->value = new_value;
* assert(ref2->value == old_value);
* assert(ref->value == new_value);
*
* \endcode
*/
#define TVM_DEFINE_COW_NODE_REF(TypeName, BaseType, NodeName) \
class TypeName : public BaseType { \
public: \
TypeName() {} \
explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : BaseType(n) {} \
const NodeName* operator->() const { \
return static_cast<const NodeName*>(node_.get()); \
} \
inline NodeName* CopyOnWrite() { \
CHECK(node_ != nullptr); \
if (!node_.unique()) { \
NodePtr<NodeName> n = make_node<NodeName>(*(operator->())); \
NodePtr<Node>(std::move(n)).swap(node_); \
} \
return static_cast<NodeName*>(node_.get()); \
} \
using ContainerType = NodeName; \
};
/*!
* \brief save the node as well as all the node it depends on as json.
......
......@@ -97,6 +97,7 @@ class Analyzer:
self._bind = _mod("bind")
self._modular_set = _mod("modular_set")
self._rewrite_simplify = _mod("rewrite_simplify")
self._canonical_simplify = _mod("canonical_simplify")
self._enter_constraint_context = _mod("enter_constraint_context")
def const_int_bound(self, expr):
......@@ -144,6 +145,21 @@ class Analyzer:
"""
return self._rewrite_simplify(expr)
def canonical_simplify(self, expr):
"""Simplify expression via canonicalization.
Parameters
----------
expr : tvm.Expr
The expression.
Returns
-------
result : Expr
The result.
"""
return self._canonical_simplify(expr)
def bind(self, var, expr):
"""Bind a variable to the expression.
......
......@@ -102,6 +102,10 @@ TVM_REGISTER_API("arith._CreateAnalyzer")
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
*ret = self->rewrite_simplify(args[0]);
});
} else if (name == "canonical_simplify") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
*ret = self->canonical_simplify(args[0]);
});
} else if (name == "bind") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
auto& sptr = args[1].node_sptr();
......
......@@ -11,14 +11,21 @@ namespace arith {
Analyzer::Analyzer()
: const_int_bound(this),
modular_set(this),
rewrite_simplify(this) {
rewrite_simplify(this),
canonical_simplify(this) {
}
void Analyzer::Bind(const VarExpr& v, const Expr& expr) {
Var var(v.node_);
this->const_int_bound.Update(var, this->const_int_bound(expr));
this->modular_set.Update(var, this->modular_set(expr));
this->rewrite_simplify.Update(var, this->rewrite_simplify(expr));
Expr new_expr = expr;
new_expr = this->canonical_simplify(new_expr);
new_expr = this->rewrite_simplify(new_expr);
this->const_int_bound.Update(var, this->const_int_bound(new_expr));
this->modular_set.Update(var, this->modular_set(new_expr));
this->rewrite_simplify.Update(var, new_expr);
this->canonical_simplify.Update(var, new_expr);
}
void Analyzer::Bind(const VarExpr& v, const Range& range) {
......@@ -47,5 +54,6 @@ bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) {
if (bd->min_value >= lower_bound) return true;
return false;
}
} // namespace arith
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file canonical.cc
* \brief Canonicalize simplification.
*/
#include <tvm/ir_mutator.h>
#include <tvm/arithmetic.h>
#include <tvm/ir_pass.h>
#include <algorithm>
#include <map>
#include <limits>
#include <vector>
#include <memory>
#include <unordered_map>
#include "canonical.h"
#include "compute_expr.h"
#include "arithmetic/Simplify.h"
namespace tvm {
namespace arith {
using namespace ir;
// Canonical entry for communicative ops.
struct ComExprEntry {
// the value of the expression.
Expr value;
// the level of the expression.
int level{0};
// The integer scale on value
int64_t scale{1};
ComExprEntry() {}
ComExprEntry(Expr value, int level)
: value(value), level(level) {}
inline bool operator<(const ComExprEntry& other) const {
if (level < other.level) return true;
if (level > other.level) return false;
// compare top operator of entries and sort on that if possible (fast check)
if (value.type_index() < other.value.type_index()) return true;
if (value.type_index() > other.value.type_index()) return false;
// if none of the above distinguishes the terms, compare the expression tree of the entries.
// This is a slower check.
int compare_result = Compare(value, other.value);
if (compare_result < 0) return true;
if (compare_result > 0) return false;
// it's a problem if we see identical entries at this point. They should've been merged earlier.
LOG(WARNING) << "we should not have identical entries at this point";
return false;
}
};
// canonical expression for communicative expression.
struct ComExprNode : public NodeBase {
// base constant value.
int64_t base{0};
// The values to be sumed.
std::vector<ComExprEntry> elem;
};
// canonical communicative expression
struct ComExpr {
public:
// constructor
ComExpr() {}
explicit ComExpr(NodePtr<ComExprNode> ptr) : ptr_(ptr) {}
// get member
ComExprNode* operator->() const {
return ptr_.get();
}
void reset() {
ptr_.reset();
}
bool defined() const {
return ptr_.get() != nullptr;
}
// comparator
bool operator<(const ComExpr& b) const {
const ComExpr& a = *this;
if (a->base < b->base) return true;
if (a->base > b->base) return false;
if (a->elem.size() < b->elem.size()) return true;
if (a->elem.size() > b->elem.size()) return false;
for (size_t i = 0; i < a->elem.size(); ++i) {
const ComExprEntry& ea = a->elem[i];
const ComExprEntry& eb = b->elem[i];
if (ea.level < eb.level) return true;
if (ea.level > eb.level) return false;
if (ea.value.get() < eb.value.get()) return true;
if (ea.value.get() > eb.value.get()) return false;
if (ea.scale < eb.scale) return true;
if (ea.scale > eb.scale) return false;
}
return false;
}
// equality
bool operator==(const ComExpr& b) const {
const ComExpr& a = *this;
if (a->base != b->base) return false;
if (a->elem.size() != b->elem.size()) return false;
for (size_t i = 0; i < a->elem.size(); ++i) {
const ComExprEntry& ea = a->elem[i];
const ComExprEntry& eb = b->elem[i];
if (ea.level != eb.level) return false;
if (ea.value.get() != eb.value.get()) return false;
if (ea.scale != eb.scale) return false;
}
return true;
}
private:
NodePtr<ComExprNode> ptr_;
};
// binary comparison op.
struct BinaryExpr {
int kind;
Expr lhs, rhs;
// comparator
bool operator<(const BinaryExpr& b) const {
if (kind < b.kind) return true;
if (kind > b.kind) return false;
if (lhs.get() < b.lhs.get()) return true;
if (lhs.get() > b.lhs.get()) return false;
return rhs.get() < b.rhs.get();
}
// equality
bool operator==(const BinaryExpr& b) const {
return kind == b.kind &&
lhs.same_as(b.lhs) &&
rhs.same_as(b.rhs);
}
};
template<typename T>
inline Expr Binary_(const T* op,
const Expr& e,
Expr a, Expr b) {
if (a.same_as(op->a) && b.same_as(op->b)) {
return e;
} else {
return T::make(a, b);
}
}
// internal of canonical engine.
class Canonical::Internal : public IRMutator {
public:
explicit Internal(Map<Var, Range> vrange) {
for (auto kv : vrange) {
SetRange(kv.first, kv.second, 0);
}
}
// stack entry.
struct StackEntry {
int max_level{0};
bool has_side_effect{false};
};
// aggressively canonicalized expression
struct CacheEntry {
// The canonical value of the expression.
Expr value;
// The level of the expression.
int max_level{0};
// whether the expression might have side effect.
bool has_side_effect{false};
// if not null, corresponds to to sum
ComExpr sum;
// reset the return entry.
void reset() {
sum.reset();
}
// as sum expr
ComExpr AsSum() const {
if (sum.defined()) return sum;
const int64_t *v1 = as_const_int(value);
const uint64_t *v2 = as_const_uint(value);
auto n = make_node<ComExprNode>();
if (v1) {
n->base = *v1;
} else if (v2) {
CHECK_LE(*v2,
static_cast<uint64_t>(std::numeric_limits<int64_t>::max()));
n->base = static_cast<int64_t>(*v2);
} else {
n->elem.push_back(ComExprEntry(value, max_level));
}
return ComExpr(n);
}
};
// Set range and level of var.
void SetRange(Var v, Range r, int level) {
var_range_[v.get()] = IntSet::range(r);
var_level_[v.get()] = level;
var_rec_.push_back(v);
}
// functions
Stmt Mutate(Stmt stmt) final {
stmt = IRMutator::Mutate(stmt);
return stmt;
}
Expr MutateExpr_(Expr expr) {
stack_.push_back(StackEntry());
expr = IRMutator::Mutate(expr);
// update result of parent automatically during pop
if (stack_.size() > 1) {
StackEntry& back = stack_[stack_.size() - 1];
StackEntry& prev = stack_[stack_.size() - 2];
prev.max_level = std::max(prev.max_level, back.max_level);
if (back.has_side_effect) prev.has_side_effect = true;
}
// copy result from stack
ret_entry_.has_side_effect = stack_.back().has_side_effect;
ret_entry_.max_level = stack_.back().max_level;
stack_.pop_back();
CHECK(expr.defined());
if (const IntImm* op = expr.as<IntImm>()) {
return Mutate_(op, expr);
}
return expr;
}
// call produce to get a cache entry.
CacheEntry Produce(Expr expr) {
ret_entry_.reset();
ret_entry_.value = MutateExpr_(expr);
CacheEntry ret = ret_entry_;
ret_entry_.reset();
return ret;
}
Expr Mutate(Expr expr) final {
ret_entry_.reset();
expr = MutateExpr_(expr);
ret_entry_.reset();
return expr;
}
// Check whether do special canonicalization.
bool EnableOpt(Type t) const {
return (t.lanes() == 1 && (t.is_int() || t.is_uint()));
}
// Max
Expr Mutate_(const Max* op, const Expr& e) final {
CacheEntry a = Produce(op->a);
CacheEntry b = Produce(op->b);
if (a.has_side_effect || b.has_side_effect) {
return Binary_(op, e, a.value, b.value);
}
return Binary(op, e);
}
// Min
Expr Mutate_(const Min* op, const Expr& e) final {
CacheEntry a = Produce(op->a);
CacheEntry b = Produce(op->b);
if (a.has_side_effect || b.has_side_effect) {
return Binary_(op, e, a.value, b.value);
}
return Binary(op, e);
}
// Add
Expr Mutate_(const Add* op, const Expr& e) final {
if (!EnableOpt(op->type)) {
return Binary(op, e);
}
CacheEntry a = Produce(op->a);
CacheEntry b = Produce(op->b);
if (a.has_side_effect || b.has_side_effect) {
return Binary_(op, e, a.value, b.value);
}
return SumAdd(a, b, +1);
}
// Sub
Expr Mutate_(const Sub* op, const Expr& e) final {
if (!EnableOpt(op->type)) {
return Binary(op, e);
}
CacheEntry a = Produce(op->a);
CacheEntry b = Produce(op->b);
if (a.has_side_effect || b.has_side_effect) {
return Binary_(op, e, a.value, b.value);
}
return SumAdd(a, b, -1);
}
// Mul
Expr Mutate_(const Mul* op, const Expr& e) final {
if (!EnableOpt(op->type)) {
return Binary(op, e);
}
CacheEntry a = Produce(op->a);
CacheEntry b = Produce(op->b);
if (a.has_side_effect || b.has_side_effect) {
return Binary_(op, e, a.value, b.value);
}
if (is_const(a.value) && is_const(b.value)) {
return ComputeExpr<Mul>(a.value, b.value);
} else if (is_const(a.value)) {
return SumMulConst(b.AsSum(), a.value);
} else if (is_const(b.value)) {
return SumMulConst(a.AsSum(), b.value);
} else {
return Binary(op, e);
}
}
// Variable
Expr Mutate_(const Variable* op, const Expr& e) final {
auto it = var_level_.find(op);
if (it != var_level_.end()) {
stack_.back().max_level = it->second;
}
return IRMutator::Mutate_(op, e);
}
// comparison
Expr Mutate_(const LT* op, const Expr& e) {
if (!EnableOpt(op->a.type())) {
return Binary(op, e);
}
CacheEntry a = Produce(op->a);
CacheEntry b = Produce(op->b);
if (a.has_side_effect || b.has_side_effect) {
return Binary_(op, e, a.value, b.value);
}
Expr b_sub_a = SumAdd(b, a, -1);
if (EvalSet(b_sub_a, var_range_).can_prove_positive()) {
return make_const(op->type, true);
} else {
return Binary_(op, e, a.value, b.value);
}
}
// IntImm
Expr Mutate_(const IntImm* op, const Expr& e) final {
if (op->type != Int(32)) return e;
auto it = cache_intimm_.find(op->value);
if (it != cache_intimm_.end()) {
return it->second;
} else {
cache_intimm_[op->value] = e;
return e;
}
}
// Div operator
Expr Mutate_(const Div* op, const Expr& e) final {
if (!EnableOpt(op->type)) {
return Binary(op, e);
}
CacheEntry a = Produce(op->a);
CacheEntry b = Produce(op->b);
if (a.has_side_effect || b.has_side_effect) {
return Binary_(op, e, a.value, b.value);
}
if (is_const(a.value) && is_const(b.value)) {
return ComputeExpr<Div>(a.value, b.value);
} else if (is_const(b.value)) {
return SumDivConst(a.AsSum(), b.value);
} else {
return Binary(op, e);
}
}
// Mod operator
Expr Mutate_(const Mod* op, const Expr& e) final {
if (!EnableOpt(op->type)) {
return Binary(op, e);
}
CacheEntry a = Produce(op->a);
CacheEntry b = Produce(op->b);
if (a.has_side_effect || b.has_side_effect) {
return Binary_(op, e, a.value, b.value);
}
if (is_const(a.value) && is_const(b.value)) {
return ComputeExpr<Mod>(a.value, b.value);
} else if (is_const(b.value)) {
return SumModConst(a.AsSum(), b.value);
} else {
return Binary(op, e);
}
}
Expr Mutate_(const And* op, const Expr& e) final {
Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<And>();
if (is_one(op->a)) return op->b;
if (is_one(op->b)) return op->a;
return expr;
}
// Call
Expr Mutate_(const Call* op, const Expr& e) final {
if (!op->is_pure()) {
stack_.back().has_side_effect = true;
}
Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<Call>();
if (op->is_intrinsic(Call::likely) && is_const(op->args[0])) {
return op->args[0];
} else {
return expr;
}
}
// For
Stmt Mutate_(const For* op, const Stmt& s) {
++level_counter_;
Var loop_var(op->loop_var.node_);
this->SetRange(loop_var,
Range::make_by_min_extent(op->min, op->extent),
level_counter_);
Stmt stmt = IRMutator::Mutate_(op, s);
--level_counter_;
return stmt;
}
// IfThenElse
Stmt Mutate_(const IfThenElse* op, const Stmt& s) {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<IfThenElse>();
if (is_one(op->condition)) return op->then_case;
return stmt;
}
// AttrStmt
Stmt Mutate_(const AttrStmt* op, const Stmt& s) {
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread) {
++level_counter_;
IterVar iv(op->node.node_);
CHECK_NE(iv->thread_tag.length(), 0U);
if (!var_level_.count(iv->var.get())) {
this->SetRange(iv->var,
Range::make_by_min_extent(0, op->value),
level_counter_);
}
Stmt stmt = IRMutator::Mutate_(op, s);
--level_counter_;
return stmt;
} else {
return IRMutator::Mutate_(op, s);
}
}
// The simplify statement.
static FMutateExpr& vtable_expr() { // NOLINT(*)
static FMutateExpr inst; return inst;
}
private:
template<typename T>
Expr Binary(const T* op, Expr e) {
Expr a = this->Mutate(op->a);
Expr b = this->Mutate(op->b);
BinaryExpr key{static_cast<int>(T::_type_info), a, b};
auto it = cache_binary_.find(key);
if (it != cache_binary_.end()) {
return it->second;
} else {
Expr ret = Binary_(op, e, a, b);
cache_binary_[key] = ret;
return ret;
}
}
// return entry
CacheEntry ret_entry_;
// internal information stack
std::vector<StackEntry> stack_;
// cache sum
std::map<ComExpr, CacheEntry> cache_sum_;
// cache of normal binary op
std::map<BinaryExpr, Expr> cache_binary_;
// cache of int constant
std::unordered_map<int64_t, Expr> cache_intimm_;
// range of each var
std::unordered_map<const Variable*, IntSet> var_range_;
// level of each var
std::unordered_map<const Variable*, int> var_level_;
// record history vars, to avoid false positive.
std::vector<Var> var_rec_;
// level counter
int level_counter_{0};
// get constant int value
int64_t GetConstIntValue(const Expr& v) {
int64_t value = 0;
const int64_t *v1 = as_const_int(v);
const uint64_t *v2 = as_const_uint(v);
CHECK(v1 || v2);
if (v1) {
value = *v1;
} else if (v2) {
CHECK_LE(*v2,
static_cast<uint64_t>(std::numeric_limits<int64_t>::max()));
value = static_cast<int64_t>(*v2);
}
return value;
}
// Detect if a = q * coeff + r, where r \in [0, coeff), coeff > 0
// (in Euclidean division)
// returns pair (q, r) if such detection is successful
// returns empty vector otherwise.
// Assumes that coeff is a constant integer
std::vector<ComExpr> TryLinearEquation(const ComExpr& a,
const Expr& coeff) {
Type type = coeff.type();
int64_t value = GetConstIntValue(coeff);
CHECK_NE(value, 0);
if (value < 0) return {};
// Given that denominator (value variable) is positive, truncated division
// (i.e., TVM's division semantics) is equivalent to Euclidean division if and only if
// numerator is non-negative or numerator is divisible by denominator (i.e., value)
IntSet numerator_int_set = EvalSet(Sum2Expr(a, type), var_range_);
bool numerator_is_non_neg = numerator_int_set.can_prove_non_negative();
// Try to separate terms of a into ones that can be proven to be
// divisible by coeff and ones that are not
// We will build q and r from divisible and non_divisible respectively
auto divisible = make_node<ComExprNode>();
auto non_divisible = make_node<ComExprNode>();
if (a->base % value == 0) {
divisible->base = a->base;
} else {
non_divisible->base = a->base;
}
for (const auto& e : a->elem) {
if (e.scale % value == 0) {
divisible->elem.push_back(e);
} else {
non_divisible->elem.push_back(e);
}
}
bool non_divisible_is_simplified = false;
int64_t div_result;
Expr non_divisible_res = Sum2Expr(ComExpr(non_divisible), type);
// if non_divisible part consists of only an integer and numerator is non-negative,
// we can simply divide it by coeff
if (is_const(non_divisible_res)) {
int64_t non_divisible_const = GetConstIntValue(non_divisible_res);
if (numerator_is_non_neg || non_divisible_const == 0) {
non_divisible_is_simplified = true;
// We need to do an Euclidean division here because (a*b + c)/b == a + c/b
// holds true only if division is Euclidean
div_result = HalideIR::Internal::div_imp(non_divisible_const , value);
}
} else {
// If we can prove that non_divisible part lies within [0, coeff), then
// non_divisible itself will be our r
IntSet non_divisible_set = EvalSet(non_divisible_res, var_range_);
if (non_divisible_set.min().type() == type &&
non_divisible_set.max().type() == type) {
if ( (non_divisible_set.is_single_point() &&
can_prove(non_divisible_set.point_value() == 0)) ||
(numerator_is_non_neg &&
can_prove(non_divisible_set.min() >= make_zero(type)) &&
can_prove(non_divisible_set.max() < coeff)) ) {
non_divisible_is_simplified = true;
div_result = 0;
}
}
}
if (non_divisible_is_simplified) {
non_divisible->base -= div_result * value;
divisible->base /= value;
divisible->base += div_result;
for (auto& e : divisible->elem) {
e.scale /= value;
}
return {ComExpr(divisible), ComExpr(non_divisible)};
} else {
return {};
}
}
// subroutine to do produce a % v
Expr SumModConst(ComExpr a, Expr v) {
std::vector<ComExpr> pair = TryLinearEquation(a, v);
if (pair.size() == 0) {
int64_t value = GetConstIntValue(v);
auto n = make_node<ComExprNode>();
// FIXME(derisavi) : The following can be done only for Euclidean division/mod.
// Therefore, it's only valid when truncated division/mod is equivalent to Euclidean one,
// that is, if and only if a and v are
// both negative or both positive or a is divisible by v.
// Extend the code to handle cases where the above condition is not satisfied, i.e.,
// a and v are of different signs and a is not divisible by v.
n->base = a->base % value;
for (auto e : a->elem) {
if (e.scale % value == 0) continue;
e.scale = e.scale % value;
n->elem.push_back(e);
}
Expr ret = Sum2Expr(ComExpr(n), v.type()) % v;
if (const Mod* mod = ret.as<Mod>()) {
return Binary(mod, ret);
} else {
// Sometimes the result is a constant, this may happen when value is -1
CHECK(is_const(ret)) << "CanonicalSimplify: "
<< Sum2Expr(ComExpr(n), v.type()) << " % " << v << " is " << ret
<< " which is neither Mod, nor a constant";
return ret;
}
}
ret_entry_.sum = pair[1];
ret_entry_.max_level = stack_.back().max_level;
ret_entry_.has_side_effect = stack_.back().has_side_effect;
auto it = cache_sum_.find(ret_entry_.sum);
if (it != cache_sum_.end()) {
ret_entry_ = it->second;
} else {
ret_entry_.value = Sum2Expr(ret_entry_.sum, v.type());
cache_sum_[ret_entry_.sum] = ret_entry_;
}
return ret_entry_.value;
}
// subroutine to do produce a % v
Expr SumDivConst(ComExpr a, Expr v) {
std::vector<ComExpr> pair = TryLinearEquation(a, v);
if (pair.size() == 0) {
Expr ret = Sum2Expr(a, v.type()) / v;
return Binary(ret.as<Div>(), ret);
}
ret_entry_.sum = pair[0];
ret_entry_.max_level = stack_.back().max_level;
ret_entry_.has_side_effect = stack_.back().has_side_effect;
auto it = cache_sum_.find(ret_entry_.sum);
if (it != cache_sum_.end()) {
ret_entry_ = it->second;
} else {
ret_entry_.value = Sum2Expr(ret_entry_.sum, v.type());
cache_sum_[ret_entry_.sum] = ret_entry_;
}
return ret_entry_.value;
}
// subroutine to do produce
Expr SumMulConst(ComExpr a, Expr v) {
int64_t value = GetConstIntValue(v);
if (value == 0) {
return make_zero(v.type());
}
auto vsum = make_node<ComExprNode>(*a.operator->());
vsum->base *= value;
for (auto& e : vsum->elem) {
e.scale *= value;
}
ret_entry_.sum = ComExpr(vsum);
ret_entry_.max_level = stack_.back().max_level;
ret_entry_.has_side_effect = stack_.back().has_side_effect;
auto it = cache_sum_.find(ret_entry_.sum);
if (it != cache_sum_.end()) {
ret_entry_ = it->second;
} else {
ret_entry_.value = Sum2Expr(ret_entry_.sum, v.type());
cache_sum_[ret_entry_.sum] = ret_entry_;
}
return ret_entry_.value;
}
// add two ComExpr together
ComExpr SumAdd_(const ComExpr& suma,
const ComExpr& sumb,
int bscale) {
auto n = make_node<ComExprNode>();
n->base = suma->base + sumb->base * bscale;
// merge of suma and sumb;
size_t i = 0, j = 0;
while (i < suma->elem.size() && j < sumb->elem.size()) {
const auto& a = suma->elem[i];
const auto& b = sumb->elem[j];
if (a.value.same_as(b.value) && a.level == b.level) {
ComExprEntry e = a;
e.scale = a.scale + b.scale * bscale;
if (e.scale != 0) {
n->elem.push_back(e);
}
++i; ++j;
} else if (a < b) {
n->elem.push_back(a);
++i;
} else {
ComExprEntry e = b;
e.scale *= bscale;
n->elem.push_back(e);
++j;
}
}
for (; i < suma->elem.size(); ++i) {
n->elem.push_back(suma->elem[i]);
}
for (; j < sumb->elem.size(); ++j) {
ComExprEntry e = sumb->elem[j];
e.scale *= bscale;
n->elem.push_back(e);
}
return ComExpr(n);
}
// subroutine to do produce
Expr SumAdd(CacheEntry a, CacheEntry b, int bscale) {
ret_entry_.sum = SumAdd_(a.AsSum(), b.AsSum(), bscale);
CHECK_NE(stack_.size(), 0U);
ret_entry_.max_level = stack_.back().max_level;
ret_entry_.has_side_effect = stack_.back().has_side_effect;
auto it = cache_sum_.find(ret_entry_.sum);
if (it != cache_sum_.end()) {
ret_entry_ = it->second;
} else {
ret_entry_.value = Sum2Expr(ret_entry_.sum, a.value.type());
cache_sum_[ret_entry_.sum] = ret_entry_;
}
return ret_entry_.value;
}
// convert sum to expr
Expr Sum2Expr(const ComExpr& com, Type t) {
Expr vsum;
if (com->base > 0) {
vsum = make_const(t, com->base);
}
for (const ComExprEntry& e : com->elem) {
if (e.scale > 0) {
Expr v = e.value;
if (e.scale != 1) {
v = Mul::make(v, make_const(t, e.scale));
}
if (vsum.defined()) {
vsum = Add::make(vsum, v);
} else {
vsum = v;
}
}
}
if (com->base < 0) {
if (vsum.defined()) {
vsum = Sub::make(vsum, make_const(t, -com->base));
} else {
vsum = make_const(t, com->base);
}
}
for (const ComExprEntry& e : com->elem) {
if (e.scale < 0) {
Expr v = e.value;
if (e.scale != -1) {
v = Mul::make(v, make_const(t, -e.scale));
}
if (vsum.defined()) {
vsum = Sub::make(vsum, v);
} else {
vsum = Sub::make(make_zero(t), v);
}
}
}
if (vsum.defined()) {
return vsum;
} else {
return make_zero(t);
}
}
};
using CInternal = Canonical::Internal;
Canonical::Canonical(Map<Var, Range> vrange)
: ptr_(std::make_shared<Internal>(vrange)) {}
Expr Canonical::Simplify(Expr expr) {
return ptr_->Mutate(expr);
}
Stmt Canonical::Simplify(Stmt stmt) {
return ptr_->Mutate(stmt);
}
void Canonical::SetRange(Var v, Range r, int level) {
ptr_->SetRange(v, r, level);
}
} // namespace arith
namespace ir {
Stmt CanonicalSimplify(Stmt stmt, Map<Var, Range> vrange) {
return arith::Canonical(vrange).Simplify(stmt);
}
Expr CanonicalSimplify(Expr expr, Map<Var, Range> vrange) {
return arith::Canonical(vrange).Simplify(expr);
}
template<typename T>
T Simplify_(T a, Map<Var, Range> vrange) {
using namespace HalideIR::Internal;
Scope<Interval> rscope;
for (auto kv : vrange) {
Range r = kv.second;
rscope.push(
kv.first.get(),
Interval(r->min,
simplify(r->min + r->extent - make_const(r->min.type(), 1))));
}
return HalideIR::Internal::simplify(a, true, rscope);
}
/*!
* \brief Simplify just the combiner of the given reduce node.
*
* This function applies Simplify to the components of the top reduction's
* combiner, but not to the source or condition of the reduction.
* It also removes all components which are not used to
* compute the resulting value (the value_index-th value).
*
* If \p expr is not a reduction node, it is left unchanged.
*
* \param expr The expression to be simplifed.
* \return Simplified expression.
*/
Expr SimplifyCombiner(const Expr& expr, const Map<Var, Range>& vrange = Map<Var, Range>()) {
const Reduce* op = expr.as<Reduce>();
if (!op) {
return expr;
}
// First simplify the results
Array<Expr> simplified_result;
for (const auto& res : op->combiner->result) {
simplified_result.push_back(Simplify(res, vrange));
}
// Which components to keep
std::vector<int> used(op->combiner->result.size(), false);
// This function recursively marks the used components starting from
// the index idx
std::function<void(int)> mark_used;
mark_used = [&used, &simplified_result, op, &mark_used](size_t idx) {
// if the idx-th component was marked as used before, do nothing
if (used[idx]) return;
used[idx] = true;
// check if the idx-th result expr uses some lhs or rhs variables
// and recursively mark the corresponding components
for (size_t i = 0; i < simplified_result.size(); ++i)
if (!used[i]) {
if (ExprUseVar(simplified_result[idx], op->combiner->lhs[i]) ||
ExprUseVar(simplified_result[idx], op->combiner->rhs[i]))
mark_used(i);
}
};
// mark all used components starting from the value_index
mark_used(op->value_index);
// components which have side effects should also be preserved
for (size_t i = 0; i < used.size(); ++i) {
if (HasSideEffect(op->source[i]) || HasSideEffect(op->combiner->identity_element[i]) ||
HasSideEffect(op->combiner->result[i])) {
mark_used(i);
}
}
int new_value_index = op->value_index;
Array<Expr> new_result;
Array<Expr> new_identity;
Array<Var> new_lhs;
Array<Var> new_rhs;
Array<Expr> new_source;
// new stuff is old stuff which is used
for (size_t i = 0; i < used.size(); ++i) {
if (used[i]) {
// We simplify the result and identity, but not the source
new_result.push_back(simplified_result[i]);
new_identity.push_back(Simplify(op->combiner->identity_element[i], vrange));
new_lhs.push_back(op->combiner->lhs[i]);
new_rhs.push_back(op->combiner->rhs[i]);
new_source.push_back(op->source[i]);
} else if (static_cast<int>(i) < op->value_index) {
// value_index should also be adjusted
new_value_index--;
}
}
CommReducer new_combiner = CommReducerNode::make(new_lhs, new_rhs, new_result, new_identity);
return Reduce::make(new_combiner, new_source, op->axis, op->condition, new_value_index);
}
/*!
* \brief Remove a single reduction over empty axis.
*
* If \p e is a reduction node and its axis is empty, replace it with its source,
* otherwise return \p e unchanged.
*
* \param e The expression to be transformed.
* \return The transformed expression.
*/
Expr RemoveEmptyReduction(const Expr& e) {
const Reduce* r = e.as<Reduce>();
if (r && r->axis.empty()) {
// Note that here we assume that the identity element is indeed identity. Without this
// assumption we would have to perform a single iteration of the loop, i.e. use
// `(*r->combiner.get())(r->combiner->identity_element, r->source)[r->value_index]`
// instead of `r->source[r->value_index]`. The former may be more difficult to simplify.
return Select::make(r->condition,
r->source[r->value_index],
r->combiner->identity_element[r->value_index]);
}
return e;
}
Expr Simplify(Expr a, Map<Var, Range> vrange) {
// We should not pass an expression having a non-HalideIR op to
// Halide::Internal::simplify. Reduce op is the only such op at this time
// and it only appears as the top op in an expression. So we strip it
// first and send the sub-expressions to the simplifier.
if (const Reduce* r = a.as<Reduce>()) {
// If axis is empty, we can remove the reduce op completely.
if (r->axis.empty())
return Simplify_(RemoveEmptyReduction(a), vrange);
// Simplify the combiner of the reduction
a = SimplifyCombiner(a, vrange);
r = a.as<Reduce>();
// If axis is not empty then we add the information about ranges to vrange
for (const IterVar& iv : r->axis) {
if (vrange.count(iv->var)) {
Range existing_range = vrange[iv->var];
CHECK(Equal(existing_range->min, iv->dom->min) &&
Equal(existing_range->extent, iv->dom->extent))
<< "Simplify was given vrange stating that the range of the reduction var "
<< iv << " is " << existing_range << ". This is probably a mistake.";
}
vrange.Set(iv->var, iv->dom);
}
Array<Expr> new_source;
for (auto& e : r->source) {
new_source.push_back(Simplify_(e, vrange));
}
Expr new_condition = Simplify_(r->condition, vrange);
if (r->source.same_as(new_source) &&
r->condition.same_as(new_condition)) {
return a;
} else {
return Reduce::make(
r->combiner, new_source, r->axis, new_condition, r->value_index);
}
}
return Simplify_(a, vrange);
}
Stmt Simplify(Stmt a, Map<Var, Range> vrange) {
return Simplify_(a, vrange);
}
} // namespace ir
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file canonical.h
* \brief Internal canonicalized expression simplification engine.
*/
#ifndef TVM_ARITHMETIC_CANONICAL_H_
#define TVM_ARITHMETIC_CANONICAL_H_
#include <tvm/expr.h>
#include <tvm/schedule.h>
#include <memory>
namespace tvm {
namespace arith {
/*!
* \brief A stateful CanonicalEngine over SSA.
*
* Simplify and CSE with canonicalization expressions.
* Each call's result will get cached, so next call will
* simply return the cached result.
*/
class Canonical {
public:
/*! \brief constructor */
explicit Canonical(Map<Var, Range> var_range);
/*!
* \brief simplify expression e.
* \param expr The expression to be simplified.
*/
Expr Simplify(Expr expr);
/*!
* \brief simplify stmt.
* \param stmt The stmt to be simplified.
*/
Stmt Simplify(Stmt expr);
/*!
* \brief Set range and level variable
* \param v The variable
* \param r The range of the variable, can be undefined.
* \param level The scope level of the variable,
* affect the order of formula in communicative ops.
*/
void SetRange(Var v, Range r, int level);
class Internal;
private:
// Internal pointer
std::shared_ptr<Internal> ptr_;
};
} // namespace arith
} // namespace tvm
#endif // TVM_ARITHMETIC_CANONICAL_H_
/*!
* Copyright (c) 2019 by Contributors
* \file canonical_simplify.cc
* \brief Canonical form based simplification.
*/
#include <tvm/arithmetic.h>
#include <tvm/expr_operator.h>
#include <tvm/ir_mutator.h>
#include "const_fold.h"
#include "pattern_match.h"
#include "rewrite_simplify.h"
namespace tvm {
namespace arith {
using namespace ir;
class SumExpr;
class SplitExpr;
/*!
* \brief Base class of all temporary expression introduced
* for canonicalization.
*/
class CanonicalExprNode : public BaseExprNode {
public:
/*!
* \brief Return the normal Expr that is equivalent to self.
* \note Can mutate the internal data structure.
* \return The normal expression.
*/
virtual Expr Normalize() const = 0;
// overrides
void VisitAttrs(tvm::AttrVisitor* v) final {
}
void accept(HalideIR::Internal::IRVisitor* v, const Expr& e) const final {
LOG(FATAL) << "not supported";
}
IRNodeType type_info() const final {
return IRNodeType::ExtensionExpr;
}
static constexpr const char* _type_key = "arith.CanonicalExpr";
TVM_DECLARE_BASE_NODE_INFO(CanonicalExprNode, BaseExprNode);
};
/*!
* \brief Internal "Split normal form" of expression.
*
* This is a special expression that represents
* a scaled value derived from a split of an index.
*
* result = ((index % upper_factor) / lower_factor) * scale
*/
class SplitExprNode : public CanonicalExprNode {
public:
/*! \brief The base index expression. */
Expr index;
/*! \brief The division factor ratio. */
int64_t lower_factor{1};
/*!
* \brief The upper factor.
* invariance: (upper_factor == kPosInf || upper_factor % lower_factor == 0)
*/
int64_t upper_factor{kPosInf};
/*! \brief scale to the expression. */
int64_t scale{1};
/*! \brief verify that this is a valid entry. */
void Verify() const {
CHECK(upper_factor == kPosInf || upper_factor % lower_factor == 0);
}
Expr NormalizeWithScale(int64_t sscale) const {
Expr res = this->index;
Type dtype = this->type;
if (this->scale == 0) {
return make_const(dtype, 0);
}
if (this->upper_factor != SplitExprNode::kPosInf) {
res = res % make_const(dtype, this->upper_factor);
}
if (this->lower_factor != 1) {
res = res / make_const(dtype, this->lower_factor);
}
sscale *= this->scale;
if (sscale != 1) {
CHECK(!dtype.is_uint() || sscale > 0);
res = res * make_const(dtype, sscale);
}
return res;
}
Expr Normalize() const final {
return NormalizeWithScale(1);
}
void MulToSelf(int64_t scale) {
this->scale *= scale;
}
inline bool IndexEqual(const SplitExpr& other) const;
/*! \brief positive infty */
static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf;
static constexpr const char* _type_key = "arith.SplitExpr";
TVM_DECLARE_NODE_TYPE_INFO(SplitExprNode, CanonicalExprNode);
};
TVM_DEFINE_COW_NODE_REF(SplitExpr, Expr, SplitExprNode);
inline bool SplitExprNode::IndexEqual(const SplitExpr& other) const {
if (index.same_as(other->index)) return true;
return ir::Equal(index, other->index);
}
/*!
* \brief Normal form that represents sum of expressions.
*
* result = sum(args) + base.
*/
class SumExprNode : public CanonicalExprNode {
public:
/*!
* \brief arguments to be summed up.
*
* args are divided into segments with the same index.
* within each segment, the SplitExpr is ordered in descending order of lower_factor.
*/
std::vector<SplitExpr> args;
/*! \brief Base value in the summation. */
int64_t base{0};
/*!
* \brief Return the normal Expr that is equivalent to self.
* \return The normal expression.
*/
Expr Normalize() const final {
// quick path 1.
if (this->args.size() == 0) {
return make_const(this->type, this->base);
}
return Normalize_(this->type,
SimplifySplitExprs(args),
base);
}
/*!
* \brief Whether self is divisible by scale.
* \param scale The scale to be applied.
*/
bool DivisibleBy(int64_t scale) {
if (base % scale != 0) return false;
for (size_t i = 0; i < this->args.size(); ++i) {
if (args[i]->scale % scale != 0) return false;
}
return true;
}
/*!
* \brief mul scale to self.
* \param scale The scale to be applied.
*/
void MulToSelf(int64_t scale) {
this->base *= scale;
for (size_t i = 0; i < this->args.size(); ++i) {
args[i].CopyOnWrite()->scale *= scale;
}
}
/*!
* \brief divide by scale.
* \param scale The scale to be applied.
*/
void DivideBy(int64_t scale) {
CHECK_EQ(this->base % scale, 0);
this->base /= scale;
for (size_t i = 0; i < this->args.size(); ++i) {
CHECK_EQ(args[i]->scale % scale, 0);
args[i].CopyOnWrite()->scale /= scale;
}
}
/*!
* \brief add constant value to self.
* \param value to be added.
*/
void AddToSelf(int64_t value) {
this->base += value;
}
/*!
* \brief self += other * scale;
* \param other The expression to be added.
* \param scale The additional scale on value.
*/
void AddToSelf(SplitExpr other, int64_t scale) {
if (other->scale == 0) return;
// We need to maintain the segment invariance:
// Same index are stored close to each other.
// sorted from big lower_factor to small one.
size_t start = 0;
for (; start < args.size(); ++start) {
if (args[start]->IndexEqual(other)) break;
}
for (size_t j = start; j < args.size(); ++j) {
if (!args[j]->IndexEqual(other) ||
other->lower_factor > args[j]->lower_factor) {
other.CopyOnWrite()->scale *= scale;
this->args.insert(this->args.begin() + j, other);
return;
}
if (other->lower_factor == args[j]->lower_factor &&
other->upper_factor == args[j]->upper_factor) {
args[j].CopyOnWrite()->scale += other->scale * scale;
return;
}
}
// Insert other in the end.
other.CopyOnWrite()->scale *= scale;
this->args.emplace_back(std::move(other));
}
void AddToSelf(const SumExpr& other, int64_t scale);
static constexpr const char* _type_key = "arith.SumExpr";
TVM_DECLARE_NODE_TYPE_INFO(SumExprNode, CanonicalExprNode);
private:
/*!
* \brief Simplify the args by merging SplitExprs
* \param args The original list of arguments.
* \return simplified version.
*/
static std::vector<SplitExpr>
SimplifySplitExprs(std::vector<SplitExpr> args) {
// NOTE: This algorithm relies on the factor that args are divided into segments
// and each segment is sorted in descending order of lower_factor.
for (size_t i = 0; i < args.size(); ++i) {
if (args[i]->scale == 0) continue;
for (size_t j = i + 1; j < args.size(); ++j) {
SplitExpr& lhs = args[i];
SplitExpr& rhs = args[j];
if (!lhs->IndexEqual(rhs)) break;
if (lhs->upper_factor < rhs->lower_factor) break;
if (lhs->lower_factor == rhs->upper_factor &&
lhs->scale % rhs->scale == 0 &&
lhs->lower_factor == (lhs->scale / rhs->scale) * rhs->lower_factor) {
// Rules used in the proof:
//
// Rule 1: (x % (c * s)) / c = (x / c) % s
// Proof:
// x can always be decomposed into p * c * s + q * c + r
// where 0 <= q * c + r < c * s and 0 <= r < c.
// Then, lhs = ((p * c * s + q * c + r) % (c * s)) / c = (q * c + r) / c = q
// rhs = ((p * c * s + q * c + r) / c) % s = (p * s + q) % s = q
// Thus, lhs = rhs
//
// The above proof is for the floordiv.
// The same rule also holds for trucdiv(division rule in C).
// Because both sides only involve mul, div and mod,
// we can take abs of x, c and s, apply the floordiv proof,
// and finally add the sign back.
//
// Rule 2: (x / s) * s + x % s = x (true for both truc and floor div)
//
// General merge condition and proof:
// - x = lhs->index % lhs->upper_factor
// - s = lhs->scale / rhs->scale
// - c = rhs->lower_factor
//
// (x / (c * s)) * s + (x % (c * s)) / c
// => ((x / c) / s) * s + ((x / c) % s)
// => (x / c)
//
// Examples:
//
// (z / 6) * 6 + ((z % 6) / 3) * 3
// => ((z / 6) * 2 + (z % 6) / 3) * 3
// => (z / 3) * 3
// note: x = z, c = 3, s = 2
//
// ((z % 12) / 6) * 6 + ((z % 6) / 3) * 3
// => (((z % 12) / 6) * 2 + ((z % 12) % 6) / 3) * 3
// => ((z % 12) / 3) * 3
// note: x = z % 12, c = 3, s = 2
// note also the invariance lhs->upper_factor % lhs->lower_factor == 0
//
SplitExprNode* merged = rhs.CopyOnWrite();
merged->upper_factor = lhs->upper_factor;
// reset args[i] to be zero.
lhs.CopyOnWrite()->scale = 0;
break;
}
}
}
// sort by the entry
// Here we simply sort by descending order of scales.
// For now, we do not compare by index because that comparison
// can be runtime dependent and create inderminism.
// we do not sort by index for now because it can be costly
// to deep compare Exprs, and address of Vars can be runtime dependent.
//
auto fcompare = [](const SplitExpr& lhs, const SplitExpr& rhs) {
// order by scale first
if (lhs->scale > rhs->scale) return true;
if (lhs->scale < rhs->scale) return false;
// then order by factor
if (lhs->lower_factor > rhs->lower_factor) return true;
if (lhs->lower_factor < rhs->lower_factor) return false;
// then order by upper factor
if (lhs->upper_factor > rhs->upper_factor) return true;
if (lhs->upper_factor < rhs->upper_factor) return false;
// tie.
// TODO(tvm-team) We might consider index as the last comparison point,
// after we make deep comparator more derministic.
// Specifically, we can consider comparing names of vars and break ties with address.
return false;
};
std::stable_sort(args.begin(), args.end(), fcompare);
return args;
}
static Expr Normalize_(Type dtype,
const std::vector<SplitExpr>& args,
int64_t base) {
// Positive scales first
Expr res = make_const(dtype, 0);
for (size_t i = 0; i < args.size(); ++i) {
if (args[i]->scale > 0) {
res = res + args[i]->Normalize();
}
}
if (base > 0) {
res = res + make_const(dtype, base);
}
// negative scales follows using sub.
for (size_t i = 0; i < args.size(); ++i) {
if (args[i]->scale < 0) {
res = res - args[i]->NormalizeWithScale(-1);
}
}
if (base < 0) {
res = res - make_const(dtype, -base);
}
return res;
}
};
TVM_DEFINE_COW_NODE_REF(SumExpr, Expr, SumExprNode);
void SumExprNode::AddToSelf(const SumExpr& other, int64_t scale) {
// NOTE: it is rare to have a balanced long expression,
// linear scan is fine for our case.
for (size_t i = 0; i < other->args.size(); ++i) {
this->AddToSelf(other->args[i], scale);
}
this->AddToSelf(other->base * scale);
}
// Sub-class RewriteSimplifier::Impl to take benefit of
// rewriter for condition simplification etc.
class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
public:
using Rewriter = RewriteSimplifier::Impl;
explicit Impl(Analyzer* parent)
: Rewriter(parent) {}
Expr CanonicalSimplify(Expr expr) {
expr = Mutate(expr);
return expr;
}
// override the original mutate function.
Expr Mutate(Expr expr) final {
expr = IRMutator::Mutate(expr);
return Normalize(expr);
}
// Normal mutation without normalization.
Expr CanonicalMutate(Expr expr) {
return IRMutator::Mutate(expr);
}
using Rewriter::Mutate_;
Expr Mutate_(const Add* op, const Expr& self) final;
Expr Mutate_(const Sub* op, const Expr& self) final;
Expr Mutate_(const Mul* op, const Expr& self) final;
Expr Mutate_(const Div* op, const Expr& self) final;
Expr Mutate_(const Mod* op, const Expr& self) final;
Expr Mutate_(const Reduce* op, const Expr& self) final;
private:
/*!
* \brief compute lhs / cval
* \param lhs The left operand.
* \param cval The constant value.
* \return The result expression;
*/
SplitExpr SplitDivConst(SplitExpr lhs, int64_t cval);
/*!
* \brief compute lhs % cval
* \param lhs The left operand.
* \param cval The constant value.
* \return The result expression;
*/
SplitExpr SplitModConst(SplitExpr lhs, int64_t cval);
/*!
* \brief Detect if psum = q * coeff + r such that (q >= 0 && r >= 0)
* \param psum The sum expression.
* \param coeff The co-efficient.
* \param out_divisible The result divisible component.
* \param out_non_divisible The non-divisible component.
* \return Whether detection is successful.
*/
bool TryLinearEquation(const SumExprNode* psum,
int64_t coeff,
SumExpr* out_divisible,
SumExpr* out_non_divisible);
/*!
* \brief Normalize expr to normal expr.
* \param expr The input expression.
* \return Normalized expr.
*/
Expr Normalize(Expr expr) {
if (const auto* op = expr.as_derived<CanonicalExprNode>()) {
return op->Normalize();
} else {
return expr;
}
}
/*!
* \brief Create a SplitExpr from expr.
* \param expr The input expr.
* \return The transformed SplitExpr.
*/
SplitExpr ToSplitExpr(Expr expr) {
if (const auto* op = expr.as<SplitExprNode>()) {
return GetRef<SplitExpr>(op);
}
if (const auto* op = expr.as_derived<CanonicalExprNode>()) {
expr = op->Normalize();
}
NodePtr<SplitExprNode> n = make_node<SplitExprNode>();
n->type = expr.type();
n->index = std::move(expr);
return SplitExpr(n);
}
/*!
* \brief Create a SumExpr from expr.
* \param expr The input expr.
* \return The transformed SumExpr.
*/
SumExpr ToSumExpr(Expr expr) {
if (const auto* op = expr.as<SumExprNode>()) {
return GetRef<SumExpr>(op);
}
NodePtr<SumExprNode> n = make_node<SumExprNode>();
n->type = expr.type();
if (const auto* op = expr.as<IntImm>()) {
n->base = op->value;
return SumExpr(n);
} else {
n->args.emplace_back(ToSplitExpr(expr));
return SumExpr(n);
}
}
// Simplify the combiner used in reduce.
Expr SimplifyReduceCombiner(const Reduce* op);
};
Expr CanonicalSimplifier::Impl::
Mutate_(const Add* op, const Expr& self) {
if (!IsIndexType(op->type)) {
return Rewriter::Mutate_(op, self);
}
// normalize
Expr a = this->CanonicalMutate(op->a);
Expr b = this->CanonicalMutate(op->b);
// const folding
Expr const_res = TryConstFold<Add>(a, b);
if (const_res.defined()) return const_res;
// canonical form simplification.
SumExpr ret = ToSumExpr(std::move(a));
if (const auto* op = b.as<IntImm>()) {
ret.CopyOnWrite()->AddToSelf(op->value);
} else if (const auto* op = b.as<SumExprNode>()) {
ret.CopyOnWrite()->AddToSelf(GetRef<SumExpr>(op), 1);
} else {
ret.CopyOnWrite()->AddToSelf(ToSplitExpr(b), 1);
}
return ret;
}
Expr CanonicalSimplifier::Impl::
Mutate_(const Sub* op, const Expr& self) {
if (!IsIndexType(op->type)) {
return Rewriter::Mutate_(op, self);
}
// normalize
Expr a = this->CanonicalMutate(op->a);
Expr b = this->CanonicalMutate(op->b);
// const folding
Expr const_res = TryConstFold<Sub>(a, b);
if (const_res.defined()) return const_res;
// canonical form simplification.
SumExpr ret = ToSumExpr(std::move(a));
if (const auto* op = b.as<IntImm>()) {
ret.CopyOnWrite()->AddToSelf(-op->value);
} else if (const auto* op = b.as<SumExprNode>()) {
ret.CopyOnWrite()->AddToSelf(GetRef<SumExpr>(op), -1);
} else {
ret.CopyOnWrite()->AddToSelf(ToSplitExpr(b), -1);
}
return ret;
}
Expr CanonicalSimplifier::Impl::
Mutate_(const Mul* op, const Expr& self) {
if (!IsIndexType(op->type)) {
return Rewriter::Mutate_(op, self);
}
// normalize
Expr a = this->CanonicalMutate(op->a);
Expr b = this->CanonicalMutate(op->b);
// const folding
Expr const_res = TryConstFold<Mul>(a, b);
if (const_res.defined()) return const_res;
// x * c
if (a.as<IntImm>()) {
std::swap(a, b);
}
if (const auto* bconst = b.as<IntImm>()) {
if (a.as<SumExprNode>()) {
SumExpr ret(std::move(a.node_));
ret.CopyOnWrite()->MulToSelf(bconst->value);
return ret;
} else {
SplitExpr ret = ToSplitExpr(std::move(a));
ret.CopyOnWrite()->MulToSelf(bconst->value);
return ret;
}
}
// normal path.
a = Normalize(a);
b = Normalize(b);
if (op->a.same_as(a) && op->b.same_as(b)) {
return self;
} else {
return Mul::make(a, b);
}
}
bool CanonicalSimplifier::Impl::
TryLinearEquation(const SumExprNode* psum,
int64_t coeff,
SumExpr* out_divisible,
SumExpr* out_non_divisible) {
auto divisible = make_node<SumExprNode>();
auto non_divisible = make_node<SumExprNode>();
divisible->type = psum->type;
non_divisible->type = psum->type;
if (psum->base % coeff == 0) {
divisible->base = psum->base;
} else {
non_divisible->base = psum->base;
}
for (const auto& e : psum->args) {
if (e->scale % coeff == 0) {
divisible->args.push_back(e);
} else {
non_divisible->args.push_back(e);
}
}
*out_divisible = SumExpr(divisible);
*out_non_divisible = SumExpr(non_divisible);
if (non_divisible->base == 0 && non_divisible->args.size() == 0) {
return true;
}
if (parent_->CanProveGreaterEqual(divisible->Normalize(), 0) &&
parent_->CanProveGreaterEqual(non_divisible->Normalize(), 0)) {
return true;
} else {
return false;
}
}
SplitExpr CanonicalSimplifier::Impl::
SplitDivConst(SplitExpr lhs, int64_t cval) {
if (lhs->scale % cval == 0) {
lhs.CopyOnWrite()->scale /= cval;
return lhs;
}
if (cval % lhs->scale == 0) {
int64_t scaled_cval = cval / lhs->scale;
if (lhs->upper_factor == SplitExprNode::kPosInf ||
lhs->upper_factor % (lhs->lower_factor * scaled_cval) == 0) {
// directly fold division.
lhs.CopyOnWrite()->scale = 1;
lhs.CopyOnWrite()->lower_factor *= scaled_cval;
lhs->Verify();
return lhs;
} else if (lhs->upper_factor <= (lhs->lower_factor * scaled_cval)) {
// (x % c1) / c2 => 0 when c2 >= c1
return ToSplitExpr(make_zero(lhs.type()));
} else {
// move the upper_factor modular into index.
lhs.CopyOnWrite()->index =
lhs->index % make_const(lhs.type(), lhs->upper_factor);
lhs.CopyOnWrite()->upper_factor = SplitExprNode::kPosInf;
lhs.CopyOnWrite()->scale = 1;
lhs.CopyOnWrite()->lower_factor *= scaled_cval;
lhs->Verify();
return lhs;
}
}
// directly return the split with cval == 1
lhs = ToSplitExpr(Normalize(lhs));
CHECK_EQ(lhs->scale, 1);
lhs.CopyOnWrite()->lower_factor *= cval;
return lhs;
}
Expr CanonicalSimplifier::Impl::
Mutate_(const Div* op, const Expr& self) {
if (!IsIndexType(op->type)) {
return Rewriter::Mutate_(op, self);
}
Expr a = this->CanonicalMutate(op->a);
Expr b = this->CanonicalMutate(op->b);
// const folding
Expr const_res = TryConstFold<Div>(a, b);
if (const_res.defined()) return const_res;
PVar<Integer> c1;
// x / c1
if (c1.Match(b) && c1.Eval()->value > 0) {
int64_t cval = c1.Eval()->value;
if (cval == 1) return a;
if (const auto* psum = a.as<SumExprNode>()) {
SumExpr lhs, extra;
if (TryLinearEquation(psum, cval, &lhs, &extra)) {
lhs.CopyOnWrite()->DivideBy(cval);
Expr temp = Normalize(extra);
if (const auto* pconst = temp.as<IntImm>()) {
lhs.CopyOnWrite()->AddToSelf(pconst->value / cval);
} else {
// if extra <= cval, it means the extra can be eliminated.
if (TryCompare(temp, cval) != kLT) {
lhs.CopyOnWrite()->AddToSelf(
SplitDivConst(ToSplitExpr(temp), cval), 1);
}
}
return lhs;
}
} else {
// if a >= 0 && a < cval, then result == 0
auto cbound = parent_->const_int_bound(Normalize(a));
if (cbound->min_value >= 0 && cbound->max_value < cval) {
return make_zero(a.type());
}
}
return SplitDivConst(ToSplitExpr(std::move(a)), cval);
}
// normal path
a = Normalize(a);
b = Normalize(b);
if (op->a.same_as(a) && op->b.same_as(b)) {
return self;
} else {
return Div::make(a, b);
}
}
SplitExpr CanonicalSimplifier::Impl::
SplitModConst(SplitExpr lhs, int64_t cval) {
if (lhs->scale % cval == 0) {
lhs.CopyOnWrite()->scale = 0;
return lhs;
}
if (cval % lhs->scale == 0) {
// (x * c1) % (c2 * c1) => (x % c2) * c1
int64_t scaled_cval = cval / lhs->scale;
// (x / c1) % c2 => (x % (c1 * c2)) / c2
int64_t new_upper_factor = lhs->lower_factor * scaled_cval;
// try to see if we can reduce the existing upper modular.
if (lhs->upper_factor == SplitExprNode::kPosInf ||
lhs->upper_factor % new_upper_factor == 0) {
lhs.CopyOnWrite()->upper_factor = new_upper_factor;
lhs->Verify();
return lhs;
} else if (new_upper_factor % lhs->upper_factor == 0) {
// (x % 2) % 4 => x % 2
return lhs;
}
}
// Normalize the value.
lhs = ToSplitExpr(Normalize(lhs));
CHECK_EQ(lhs->scale, 1);
CHECK_EQ(lhs->lower_factor, 1);
lhs.CopyOnWrite()->upper_factor = cval;
return lhs;
}
Expr CanonicalSimplifier::Impl::
Mutate_(const Mod* op, const Expr& self) {
if (!IsIndexType(op->type)) {
return Rewriter::Mutate_(op, self);
}
// normalize
Expr a = this->CanonicalMutate(op->a);
Expr b = this->CanonicalMutate(op->b);
// const folding
Expr const_res = TryConstFold<Mod>(a, b);
if (const_res.defined()) return const_res;
PVar<Integer> c1;
// x % c1
if (c1.Match(b) && c1.Eval()->value > 0) {
int64_t cval = c1.Eval()->value;
if (const auto* psum = a.as<SumExprNode>()) {
SumExpr lhs, extra;
if (TryLinearEquation(psum, cval, &lhs, &extra)) {
Expr temp = Normalize(extra);
if (temp.as<IntImm>()) {
return temp % c1.Eval();
} else {
// If temp < cval && temp >=0 then can remove the mod.
if (TryCompare(temp, cval) == kLT) {
return temp;
} else {
return SplitModConst(ToSplitExpr(temp), cval);
}
}
}
} else {
// if a >= 0 && a < cval, then result == 0
auto cbound = parent_->const_int_bound(Normalize(a));
if (cbound->min_value >= 0 && cbound->max_value < cval) {
return a;
}
}
return SplitModConst(ToSplitExpr(std::move(a)), cval);
}
// normal path
a = Normalize(a);
b = Normalize(b);
if (op->a.same_as(a) && op->b.same_as(b)) {
return self;
} else {
return Mod::make(a, b);
}
}
// Simplify reduce expression.
Expr CanonicalSimplifier::Impl::
SimplifyReduceCombiner(const Reduce* op) {
// First simplify the results
Array<Expr> simplified_result;
for (const auto& res : op->combiner->result) {
Expr new_res = Mutate(res);
simplified_result.push_back(new_res);
}
// Which components to keep
std::vector<int> used(op->combiner->result.size(), false);
// This function recursively marks the used components starting from
// the index idx
std::function<void(int)> mark_used;
mark_used = [&used, &simplified_result, op, &mark_used](size_t idx) {
// if the idx-th component was marked as used before, do nothing
if (used[idx]) return;
used[idx] = true;
// check if the idx-th result expr uses some lhs or rhs variables
// and recursively mark the corresponding components
for (size_t i = 0; i < simplified_result.size(); ++i)
if (!used[i]) {
if (ExprUseVar(simplified_result[idx], op->combiner->lhs[i]) ||
ExprUseVar(simplified_result[idx], op->combiner->rhs[i]))
mark_used(i);
}
};
// mark all used components starting from the value_index
mark_used(op->value_index);
// components which have side effects should also be preserved
for (size_t i = 0; i < used.size(); ++i) {
if (HasSideEffect(op->source[i]) ||
HasSideEffect(op->combiner->identity_element[i]) ||
HasSideEffect(op->combiner->result[i])) {
mark_used(i);
}
}
int new_value_index = op->value_index;
Array<Expr> new_result;
Array<Expr> new_identity;
Array<Var> new_lhs;
Array<Var> new_rhs;
Array<Expr> new_source;
// new stuff is old stuff which is used
for (size_t i = 0; i < used.size(); ++i) {
if (used[i]) {
// We simplify the result and identity, but not the source
new_result.push_back(simplified_result[i]);
new_identity.push_back(Mutate(op->combiner->identity_element[i]));
new_lhs.push_back(op->combiner->lhs[i]);
new_rhs.push_back(op->combiner->rhs[i]);
new_source.push_back(op->source[i]);
} else if (static_cast<int>(i) < op->value_index) {
// value_index should also be adjusted
new_value_index--;
}
}
CommReducer new_combiner =
CommReducerNode::make(new_lhs, new_rhs, new_result, new_identity);
return Reduce::make(
new_combiner, new_source, op->axis, op->condition, new_value_index);
}
Expr CanonicalSimplifier::Impl::
Mutate_(const Reduce* op, const Expr& self) {
// Setup the domain information before simplification.
for (const IterVar& iv : op->axis) {
parent_->Bind(iv->var, iv->dom);
}
// Recursively call simplification when necessary.
Expr ret = IRMutator::Mutate_(op, self);
op = ret.as<Reduce>();
// already been simplified by const reduction axis removal
if (op == nullptr) return ret;
if (op->axis.empty()) {
// Note that here we assume that the identity element is indeed identity. Without this
// assumption we would have to perform a single iteration of the loop, i.e. use
// `(*op->combiner.get())(op->combineop->identity_element, op->source)[op->value_index]`
// instead of `op->source[op->value_index]`. The former may be more difficult to simplify.
return Mutate(
Select::make(op->condition,
op->source[op->value_index],
op->combiner->identity_element[op->value_index]));
}
// combiner simplification.
ret = SimplifyReduceCombiner(op);
return ret;
}
Expr CanonicalSimplifier::operator()(const Expr& expr) {
return impl_->CanonicalSimplify(expr);
}
void CanonicalSimplifier::Update(const Var& var,
const Expr& info,
bool override) {
impl_->Update(var, info, override);
}
CanonicalSimplifier::CanonicalSimplifier(Analyzer* parent)
: impl_(new Impl(parent)) {
}
CanonicalSimplifier::~CanonicalSimplifier() {
delete impl_;
}
} // namespace arith
} // namespace tvm
......@@ -7,6 +7,8 @@
#define TVM_ARITHMETIC_CONST_FOLD_H_
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/expr_operator.h>
#include <algorithm>
namespace tvm {
......
......@@ -37,6 +37,10 @@ struct ConstIntBoundAnalyzer::Entry {
bool is_const(int64_t value) const {
return min_value == max_value && min_value == value;
}
bool operator==(const Entry& other) const {
return min_value == other.min_value && max_value == other.max_value;
}
};
class ConstIntBoundAnalyzer::Impl :
......@@ -55,7 +59,11 @@ class ConstIntBoundAnalyzer::Impl :
const Entry& info,
bool override) {
if (!override) {
CHECK(!var_map_.count(var));
auto it = var_map_.find(var);
if (it != var_map_.end()) {
CHECK(it->second == info)
<< "var \'" << var << "\' already updated.";
}
}
var_map_[var] = info;
}
......
......@@ -7,8 +7,10 @@
#include <tvm/arithmetic.h>
#include <tvm/expr_operator.h>
#include <tvm/ir_mutator.h>
#include <algorithm>
#include "const_fold.h"
#include "pattern_match.h"
#include "rewrite_simplify.h"
namespace tvm {
namespace arith {
......@@ -39,134 +41,55 @@ using namespace ir;
return RecursiveRewrite((ResExpr).Eval()); \
}
// NOTE for developers:
//
// We mainly focus on index expression simplification.
// Besides the RewriteSimplifier, some cases can be better
// handled by CanonicalSimplifier.
//
class RewriteSimplifier::Impl : public IRMutator {
public:
explicit Impl(Analyzer* parent)
: parent_(parent) {}
void Update(const Var& var,
const Expr& info,
bool override) {
if (!override) {
CHECK(!var_map_.count(var));
// try to prove x equals val
RewriteSimplifier::Impl::CompareResult RewriteSimplifier::Impl::
TryCompare(const Expr& x, int64_t val) {
Expr diff = Mutate(x);
if (const auto* ptr = diff.as<IntImm>()) {
if (ptr->value == val) {
return kEQ;
} else if (ptr->value > val) {
return kGT;
} else if (ptr->value < val) {
return kLT;
}
var_map_[var] = info;
}
// Run simplification in post order
Expr PostOrderSimplify(Expr expr, int max_iter = 2) {
for (int i = 0; i < max_iter; ++i) {
Expr new_expr = this->Mutate(expr);
if (new_expr.same_as(expr)) return expr;
expr = new_expr;
if (val == 0) {
ModularSet dmod = parent_->modular_set(diff);
if (dmod->base != 0) {
return kNE;
}
return expr;
}
Expr Mutate_(const Add* op, const Expr& self) final;
Expr Mutate_(const Sub* op, const Expr& self) final;
Expr Mutate_(const Mul* op, const Expr& self) final;
Expr Mutate_(const Div* op, const Expr& self) final;
Expr Mutate_(const Mod* op, const Expr& self) final;
Expr Mutate_(const Min* op, const Expr& self) final;
Expr Mutate_(const Max* op, const Expr& self) final;
Expr Mutate_(const EQ* op, const Expr& self) final;
Expr Mutate_(const NE* op, const Expr& self) final;
Expr Mutate_(const LT* op, const Expr& self) final;
Expr Mutate_(const LE* op, const Expr& self) final;
Expr Mutate_(const GT* op, const Expr& self) final;
Expr Mutate_(const GE* op, const Expr& self) final;
Expr Mutate_(const And* op, const Expr& self) final;
Expr Mutate_(const Or* op, const Expr& self) final;
Expr Mutate_(const Not* op, const Expr& self) final;
Expr Mutate_(const Select* op, const Expr& self) final;
Expr Mutate_(const Ramp* op, const Expr& self) final;
private:
/*! \brief internal structure for comparison. */
enum CompareResult {
kUnknown,
kEQ,
kGT,
kLT,
kGE,
kLE,
kNE
};
// reference to the main analyzer
Analyzer* parent_;
// counter to record recursive rewrite depth.
int recur_depth_{0};
// internal variable map
std::unordered_map<Var, Expr, ExprHash, ExprEqual> var_map_;
// maximum number of recursion allowed during a single pass.
static const constexpr int kMaxRecurDepth = 5;
// Whether x >= val
bool CanProveGreaterEqual(const Expr& x, int64_t val) {
return parent_->CanProveGreaterEqual(x, val);
ConstIntBound dbound = parent_->const_int_bound(diff);
if (dbound->min_value > val) {
return kGT;
}
// Whether x == val
bool CanProveEqual(const Expr& x, int64_t val) {
// TODO(tqchen) refer back to super-analyzer.
return TryCompare(x, val) == kEQ;
if (dbound->max_value < val) {
return kLT;
}
// try to prove x equals val
CompareResult TryCompare(const Expr& x, int64_t val) {
Expr diff = Mutate(x);
if (const auto* ptr = diff.as<IntImm>()) {
if (ptr->value == val) {
return kEQ;
} else if (ptr->value > val) {
return kGT;
} else if (ptr->value < val) {
return kLT;
}
}
if (val == 0) {
ModularSet dmod = parent_->modular_set(diff);
if (dmod->base != 0) {
return kNE;
}
}
ConstIntBound dbound = parent_->const_int_bound(diff);
if (dbound->min_value > val) {
return kGT;
}
if (dbound->max_value < val) {
return kLT;
}
if (dbound->min_value >= val) {
return kGE;
}
if (dbound->max_value <= val) {
return kLE;
}
return kUnknown;
if (dbound->min_value >= val) {
return kGE;
}
// Recursive rewrite x
// we limit maximum depth of recursive rewrite allowed to
// avoid infinite loop
Expr RecursiveRewrite(const Expr& x) {
if (recur_depth_ >= kMaxRecurDepth) return x;
++recur_depth_;
Expr res = Mutate(x);
--recur_depth_;
return res;
if (dbound->max_value <= val) {
return kLE;
}
return kUnknown;
}
template<typename TA>
PConstWithTypeLike<TA> ZeroWithTypeLike(const Pattern<TA>& pattern) {
return PConstWithTypeLike<TA>(pattern.derived(), 0);
void RewriteSimplifier::Impl::
Update(const Var& var, const Expr& info, bool override) {
if (!override) {
CHECK(!var_map_.count(var));
}
};
var_map_[var] = info;
}
Expr RewriteSimplifier::Impl::
Mutate_(const Add* op, const Expr& self) {
......@@ -1254,16 +1177,6 @@ Mutate_(const Or* op, const Expr& self) {
}
Expr RewriteSimplifier::Impl::
Mutate_(const Ramp* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
op = ret.as<Ramp>();
if (is_zero(op->stride)) {
return Broadcast::make(op->base, op->lanes);
}
return ret;
}
Expr RewriteSimplifier::Impl::
Mutate_(const Select* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
op = ret.as<Select>();
......@@ -1275,13 +1188,30 @@ Mutate_(const Select* op, const Expr& self) {
}
// Pattern var to match any expression
PVar<Expr> x, y;
TVM_TRY_REWRITE(select(x, y, y), y);
return ret;
}
Expr RewriteSimplifier::Impl::
Mutate_(const Call* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
op = ret.as<Call>();
if (op->is_intrinsic(Call::likely) && is_const(op->args[0])) {
return op->args[0];
}
return ret;
}
Expr RewriteSimplifier::operator()(const Expr& expr) {
return impl_->PostOrderSimplify(expr);
// Run simplification in post order
Expr res = expr;
int max_iter = 2;
for (int i = 0; i < max_iter; ++i) {
Expr new_expr = impl_->Mutate(res);
if (new_expr.same_as(res)) return res;
res = new_expr;
}
return res;
}
void RewriteSimplifier::Update(const Var& var,
......@@ -1290,7 +1220,6 @@ void RewriteSimplifier::Update(const Var& var,
impl_->Update(var, info, override);
}
RewriteSimplifier::RewriteSimplifier(Analyzer* parent)
: impl_(new Impl(parent)) {
}
......
/*!
* Copyright (c) 2019 by Contributors
* \file rewrite_simplify.h
* \brief Rewrite-rule based simplification.
*/
#ifndef TVM_ARITHMETIC_REWRITE_SIMPLIFY_H_
#define TVM_ARITHMETIC_REWRITE_SIMPLIFY_H_
#include <tvm/arithmetic.h>
#include <tvm/expr_operator.h>
#include <tvm/ir_mutator.h>
#include <unordered_map>
#include "const_fold.h"
#include "pattern_match.h"
namespace tvm {
namespace arith {
using namespace ir;
/*!
* \brief Rewrite-based simplifier.
*
* This class can be inheritated for other simplifiers.
*/
class RewriteSimplifier::Impl : public IRMutator {
public:
explicit Impl(Analyzer* parent)
: parent_(parent) {}
void Update(const Var& var, const Expr& info, bool override);
Expr Mutate_(const Add* op, const Expr& self) override;
Expr Mutate_(const Sub* op, const Expr& self) override;
Expr Mutate_(const Mul* op, const Expr& self) override;
Expr Mutate_(const Div* op, const Expr& self) override;
Expr Mutate_(const Mod* op, const Expr& self) override;
Expr Mutate_(const Min* op, const Expr& self) override;
Expr Mutate_(const Max* op, const Expr& self) override;
Expr Mutate_(const EQ* op, const Expr& self) override;
Expr Mutate_(const NE* op, const Expr& self) override;
Expr Mutate_(const LT* op, const Expr& self) override;
Expr Mutate_(const LE* op, const Expr& self) override;
Expr Mutate_(const GT* op, const Expr& self) override;
Expr Mutate_(const GE* op, const Expr& self) override;
Expr Mutate_(const And* op, const Expr& self) override;
Expr Mutate_(const Or* op, const Expr& self) override;
Expr Mutate_(const Not* op, const Expr& self) override;
Expr Mutate_(const Select* op, const Expr& self) override;
Expr Mutate_(const Call* op, const Expr& self) override;
protected:
/*! \brief internal structure for comparison. */
enum CompareResult {
kUnknown,
kEQ,
kGT,
kGE,
kLT,
kLE,
kNE
};
// reference to the main analyzer
Analyzer* parent_;
// counter to record recursive rewrite depth.
int recur_depth_{0};
// internal variable map
std::unordered_map<Var, Expr, ExprHash, ExprEqual> var_map_;
// maximum number of recursion allowed during a single pass.
static const constexpr int kMaxRecurDepth = 5;
/*!
* \brief try to compare x against val.
* \param x The expression to be evaluated.
* \param val The constant value.
* \return comparison result.
*/
CompareResult TryCompare(const Expr& x, int64_t val);
private:
// Whether x >= val
bool CanProveGreaterEqual(const Expr& x, int64_t val) {
return parent_->CanProveGreaterEqual(x, val);
}
// Whether x == val
bool CanProveEqual(const Expr& x, int64_t val) {
// TODO(tqchen) refer back to super-analyzer.
return TryCompare(x, val) == kEQ;
}
// Recursive rewrite x
// we limit maximum depth of recursive rewrite allowed to
// avoid infinite loop
Expr RecursiveRewrite(const Expr& x) {
if (recur_depth_ >= kMaxRecurDepth) return x;
++recur_depth_;
Expr res = Mutate(x);
--recur_depth_;
return res;
}
template<typename TA>
PConstWithTypeLike<TA> ZeroWithTypeLike(const Pattern<TA>& pattern) {
return PConstWithTypeLike<TA>(pattern.derived(), 0);
}
};
} // namespace arith
} // namespace tvm
#endif // TVM_ARITHMETIC_REWRITE_SIMPLIFY_H_
/*!
* Copyright (c) 2019 by Contributors
* \file stmt_simplify.cc
* \brief Statement simplifier based on analyzer
*/
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/arithmetic.h>
#include <tvm/ir_mutator.h>
#include <tvm/expr_operator.h>
#include <tvm/arithmetic.h>
#include "arithmetic/Simplify.h"
namespace tvm {
namespace arith {
// statement simplifier
using namespace ir;
class StmtSimplifier : public IRMutator {
public:
Stmt Mutate_(const For* op, const Stmt& s) final {
Var loop_var(op->loop_var.node_);
analyzer_.Bind(loop_var, Range::make_by_min_extent(op->min, op->extent));
return IRMutator::Mutate_(op, s);
}
// IfThenElse
Stmt Mutate_(const IfThenElse* op, const Stmt& s) {
Expr condition = this->Mutate(op->condition);
Stmt then_case, else_case;
{
ConstraintContext ctx(&analyzer_, condition);
then_case = this->Mutate(op->then_case);
}
if (op->else_case.defined()) {
ConstraintContext ctx(&analyzer_, Mutate(Not::make(condition)));
else_case = this->Mutate(op->else_case);
}
if (is_one(condition)) return then_case;
if (is_zero(condition)) {
if (else_case.defined()) {
return else_case;
}
return Evaluate::make(0);
}
if (condition.same_as(op->condition) &&
then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return s;
} else {
return IfThenElse::make(condition, then_case, else_case);
}
}
// AttrStmt
Stmt Mutate_(const AttrStmt* op, const Stmt& s) {
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread) {
IterVar iv(op->node.node_);
CHECK_NE(iv->thread_tag.length(), 0U);
if (!var_dom_.count(iv->var.get())) {
Range dom = Range::make_by_min_extent(0, op->value);
var_dom_[iv->var.get()] = dom;
analyzer_.Bind(iv->var, dom);
}
Stmt stmt = IRMutator::Mutate_(op, s);
return stmt;
} else {
return IRMutator::Mutate_(op, s);
}
}
// AssertStmt
Stmt Mutate_(const AssertStmt* op, const Stmt& s) final {
Expr condition = this->Mutate(op->condition);
Expr message = this->Mutate(op->message);
ConstraintContext ctx(&analyzer_, condition);
Stmt body = this->Mutate(op->body);
if (condition.same_as(op->condition) &&
message.same_as(op->message) &&
body.same_as(op->body)) {
return s;
} else {
return AssertStmt::make(condition, message, body);
}
}
protected:
Analyzer analyzer_;
// variable domain
std::unordered_map<const Variable*, Range> var_dom_;
};
class CanonicalStmtSimplifier : public StmtSimplifier {
public:
using StmtSimplifier::Mutate;
Expr Mutate(Expr expr) final {
return analyzer_.canonical_simplify(expr);
}
Stmt CanonicalSimplify(Stmt stmt, Map<Var, Range> vrange) {
for (auto kv : vrange) {
analyzer_.Bind(kv.first, kv.second);
}
return Mutate(stmt);
}
};
} // namespace arith
namespace ir {
Stmt CanonicalSimplify(Stmt stmt, Map<Var, Range> vrange) {
return arith::CanonicalStmtSimplifier().CanonicalSimplify(
stmt, vrange);
}
Expr CanonicalSimplify(Expr expr, Map<Var, Range> vrange) {
arith::Analyzer analyzer;
for (auto kv : vrange) {
analyzer.Bind(kv.first, kv.second);
}
return analyzer.canonical_simplify(expr);
}
template<typename T>
T Simplify_(T a, Map<Var, Range> vrange) {
using namespace HalideIR::Internal;
Scope<Interval> rscope;
for (auto kv : vrange) {
Range r = kv.second;
rscope.push(
kv.first.get(),
Interval(r->min,
simplify(r->min + r->extent - make_const(r->min.type(), 1))));
}
return HalideIR::Internal::simplify(a, true, rscope);
}
Expr Simplify(Expr a, Map<Var, Range> vrange) {
// Simplify top level reduce.
if (const Reduce* r = a.as<Reduce>()) {
Array<Expr> new_source;
for (auto& e : r->source) {
new_source.push_back(Simplify_(e, vrange));
}
Expr new_condition = Simplify_(r->condition, vrange);
if (r->source.same_as(new_source) &&
r->condition.same_as(new_condition)) {
return a;
} else {
return Reduce::make(
r->combiner, new_source, r->axis, new_condition, r->value_index);
}
}
return Simplify_(a, vrange);
}
Stmt Simplify(Stmt a, Map<Var, Range> vrange) {
return Simplify_(a, vrange);
}
} // namespace ir
} // namespace tvm
......@@ -78,7 +78,7 @@ class ThreadAllreduceBuilder final : public IRMutator {
Expr Mutate_(const Load* op, const Expr& e) final {
auto it = load_remap_.find(op->buffer_var.get());
if (it != load_remap_.end()) {
CHECK(is_zero(op->index));
CHECK(is_zero(op->index)) << e;
return it->second;
} else {
return IRMutator::Mutate_(op, e);
......
import tvm
class CanonicalChecker:
def __init__(self):
self.analyzer = tvm.arith.Analyzer()
def verify(self, data, expected):
res = self.analyzer.canonical_simplify(data)
assert tvm.ir_pass.Equal(res, expected), "data={}, res={}, expected={}".format(data, res, expected)
def test_mul_sum_simplify():
ck = CanonicalChecker()
x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
ck.verify(2 + (3 * x + z + y + 1) * 4 + x,
x * 13 + z * 4 + y * 4 +6)
ck.verify((x + y + x + y * 3) / 2, y * 2 + x)
ck.verify((x + y + x + y * 3) % 2, 0)
ck.verify(x * 3 - 4 * x + 1, 1 - x)
ck.verify(y + x * 3 - 5 * x + 1 + y, y * 2 + 1 - x * 2)
def test_split_index_simplify():
ck = CanonicalChecker()
x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
ck.verify((x/3) *3 + x % 3, x)
ck.verify((x/6) * 6 + ((x/3) % 2) * 3 + x % 3, x)
# split div const
ck.verify(((x % 16) / 2) * 2 / 4, (x % 16) / 4)
ck.verify((x % 2) / 8, 0)
ck.verify((x % 2) / 7, 0)
ck.verify(((x % 16) / 2) * 2 / 6, (x % 16) / 6)
# split mod const
ck.verify((x * 8) % 16, (x % 2) * 8)
ck.verify((x * 8) % 2, 0)
# simplify then fold
ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000))
ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1000))
ck.verify((x * 4 + y) / 2 * 2 + (x * 4 + y) % 2, x * 4 + y)
# complex fold
ck.verify((z * 9 + y) / 2 * 2 + (z * 9 + y) % 2, z * 9 + y)
def test_div_simplify():
ck = CanonicalChecker()
x = tvm.var("x")
ck.verify((16+48*x)/16, x*3 + 1)
# (17+48*x)/16 is not simplifiable for arbitrary x because when 17+48*x<0
# (17+48*x)/16 != 1+3*x
ck.verify((17+48*x)/16, (x * 48 + 17) / 16)
# However, when x >= 0, then 17+48*x >= 0 and (17+48*x)/16 can be simplified
ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 10))
ck.verify((17+48*x)/16, x * 3 + 1)
# Trying expressions that are not simplifiable for any values of the variables
ck.verify((17+47*x)/16, (x * 47 + 17) / 16)
def test_canonical_mixed():
ck = CanonicalChecker()
x = tvm.var("x")
z = tvm.const(3, "int32")
ck.verify(x / (z*z) - x / (z*z), 0)
ck.verify(x / (z+z) - x / (z+z), 0)
def test_reduce_combiner_simplify():
ck = CanonicalChecker()
dummy = tvm.var('dummy')
comm_reducer = tvm.comm_reducer
prod = comm_reducer(lambda x, y: x*y, lambda t0: tvm.const(1, t0))
sum_or_prod = comm_reducer(
lambda x, y: tvm.expr.Select(dummy < 0,
x + y, x*y),
lambda t0: tvm.expr.Select(dummy < 0,
tvm.const(0, t0), tvm.const(1, t0)))
sum_and_prod = comm_reducer(
lambda x, y: (x[0] + y[0],
x[1]*y[1]),
lambda t0, t1: (tvm.const(0, t0),
tvm.const(5, t0) - tvm.const(4, t0)))
some_reducer1 = comm_reducer(
lambda x, y: (x[0] + y[0],
x[0] + y[0] + x[1] + y[1],
x[0]*y[2] + y[0]*x[2],
x[1] + y[2],
4.0),
lambda t0, t1, t2, t3, t4: (tvm.const(0, t0),
tvm.const(1, t1),
tvm.const(2, t2),
tvm.const(3, t3),
tvm.const(4, t4)))
k = tvm.reduce_axis((0, 10), name="k")
A = tvm.placeholder((10,), name='A')
# Test that SimplifyCombiner makes use of vranges
ck.analyzer.update(dummy, tvm.arith.ConstIntBound(-10, -4))
ck.verify(sum_or_prod(A[k], k), tvm.sum(A[k], k))
ck.analyzer.update(dummy, tvm.arith.ConstIntBound(5, 9), True)
ck.verify(sum_or_prod(A[k], k), prod(A[k], k))
ck.analyzer.update(dummy, tvm.arith.ConstIntBound(-10, 100), True)
ck.verify(sum_and_prod((A[k], A[10-k]), k)[0], tvm.sum(A[k], k))
ck.verify(sum_and_prod((A[k], A[10-k]), k)[1], prod(A[10-k], k))
reference_simplified_sources = [[A[0]],
[A[0], A[1]],
[A[0], A[2]],
[A[0], A[1], A[2], A[3]],
[A[4]]]
for j in range(5):
# Here we use the j-th component of the result, so only it and the components it
# depends on are left.
simplified = ck.analyzer.canonical_simplify(
some_reducer1((A[0], A[1], A[2], A[3], A[4]), k)[j])
# Check that the remaining components are the expected ones.
for lhs, rhs in zip(simplified.source, reference_simplified_sources[j]):
assert tvm.ir_pass.Equal(lhs, rhs)
# Test that components with side effects are not removed
side_effect = lambda *xs: tvm.make.Call("int32", "dummy", xs, tvm.expr.Call.Intrinsic, None, 0)
ck.verify(sum_and_prod((A[k], side_effect(A[10-k])), k)[0],
sum_and_prod((A[k], side_effect(A[10-k])), k)[0])
ck.verify(sum_and_prod((side_effect(A[k]), A[10-k]), k)[0],
tvm.sum(side_effect(A[k]), k))
def test_reduce_simplify():
ck = CanonicalChecker()
k = tvm.reduce_axis((0, 10), name="k")
j = tvm.reduce_axis((-5, 3), name="j")
A = tvm.placeholder((10,), name='A')
ck.verify(tvm.sum(tvm.expr.Select(k + j < 12, k + j, 0), [k, j]),
tvm.sum(k + j, [k, j]))
ck.verify(tvm.sum(A[3], []), A[3])
# The rule below is not typical, removed for now
ck.verify(tvm.sum(k / 10, k), tvm.sum(tvm.const(0, "int32"), k))
if __name__ == "__main__":
test_div_simplify()
test_reduce_simplify()
test_reduce_combiner_simplify()
test_mul_sum_simplify()
test_split_index_simplify()
test_canonical_mixed()
......@@ -21,7 +21,6 @@ def test_simplify():
assert zz.a == x and zz.b.value == 4
n = tvm.var('n')
assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(n % (-1)), tvm.const(0, "int32"))
assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(n % 1), tvm.const(0, "int32"))
assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(n / 1), n)
tvm.ir_pass.CanonicalSimplify(n / (-1))
......@@ -29,36 +28,16 @@ def test_simplify():
# assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(n / (-1)),
# tvm.ir_pass.CanonicalSimplify(-n))
def test_simplify_div():
x = tvm.var('x')
assert tvm.ir_pass.CanonicalSimplify((16+48*x)/16 - (1 + (x*3))).value == 0
# (17+48*x)/16 is not simplifiable for arbitrary x because when 17+48*x<0
# (17+48*x)/16 != 1+3*x
r = tvm.ir_pass.CanonicalSimplify((17+48*x)/16)
assert r.b.value == 16
assert tvm.ir_pass.CanonicalSimplify(r.a - (17 + 48*x)).value == 0
# However, when x >= 0, then 17+48*x >= 0 and (17+48*x)/16 can be simplified
assert tvm.ir_pass.CanonicalSimplify((17+48*x)/16 - (1 + (x*3)), {x: tvm.Range(0,10)}).value == 0
# Trying expressions that are not simplifiable for any values of the variables
r = tvm.ir_pass.CanonicalSimplify((17+47*x)/16, {x: tvm.Range(0,10)})
assert r.b.value == 16
assert tvm.ir_pass.CanonicalSimplify(r.a - (17+47*x)).value == 0
r = tvm.ir_pass.CanonicalSimplify((8*x - 17)/8, {x : tvm.Range(4,10)})
assert tvm.ir_pass.CanonicalSimplify(r - (x-3)).value == 0
def test_simplify_mod():
"""Not yet working, mock design"""
ib = tvm.ir_builder.create()
n = tvm.var('n')
j = tvm.var('j')
A = ib.pointer("float32", name="A")
with ib.for_range(0, 16, name="i") as i:
A[i] = A[((n * 4 + j * 2) * 8 + i+1) % 16]
with ib.for_range(0, 10, name="j") as j:
with ib.for_range(0, 16, name="i") as i:
A[i] = A[(j * 32 + i+1) % 16]
body = ib.get()
stmt = tvm.ir_pass.CanonicalSimplify(body)
diff = tvm.ir_pass.CanonicalSimplify(stmt.body.value.index - (1 + i) % 16)
diff = tvm.ir_pass.CanonicalSimplify(stmt.body.body.value.index - (1 + i) % 16)
assert diff.value == 0
# if we can't prove that j+n*32 is non-negative, we can't prove that (j+n*32) % 16 is j%16
index = tvm.ir_pass.CanonicalSimplify(
......@@ -95,8 +74,8 @@ def test_modular():
y: tvm.Range(i32_const(0), i32_const(2)),
x: tvm.Range(i32_const(0), i32_const(14))}
idx = ry * 16 + rx + y * 16 + x
z1 = tvm.ir_pass.CanonicalSimplify(idx // 16, vmap)
z2 = tvm.ir_pass.CanonicalSimplify(idx % 16, vmap)
z1 = tvm.ir_pass.CanonicalSimplify(idx // 16, vmap)
assert tvm.ir_pass.CanonicalSimplify(z1 - (ry + y)).value == 0
assert tvm.ir_pass.CanonicalSimplify(z2 - (rx + x)).value == 0
......@@ -117,10 +96,9 @@ def test_const_propagation():
if __name__ == "__main__":
test_simplify_div()
test_simplify_mod()
test_modular()
test_simplify()
test_mul()
test_simplify_minmax()
test_const_propagation()
test_simplify_mod()
......@@ -35,109 +35,8 @@ def test_bound():
ret = tvm.ir_pass.Simplify(m % 10, vrange)
assert ret == m
def test_canonical():
x = tvm.var("x")
z = tvm.const(3, "int32")
ret = tvm.ir_pass.CanonicalSimplify(x / (z*z) - x / (z*z))
assert(tvm.ir_pass.Equal(ret, 0))
ret = tvm.ir_pass.CanonicalSimplify(x / (z+z) - x / (z+z))
assert(tvm.ir_pass.Equal(ret, 0))
#make sure terms are ordered based on their top operators (e.g., / always precedes %)
ret1 = tvm.ir_pass.CanonicalSimplify(x % 3 + x / 3)
ret2 = tvm.ir_pass.CanonicalSimplify(x / 3 + x % 3)
assert(tvm.ir_pass.Equal(ret1, ret2))
#when top operators match, compare string representation of terms
ret1 = tvm.ir_pass.CanonicalSimplify(x % 4 + x % 3)
ret2 = tvm.ir_pass.CanonicalSimplify(x % 3 + x % 4)
assert (tvm.ir_pass.Equal(ret1, ret2))
def test_simplify_combiner():
dummy = tvm.var('dummy')
prod = comm_reducer(lambda x, y: x*y, lambda t0: tvm.const(1, t0))
sum_or_prod = comm_reducer(lambda x, y: tvm.expr.Select(dummy < 0,
x + y, x*y),
lambda t0: tvm.expr.Select(dummy < 0,
tvm.const(0, t0), tvm.const(1, t0)))
sum_and_prod = comm_reducer(lambda x, y: (x[0] + y[0],
x[1]*y[1]),
lambda t0, t1: (tvm.const(0, t0),
tvm.const(5, t0) - tvm.const(4, t0)))
sum_and_prod2 = comm_reducer(lambda x, y: (x[0] + y[0],
x[1]*y[1] + 0*x[0] + y[0] - y[0]),
lambda t0, t1: (tvm.const(5, t0) - tvm.const(5, t0),
tvm.const(1, t1)))
some_reducer1 = comm_reducer(lambda x, y: (x[0] + y[0],
x[0] + y[0] + x[1] + y[1],
x[0]*y[2] + y[0]*x[2],
x[1] + y[2],
4.0),
lambda t0, t1, t2, t3, t4: (tvm.const(0, t0),
tvm.const(1, t1),
tvm.const(2, t2),
tvm.const(3, t3),
tvm.const(4, t4)))
k = tvm.reduce_axis((0, 10), name="k")
A = tvm.placeholder((10,), name='A')
# Test that SimplifyCombiner makes use of vranges
vrange = {dummy: tvm.Range(-10, -5)}
assert Equal(Simplify(sum_or_prod(A[k], k), vrange), tvm.sum(A[k], k))
vrange = {dummy: tvm.Range(5, 10)}
assert Equal(Simplify(sum_or_prod(A[k], k), vrange), prod(A[k], k))
assert Equal(Simplify(sum_and_prod((A[k], A[10-k]), k)[0]), tvm.sum(A[k], k))
assert Equal(Simplify(sum_and_prod((A[k], A[10-k]), k)[1]), prod(A[10-k], k))
assert Equal(Simplify(sum_and_prod2((A[k], A[10-k]), k)[0]), tvm.sum(A[k], k))
assert Equal(Simplify(sum_and_prod2((A[k], A[10-k]), k)[1]), prod(A[10-k], k))
reference_simplified_sources = [[A[0]],
[A[0], A[1]],
[A[0], A[2]],
[A[0], A[1], A[2], A[3]],
[A[4]]]
for j in range(5):
# Here we use the j-th component of the result, so only it and the components it
# depends on are left.
simplified = Simplify(some_reducer1((A[0], A[1], A[2], A[3], A[4]), k)[j])
# Check that the remaining components are the expected ones.
for lhs, rhs in zip(simplified.source, reference_simplified_sources[j]):
assert Equal(lhs, rhs)
# Test that components with side effects are not removed
side_effect = lambda *xs: tvm.make.Call("int32", "dummy", xs, tvm.expr.Call.Intrinsic, None, 0)
assert Equal(Simplify(sum_and_prod((A[k], side_effect(A[10-k])), k)[0]),
sum_and_prod((A[k], side_effect(A[10-k])), k)[0])
assert Equal(Simplify(sum_and_prod((side_effect(A[k]), A[10-k]), k)[0]),
tvm.sum(side_effect(A[k]), k))
def test_simplify_reduce():
k = tvm.reduce_axis((0, 10), name="k")
j = tvm.reduce_axis((-5, 3), name="j")
A = tvm.placeholder((10,), name='A')
assert Equal(Simplify(tvm.sum(k/10, k)), tvm.sum(tvm.const(0, "int32"), k))
assert Equal(Simplify(tvm.sum(A[3], [])), A[3])
assert Equal(Simplify(tvm.sum(tvm.expr.Select(k + j < 12, k + j, 0), [k, j])),
tvm.sum(k + j, [k, j]))
if __name__ == "__main__":
test_bound()
test_basic()
test_simplify()
test_canonical()
test_simplify_combiner()
test_simplify_reduce()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment