Unverified Commit 7afbca56 by Tianqi Chen Committed by GitHub

[ARITH] Analyzer CanonicalSimplifier (#2891)

parent eb1ed116
...@@ -218,6 +218,7 @@ class RewriteSimplifier { ...@@ -218,6 +218,7 @@ class RewriteSimplifier {
private: private:
friend class Analyzer; friend class Analyzer;
friend class ConstraintContext; friend class ConstraintContext;
friend class CanonicalSimplifier;
explicit RewriteSimplifier(Analyzer* parent); explicit RewriteSimplifier(Analyzer* parent);
~RewriteSimplifier(); ~RewriteSimplifier();
class Impl; class Impl;
...@@ -226,6 +227,39 @@ class RewriteSimplifier { ...@@ -226,6 +227,39 @@ class RewriteSimplifier {
}; };
/*! /*!
* \brief Canonical-form based simplifier.
*/
class CanonicalSimplifier {
public:
/*!
* \brief analyze the expr
* \param expr The expression of interest.
* \return the result of the analysis.
*/
Expr operator()(const Expr& expr);
/*!
* \brief Update binding of var to a new expression.
*
* \param var The variable of interest.
* \param new_expr
* \param override Whether do we allow override of existing information.
*/
void Update(const Var& var,
const Expr& new_expr,
bool override = false);
private:
friend class Analyzer;
friend class ConstraintContext;
explicit CanonicalSimplifier(Analyzer* parent);
~CanonicalSimplifier();
class Impl;
/*! \brief Internal impl */
Impl* impl_;
};
/*!
* \brief A RAII constraint context. * \brief A RAII constraint context.
* *
* \code * \code
...@@ -277,6 +311,8 @@ class Analyzer { ...@@ -277,6 +311,8 @@ class Analyzer {
ModularSetAnalyzer modular_set; ModularSetAnalyzer modular_set;
/*! \brief sub-analyzer rewrite simplfy */ /*! \brief sub-analyzer rewrite simplfy */
RewriteSimplifier rewrite_simplify; RewriteSimplifier rewrite_simplify;
/*! \brief sub-analyzer rewrite simplfy */
CanonicalSimplifier canonical_simplify;
/*! \brief constructor */ /*! \brief constructor */
Analyzer(); Analyzer();
/*! /*!
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <string> #include <string>
#include <memory> #include <memory>
#include <functional> #include <functional>
#include <utility>
#include "runtime/registry.h" #include "runtime/registry.h"
namespace tvm { namespace tvm {
...@@ -32,6 +33,44 @@ using ::tvm::AttrVisitor; ...@@ -32,6 +33,44 @@ using ::tvm::AttrVisitor;
using ContainerType = NodeName; \ using ContainerType = NodeName; \
}; \ }; \
/*!
* \brief Macro to make it easy to define node ref type that
* has a CopyOnWrite member function.
*
* CopyOnWrite will generate a unique copy of the internal node.
* The node will be copied if it is referenced by multiple places.
* The function returns the raw pointer to the node to allow modification
* of the content.
*
* \code
*
* MyCOWNodeRef ref, ref2;
* ref2 = ref;
* ref.CopyOnWrite()->value = new_value;
* assert(ref2->value == old_value);
* assert(ref->value == new_value);
*
* \endcode
*/
#define TVM_DEFINE_COW_NODE_REF(TypeName, BaseType, NodeName) \
class TypeName : public BaseType { \
public: \
TypeName() {} \
explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : BaseType(n) {} \
const NodeName* operator->() const { \
return static_cast<const NodeName*>(node_.get()); \
} \
inline NodeName* CopyOnWrite() { \
CHECK(node_ != nullptr); \
if (!node_.unique()) { \
NodePtr<NodeName> n = make_node<NodeName>(*(operator->())); \
NodePtr<Node>(std::move(n)).swap(node_); \
} \
return static_cast<NodeName*>(node_.get()); \
} \
using ContainerType = NodeName; \
};
/*! /*!
* \brief save the node as well as all the node it depends on as json. * \brief save the node as well as all the node it depends on as json.
......
...@@ -97,6 +97,7 @@ class Analyzer: ...@@ -97,6 +97,7 @@ class Analyzer:
self._bind = _mod("bind") self._bind = _mod("bind")
self._modular_set = _mod("modular_set") self._modular_set = _mod("modular_set")
self._rewrite_simplify = _mod("rewrite_simplify") self._rewrite_simplify = _mod("rewrite_simplify")
self._canonical_simplify = _mod("canonical_simplify")
self._enter_constraint_context = _mod("enter_constraint_context") self._enter_constraint_context = _mod("enter_constraint_context")
def const_int_bound(self, expr): def const_int_bound(self, expr):
...@@ -144,6 +145,21 @@ class Analyzer: ...@@ -144,6 +145,21 @@ class Analyzer:
""" """
return self._rewrite_simplify(expr) return self._rewrite_simplify(expr)
def canonical_simplify(self, expr):
"""Simplify expression via canonicalization.
Parameters
----------
expr : tvm.Expr
The expression.
Returns
-------
result : Expr
The result.
"""
return self._canonical_simplify(expr)
def bind(self, var, expr): def bind(self, var, expr):
"""Bind a variable to the expression. """Bind a variable to the expression.
......
...@@ -102,6 +102,10 @@ TVM_REGISTER_API("arith._CreateAnalyzer") ...@@ -102,6 +102,10 @@ TVM_REGISTER_API("arith._CreateAnalyzer")
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
*ret = self->rewrite_simplify(args[0]); *ret = self->rewrite_simplify(args[0]);
}); });
} else if (name == "canonical_simplify") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
*ret = self->canonical_simplify(args[0]);
});
} else if (name == "bind") { } else if (name == "bind") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
auto& sptr = args[1].node_sptr(); auto& sptr = args[1].node_sptr();
......
...@@ -11,14 +11,21 @@ namespace arith { ...@@ -11,14 +11,21 @@ namespace arith {
Analyzer::Analyzer() Analyzer::Analyzer()
: const_int_bound(this), : const_int_bound(this),
modular_set(this), modular_set(this),
rewrite_simplify(this) { rewrite_simplify(this),
canonical_simplify(this) {
} }
void Analyzer::Bind(const VarExpr& v, const Expr& expr) { void Analyzer::Bind(const VarExpr& v, const Expr& expr) {
Var var(v.node_); Var var(v.node_);
this->const_int_bound.Update(var, this->const_int_bound(expr));
this->modular_set.Update(var, this->modular_set(expr)); Expr new_expr = expr;
this->rewrite_simplify.Update(var, this->rewrite_simplify(expr)); new_expr = this->canonical_simplify(new_expr);
new_expr = this->rewrite_simplify(new_expr);
this->const_int_bound.Update(var, this->const_int_bound(new_expr));
this->modular_set.Update(var, this->modular_set(new_expr));
this->rewrite_simplify.Update(var, new_expr);
this->canonical_simplify.Update(var, new_expr);
} }
void Analyzer::Bind(const VarExpr& v, const Range& range) { void Analyzer::Bind(const VarExpr& v, const Range& range) {
...@@ -47,5 +54,6 @@ bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) { ...@@ -47,5 +54,6 @@ bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) {
if (bd->min_value >= lower_bound) return true; if (bd->min_value >= lower_bound) return true;
return false; return false;
} }
} // namespace arith } // namespace arith
} // namespace tvm } // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file canonical.h
* \brief Internal canonicalized expression simplification engine.
*/
#ifndef TVM_ARITHMETIC_CANONICAL_H_
#define TVM_ARITHMETIC_CANONICAL_H_
#include <tvm/expr.h>
#include <tvm/schedule.h>
#include <memory>
namespace tvm {
namespace arith {
/*!
* \brief A stateful CanonicalEngine over SSA.
*
* Simplify and CSE with canonicalization expressions.
* Each call's result will get cached, so next call will
* simply return the cached result.
*/
class Canonical {
public:
/*! \brief constructor */
explicit Canonical(Map<Var, Range> var_range);
/*!
* \brief simplify expression e.
* \param expr The expression to be simplified.
*/
Expr Simplify(Expr expr);
/*!
* \brief simplify stmt.
* \param stmt The stmt to be simplified.
*/
Stmt Simplify(Stmt expr);
/*!
* \brief Set range and level variable
* \param v The variable
* \param r The range of the variable, can be undefined.
* \param level The scope level of the variable,
* affect the order of formula in communicative ops.
*/
void SetRange(Var v, Range r, int level);
class Internal;
private:
// Internal pointer
std::shared_ptr<Internal> ptr_;
};
} // namespace arith
} // namespace tvm
#endif // TVM_ARITHMETIC_CANONICAL_H_
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
#define TVM_ARITHMETIC_CONST_FOLD_H_ #define TVM_ARITHMETIC_CONST_FOLD_H_
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/expr_operator.h>
#include <algorithm> #include <algorithm>
namespace tvm { namespace tvm {
......
...@@ -37,6 +37,10 @@ struct ConstIntBoundAnalyzer::Entry { ...@@ -37,6 +37,10 @@ struct ConstIntBoundAnalyzer::Entry {
bool is_const(int64_t value) const { bool is_const(int64_t value) const {
return min_value == max_value && min_value == value; return min_value == max_value && min_value == value;
} }
bool operator==(const Entry& other) const {
return min_value == other.min_value && max_value == other.max_value;
}
}; };
class ConstIntBoundAnalyzer::Impl : class ConstIntBoundAnalyzer::Impl :
...@@ -55,7 +59,11 @@ class ConstIntBoundAnalyzer::Impl : ...@@ -55,7 +59,11 @@ class ConstIntBoundAnalyzer::Impl :
const Entry& info, const Entry& info,
bool override) { bool override) {
if (!override) { if (!override) {
CHECK(!var_map_.count(var)); auto it = var_map_.find(var);
if (it != var_map_.end()) {
CHECK(it->second == info)
<< "var \'" << var << "\' already updated.";
}
} }
var_map_[var] = info; var_map_[var] = info;
} }
......
...@@ -7,8 +7,10 @@ ...@@ -7,8 +7,10 @@
#include <tvm/arithmetic.h> #include <tvm/arithmetic.h>
#include <tvm/expr_operator.h> #include <tvm/expr_operator.h>
#include <tvm/ir_mutator.h> #include <tvm/ir_mutator.h>
#include <algorithm>
#include "const_fold.h" #include "const_fold.h"
#include "pattern_match.h" #include "pattern_match.h"
#include "rewrite_simplify.h"
namespace tvm { namespace tvm {
namespace arith { namespace arith {
...@@ -39,86 +41,16 @@ using namespace ir; ...@@ -39,86 +41,16 @@ using namespace ir;
return RecursiveRewrite((ResExpr).Eval()); \ return RecursiveRewrite((ResExpr).Eval()); \
} }
// NOTE for developers: // NOTE for developers:
// //
// We mainly focus on index expression simplification. // We mainly focus on index expression simplification.
// Besides the RewriteSimplifier, some cases can be better // Besides the RewriteSimplifier, some cases can be better
// handled by CanonicalSimplifier. // handled by CanonicalSimplifier.
// //
class RewriteSimplifier::Impl : public IRMutator {
public:
explicit Impl(Analyzer* parent)
: parent_(parent) {}
void Update(const Var& var,
const Expr& info,
bool override) {
if (!override) {
CHECK(!var_map_.count(var));
}
var_map_[var] = info;
}
// Run simplification in post order // try to prove x equals val
Expr PostOrderSimplify(Expr expr, int max_iter = 2) { RewriteSimplifier::Impl::CompareResult RewriteSimplifier::Impl::
for (int i = 0; i < max_iter; ++i) { TryCompare(const Expr& x, int64_t val) {
Expr new_expr = this->Mutate(expr);
if (new_expr.same_as(expr)) return expr;
expr = new_expr;
}
return expr;
}
Expr Mutate_(const Add* op, const Expr& self) final;
Expr Mutate_(const Sub* op, const Expr& self) final;
Expr Mutate_(const Mul* op, const Expr& self) final;
Expr Mutate_(const Div* op, const Expr& self) final;
Expr Mutate_(const Mod* op, const Expr& self) final;
Expr Mutate_(const Min* op, const Expr& self) final;
Expr Mutate_(const Max* op, const Expr& self) final;
Expr Mutate_(const EQ* op, const Expr& self) final;
Expr Mutate_(const NE* op, const Expr& self) final;
Expr Mutate_(const LT* op, const Expr& self) final;
Expr Mutate_(const LE* op, const Expr& self) final;
Expr Mutate_(const GT* op, const Expr& self) final;
Expr Mutate_(const GE* op, const Expr& self) final;
Expr Mutate_(const And* op, const Expr& self) final;
Expr Mutate_(const Or* op, const Expr& self) final;
Expr Mutate_(const Not* op, const Expr& self) final;
Expr Mutate_(const Select* op, const Expr& self) final;
Expr Mutate_(const Ramp* op, const Expr& self) final;
private:
/*! \brief internal structure for comparison. */
enum CompareResult {
kUnknown,
kEQ,
kGT,
kLT,
kGE,
kLE,
kNE
};
// reference to the main analyzer
Analyzer* parent_;
// counter to record recursive rewrite depth.
int recur_depth_{0};
// internal variable map
std::unordered_map<Var, Expr, ExprHash, ExprEqual> var_map_;
// maximum number of recursion allowed during a single pass.
static const constexpr int kMaxRecurDepth = 5;
// Whether x >= val
bool CanProveGreaterEqual(const Expr& x, int64_t val) {
return parent_->CanProveGreaterEqual(x, val);
}
// Whether x == val
bool CanProveEqual(const Expr& x, int64_t val) {
// TODO(tqchen) refer back to super-analyzer.
return TryCompare(x, val) == kEQ;
}
// try to prove x equals val
CompareResult TryCompare(const Expr& x, int64_t val) {
Expr diff = Mutate(x); Expr diff = Mutate(x);
if (const auto* ptr = diff.as<IntImm>()) { if (const auto* ptr = diff.as<IntImm>()) {
if (ptr->value == val) { if (ptr->value == val) {
...@@ -149,24 +81,15 @@ class RewriteSimplifier::Impl : public IRMutator { ...@@ -149,24 +81,15 @@ class RewriteSimplifier::Impl : public IRMutator {
return kLE; return kLE;
} }
return kUnknown; return kUnknown;
} }
// Recursive rewrite x
// we limit maximum depth of recursive rewrite allowed to
// avoid infinite loop
Expr RecursiveRewrite(const Expr& x) {
if (recur_depth_ >= kMaxRecurDepth) return x;
++recur_depth_;
Expr res = Mutate(x);
--recur_depth_;
return res;
}
template<typename TA> void RewriteSimplifier::Impl::
PConstWithTypeLike<TA> ZeroWithTypeLike(const Pattern<TA>& pattern) { Update(const Var& var, const Expr& info, bool override) {
return PConstWithTypeLike<TA>(pattern.derived(), 0); if (!override) {
CHECK(!var_map_.count(var));
} }
}; var_map_[var] = info;
}
Expr RewriteSimplifier::Impl:: Expr RewriteSimplifier::Impl::
Mutate_(const Add* op, const Expr& self) { Mutate_(const Add* op, const Expr& self) {
...@@ -1254,16 +1177,6 @@ Mutate_(const Or* op, const Expr& self) { ...@@ -1254,16 +1177,6 @@ Mutate_(const Or* op, const Expr& self) {
} }
Expr RewriteSimplifier::Impl:: Expr RewriteSimplifier::Impl::
Mutate_(const Ramp* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
op = ret.as<Ramp>();
if (is_zero(op->stride)) {
return Broadcast::make(op->base, op->lanes);
}
return ret;
}
Expr RewriteSimplifier::Impl::
Mutate_(const Select* op, const Expr& self) { Mutate_(const Select* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self); Expr ret = IRMutator::Mutate_(op, self);
op = ret.as<Select>(); op = ret.as<Select>();
...@@ -1275,13 +1188,30 @@ Mutate_(const Select* op, const Expr& self) { ...@@ -1275,13 +1188,30 @@ Mutate_(const Select* op, const Expr& self) {
} }
// Pattern var to match any expression // Pattern var to match any expression
PVar<Expr> x, y; PVar<Expr> x, y;
TVM_TRY_REWRITE(select(x, y, y), y); TVM_TRY_REWRITE(select(x, y, y), y);
return ret; return ret;
} }
Expr RewriteSimplifier::Impl::
Mutate_(const Call* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
op = ret.as<Call>();
if (op->is_intrinsic(Call::likely) && is_const(op->args[0])) {
return op->args[0];
}
return ret;
}
Expr RewriteSimplifier::operator()(const Expr& expr) { Expr RewriteSimplifier::operator()(const Expr& expr) {
return impl_->PostOrderSimplify(expr); // Run simplification in post order
Expr res = expr;
int max_iter = 2;
for (int i = 0; i < max_iter; ++i) {
Expr new_expr = impl_->Mutate(res);
if (new_expr.same_as(res)) return res;
res = new_expr;
}
return res;
} }
void RewriteSimplifier::Update(const Var& var, void RewriteSimplifier::Update(const Var& var,
...@@ -1290,7 +1220,6 @@ void RewriteSimplifier::Update(const Var& var, ...@@ -1290,7 +1220,6 @@ void RewriteSimplifier::Update(const Var& var,
impl_->Update(var, info, override); impl_->Update(var, info, override);
} }
RewriteSimplifier::RewriteSimplifier(Analyzer* parent) RewriteSimplifier::RewriteSimplifier(Analyzer* parent)
: impl_(new Impl(parent)) { : impl_(new Impl(parent)) {
} }
......
/*!
* Copyright (c) 2019 by Contributors
* \file rewrite_simplify.h
* \brief Rewrite-rule based simplification.
*/
#ifndef TVM_ARITHMETIC_REWRITE_SIMPLIFY_H_
#define TVM_ARITHMETIC_REWRITE_SIMPLIFY_H_
#include <tvm/arithmetic.h>
#include <tvm/expr_operator.h>
#include <tvm/ir_mutator.h>
#include <unordered_map>
#include "const_fold.h"
#include "pattern_match.h"
namespace tvm {
namespace arith {
using namespace ir;
/*!
* \brief Rewrite-based simplifier.
*
* This class can be inheritated for other simplifiers.
*/
class RewriteSimplifier::Impl : public IRMutator {
public:
explicit Impl(Analyzer* parent)
: parent_(parent) {}
void Update(const Var& var, const Expr& info, bool override);
Expr Mutate_(const Add* op, const Expr& self) override;
Expr Mutate_(const Sub* op, const Expr& self) override;
Expr Mutate_(const Mul* op, const Expr& self) override;
Expr Mutate_(const Div* op, const Expr& self) override;
Expr Mutate_(const Mod* op, const Expr& self) override;
Expr Mutate_(const Min* op, const Expr& self) override;
Expr Mutate_(const Max* op, const Expr& self) override;
Expr Mutate_(const EQ* op, const Expr& self) override;
Expr Mutate_(const NE* op, const Expr& self) override;
Expr Mutate_(const LT* op, const Expr& self) override;
Expr Mutate_(const LE* op, const Expr& self) override;
Expr Mutate_(const GT* op, const Expr& self) override;
Expr Mutate_(const GE* op, const Expr& self) override;
Expr Mutate_(const And* op, const Expr& self) override;
Expr Mutate_(const Or* op, const Expr& self) override;
Expr Mutate_(const Not* op, const Expr& self) override;
Expr Mutate_(const Select* op, const Expr& self) override;
Expr Mutate_(const Call* op, const Expr& self) override;
protected:
/*! \brief internal structure for comparison. */
enum CompareResult {
kUnknown,
kEQ,
kGT,
kGE,
kLT,
kLE,
kNE
};
// reference to the main analyzer
Analyzer* parent_;
// counter to record recursive rewrite depth.
int recur_depth_{0};
// internal variable map
std::unordered_map<Var, Expr, ExprHash, ExprEqual> var_map_;
// maximum number of recursion allowed during a single pass.
static const constexpr int kMaxRecurDepth = 5;
/*!
* \brief try to compare x against val.
* \param x The expression to be evaluated.
* \param val The constant value.
* \return comparison result.
*/
CompareResult TryCompare(const Expr& x, int64_t val);
private:
// Whether x >= val
bool CanProveGreaterEqual(const Expr& x, int64_t val) {
return parent_->CanProveGreaterEqual(x, val);
}
// Whether x == val
bool CanProveEqual(const Expr& x, int64_t val) {
// TODO(tqchen) refer back to super-analyzer.
return TryCompare(x, val) == kEQ;
}
// Recursive rewrite x
// we limit maximum depth of recursive rewrite allowed to
// avoid infinite loop
Expr RecursiveRewrite(const Expr& x) {
if (recur_depth_ >= kMaxRecurDepth) return x;
++recur_depth_;
Expr res = Mutate(x);
--recur_depth_;
return res;
}
template<typename TA>
PConstWithTypeLike<TA> ZeroWithTypeLike(const Pattern<TA>& pattern) {
return PConstWithTypeLike<TA>(pattern.derived(), 0);
}
};
} // namespace arith
} // namespace tvm
#endif // TVM_ARITHMETIC_REWRITE_SIMPLIFY_H_
/*!
* Copyright (c) 2019 by Contributors
* \file stmt_simplify.cc
* \brief Statement simplifier based on analyzer
*/
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/arithmetic.h>
#include <tvm/ir_mutator.h>
#include <tvm/expr_operator.h>
#include <tvm/arithmetic.h>
#include "arithmetic/Simplify.h"
namespace tvm {
namespace arith {
// statement simplifier
using namespace ir;
class StmtSimplifier : public IRMutator {
public:
Stmt Mutate_(const For* op, const Stmt& s) final {
Var loop_var(op->loop_var.node_);
analyzer_.Bind(loop_var, Range::make_by_min_extent(op->min, op->extent));
return IRMutator::Mutate_(op, s);
}
// IfThenElse
Stmt Mutate_(const IfThenElse* op, const Stmt& s) {
Expr condition = this->Mutate(op->condition);
Stmt then_case, else_case;
{
ConstraintContext ctx(&analyzer_, condition);
then_case = this->Mutate(op->then_case);
}
if (op->else_case.defined()) {
ConstraintContext ctx(&analyzer_, Mutate(Not::make(condition)));
else_case = this->Mutate(op->else_case);
}
if (is_one(condition)) return then_case;
if (is_zero(condition)) {
if (else_case.defined()) {
return else_case;
}
return Evaluate::make(0);
}
if (condition.same_as(op->condition) &&
then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return s;
} else {
return IfThenElse::make(condition, then_case, else_case);
}
}
// AttrStmt
Stmt Mutate_(const AttrStmt* op, const Stmt& s) {
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread) {
IterVar iv(op->node.node_);
CHECK_NE(iv->thread_tag.length(), 0U);
if (!var_dom_.count(iv->var.get())) {
Range dom = Range::make_by_min_extent(0, op->value);
var_dom_[iv->var.get()] = dom;
analyzer_.Bind(iv->var, dom);
}
Stmt stmt = IRMutator::Mutate_(op, s);
return stmt;
} else {
return IRMutator::Mutate_(op, s);
}
}
// AssertStmt
Stmt Mutate_(const AssertStmt* op, const Stmt& s) final {
Expr condition = this->Mutate(op->condition);
Expr message = this->Mutate(op->message);
ConstraintContext ctx(&analyzer_, condition);
Stmt body = this->Mutate(op->body);
if (condition.same_as(op->condition) &&
message.same_as(op->message) &&
body.same_as(op->body)) {
return s;
} else {
return AssertStmt::make(condition, message, body);
}
}
protected:
Analyzer analyzer_;
// variable domain
std::unordered_map<const Variable*, Range> var_dom_;
};
class CanonicalStmtSimplifier : public StmtSimplifier {
public:
using StmtSimplifier::Mutate;
Expr Mutate(Expr expr) final {
return analyzer_.canonical_simplify(expr);
}
Stmt CanonicalSimplify(Stmt stmt, Map<Var, Range> vrange) {
for (auto kv : vrange) {
analyzer_.Bind(kv.first, kv.second);
}
return Mutate(stmt);
}
};
} // namespace arith
namespace ir {
Stmt CanonicalSimplify(Stmt stmt, Map<Var, Range> vrange) {
return arith::CanonicalStmtSimplifier().CanonicalSimplify(
stmt, vrange);
}
Expr CanonicalSimplify(Expr expr, Map<Var, Range> vrange) {
arith::Analyzer analyzer;
for (auto kv : vrange) {
analyzer.Bind(kv.first, kv.second);
}
return analyzer.canonical_simplify(expr);
}
template<typename T>
T Simplify_(T a, Map<Var, Range> vrange) {
using namespace HalideIR::Internal;
Scope<Interval> rscope;
for (auto kv : vrange) {
Range r = kv.second;
rscope.push(
kv.first.get(),
Interval(r->min,
simplify(r->min + r->extent - make_const(r->min.type(), 1))));
}
return HalideIR::Internal::simplify(a, true, rscope);
}
Expr Simplify(Expr a, Map<Var, Range> vrange) {
// Simplify top level reduce.
if (const Reduce* r = a.as<Reduce>()) {
Array<Expr> new_source;
for (auto& e : r->source) {
new_source.push_back(Simplify_(e, vrange));
}
Expr new_condition = Simplify_(r->condition, vrange);
if (r->source.same_as(new_source) &&
r->condition.same_as(new_condition)) {
return a;
} else {
return Reduce::make(
r->combiner, new_source, r->axis, new_condition, r->value_index);
}
}
return Simplify_(a, vrange);
}
Stmt Simplify(Stmt a, Map<Var, Range> vrange) {
return Simplify_(a, vrange);
}
} // namespace ir
} // namespace tvm
...@@ -78,7 +78,7 @@ class ThreadAllreduceBuilder final : public IRMutator { ...@@ -78,7 +78,7 @@ class ThreadAllreduceBuilder final : public IRMutator {
Expr Mutate_(const Load* op, const Expr& e) final { Expr Mutate_(const Load* op, const Expr& e) final {
auto it = load_remap_.find(op->buffer_var.get()); auto it = load_remap_.find(op->buffer_var.get());
if (it != load_remap_.end()) { if (it != load_remap_.end()) {
CHECK(is_zero(op->index)); CHECK(is_zero(op->index)) << e;
return it->second; return it->second;
} else { } else {
return IRMutator::Mutate_(op, e); return IRMutator::Mutate_(op, e);
......
import tvm
class CanonicalChecker:
def __init__(self):
self.analyzer = tvm.arith.Analyzer()
def verify(self, data, expected):
res = self.analyzer.canonical_simplify(data)
assert tvm.ir_pass.Equal(res, expected), "data={}, res={}, expected={}".format(data, res, expected)
def test_mul_sum_simplify():
ck = CanonicalChecker()
x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
ck.verify(2 + (3 * x + z + y + 1) * 4 + x,
x * 13 + z * 4 + y * 4 +6)
ck.verify((x + y + x + y * 3) / 2, y * 2 + x)
ck.verify((x + y + x + y * 3) % 2, 0)
ck.verify(x * 3 - 4 * x + 1, 1 - x)
ck.verify(y + x * 3 - 5 * x + 1 + y, y * 2 + 1 - x * 2)
def test_split_index_simplify():
ck = CanonicalChecker()
x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
ck.verify((x/3) *3 + x % 3, x)
ck.verify((x/6) * 6 + ((x/3) % 2) * 3 + x % 3, x)
# split div const
ck.verify(((x % 16) / 2) * 2 / 4, (x % 16) / 4)
ck.verify((x % 2) / 8, 0)
ck.verify((x % 2) / 7, 0)
ck.verify(((x % 16) / 2) * 2 / 6, (x % 16) / 6)
# split mod const
ck.verify((x * 8) % 16, (x % 2) * 8)
ck.verify((x * 8) % 2, 0)
# simplify then fold
ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000))
ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1000))
ck.verify((x * 4 + y) / 2 * 2 + (x * 4 + y) % 2, x * 4 + y)
# complex fold
ck.verify((z * 9 + y) / 2 * 2 + (z * 9 + y) % 2, z * 9 + y)
def test_div_simplify():
ck = CanonicalChecker()
x = tvm.var("x")
ck.verify((16+48*x)/16, x*3 + 1)
# (17+48*x)/16 is not simplifiable for arbitrary x because when 17+48*x<0
# (17+48*x)/16 != 1+3*x
ck.verify((17+48*x)/16, (x * 48 + 17) / 16)
# However, when x >= 0, then 17+48*x >= 0 and (17+48*x)/16 can be simplified
ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 10))
ck.verify((17+48*x)/16, x * 3 + 1)
# Trying expressions that are not simplifiable for any values of the variables
ck.verify((17+47*x)/16, (x * 47 + 17) / 16)
def test_canonical_mixed():
ck = CanonicalChecker()
x = tvm.var("x")
z = tvm.const(3, "int32")
ck.verify(x / (z*z) - x / (z*z), 0)
ck.verify(x / (z+z) - x / (z+z), 0)
def test_reduce_combiner_simplify():
ck = CanonicalChecker()
dummy = tvm.var('dummy')
comm_reducer = tvm.comm_reducer
prod = comm_reducer(lambda x, y: x*y, lambda t0: tvm.const(1, t0))
sum_or_prod = comm_reducer(
lambda x, y: tvm.expr.Select(dummy < 0,
x + y, x*y),
lambda t0: tvm.expr.Select(dummy < 0,
tvm.const(0, t0), tvm.const(1, t0)))
sum_and_prod = comm_reducer(
lambda x, y: (x[0] + y[0],
x[1]*y[1]),
lambda t0, t1: (tvm.const(0, t0),
tvm.const(5, t0) - tvm.const(4, t0)))
some_reducer1 = comm_reducer(
lambda x, y: (x[0] + y[0],
x[0] + y[0] + x[1] + y[1],
x[0]*y[2] + y[0]*x[2],
x[1] + y[2],
4.0),
lambda t0, t1, t2, t3, t4: (tvm.const(0, t0),
tvm.const(1, t1),
tvm.const(2, t2),
tvm.const(3, t3),
tvm.const(4, t4)))
k = tvm.reduce_axis((0, 10), name="k")
A = tvm.placeholder((10,), name='A')
# Test that SimplifyCombiner makes use of vranges
ck.analyzer.update(dummy, tvm.arith.ConstIntBound(-10, -4))
ck.verify(sum_or_prod(A[k], k), tvm.sum(A[k], k))
ck.analyzer.update(dummy, tvm.arith.ConstIntBound(5, 9), True)
ck.verify(sum_or_prod(A[k], k), prod(A[k], k))
ck.analyzer.update(dummy, tvm.arith.ConstIntBound(-10, 100), True)
ck.verify(sum_and_prod((A[k], A[10-k]), k)[0], tvm.sum(A[k], k))
ck.verify(sum_and_prod((A[k], A[10-k]), k)[1], prod(A[10-k], k))
reference_simplified_sources = [[A[0]],
[A[0], A[1]],
[A[0], A[2]],
[A[0], A[1], A[2], A[3]],
[A[4]]]
for j in range(5):
# Here we use the j-th component of the result, so only it and the components it
# depends on are left.
simplified = ck.analyzer.canonical_simplify(
some_reducer1((A[0], A[1], A[2], A[3], A[4]), k)[j])
# Check that the remaining components are the expected ones.
for lhs, rhs in zip(simplified.source, reference_simplified_sources[j]):
assert tvm.ir_pass.Equal(lhs, rhs)
# Test that components with side effects are not removed
side_effect = lambda *xs: tvm.make.Call("int32", "dummy", xs, tvm.expr.Call.Intrinsic, None, 0)
ck.verify(sum_and_prod((A[k], side_effect(A[10-k])), k)[0],
sum_and_prod((A[k], side_effect(A[10-k])), k)[0])
ck.verify(sum_and_prod((side_effect(A[k]), A[10-k]), k)[0],
tvm.sum(side_effect(A[k]), k))
def test_reduce_simplify():
ck = CanonicalChecker()
k = tvm.reduce_axis((0, 10), name="k")
j = tvm.reduce_axis((-5, 3), name="j")
A = tvm.placeholder((10,), name='A')
ck.verify(tvm.sum(tvm.expr.Select(k + j < 12, k + j, 0), [k, j]),
tvm.sum(k + j, [k, j]))
ck.verify(tvm.sum(A[3], []), A[3])
# The rule below is not typical, removed for now
ck.verify(tvm.sum(k / 10, k), tvm.sum(tvm.const(0, "int32"), k))
if __name__ == "__main__":
test_div_simplify()
test_reduce_simplify()
test_reduce_combiner_simplify()
test_mul_sum_simplify()
test_split_index_simplify()
test_canonical_mixed()
...@@ -21,7 +21,6 @@ def test_simplify(): ...@@ -21,7 +21,6 @@ def test_simplify():
assert zz.a == x and zz.b.value == 4 assert zz.a == x and zz.b.value == 4
n = tvm.var('n') n = tvm.var('n')
assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(n % (-1)), tvm.const(0, "int32"))
assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(n % 1), tvm.const(0, "int32")) assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(n % 1), tvm.const(0, "int32"))
assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(n / 1), n) assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(n / 1), n)
tvm.ir_pass.CanonicalSimplify(n / (-1)) tvm.ir_pass.CanonicalSimplify(n / (-1))
...@@ -29,36 +28,16 @@ def test_simplify(): ...@@ -29,36 +28,16 @@ def test_simplify():
# assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(n / (-1)), # assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(n / (-1)),
# tvm.ir_pass.CanonicalSimplify(-n)) # tvm.ir_pass.CanonicalSimplify(-n))
def test_simplify_div():
x = tvm.var('x')
assert tvm.ir_pass.CanonicalSimplify((16+48*x)/16 - (1 + (x*3))).value == 0
# (17+48*x)/16 is not simplifiable for arbitrary x because when 17+48*x<0
# (17+48*x)/16 != 1+3*x
r = tvm.ir_pass.CanonicalSimplify((17+48*x)/16)
assert r.b.value == 16
assert tvm.ir_pass.CanonicalSimplify(r.a - (17 + 48*x)).value == 0
# However, when x >= 0, then 17+48*x >= 0 and (17+48*x)/16 can be simplified
assert tvm.ir_pass.CanonicalSimplify((17+48*x)/16 - (1 + (x*3)), {x: tvm.Range(0,10)}).value == 0
# Trying expressions that are not simplifiable for any values of the variables
r = tvm.ir_pass.CanonicalSimplify((17+47*x)/16, {x: tvm.Range(0,10)})
assert r.b.value == 16
assert tvm.ir_pass.CanonicalSimplify(r.a - (17+47*x)).value == 0
r = tvm.ir_pass.CanonicalSimplify((8*x - 17)/8, {x : tvm.Range(4,10)})
assert tvm.ir_pass.CanonicalSimplify(r - (x-3)).value == 0
def test_simplify_mod(): def test_simplify_mod():
"""Not yet working, mock design"""
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
n = tvm.var('n') n = tvm.var('n')
j = tvm.var('j')
A = ib.pointer("float32", name="A") A = ib.pointer("float32", name="A")
with ib.for_range(0, 10, name="j") as j:
with ib.for_range(0, 16, name="i") as i: with ib.for_range(0, 16, name="i") as i:
A[i] = A[((n * 4 + j * 2) * 8 + i+1) % 16] A[i] = A[(j * 32 + i+1) % 16]
body = ib.get() body = ib.get()
stmt = tvm.ir_pass.CanonicalSimplify(body) stmt = tvm.ir_pass.CanonicalSimplify(body)
diff = tvm.ir_pass.CanonicalSimplify(stmt.body.value.index - (1 + i) % 16) diff = tvm.ir_pass.CanonicalSimplify(stmt.body.body.value.index - (1 + i) % 16)
assert diff.value == 0 assert diff.value == 0
# if we can't prove that j+n*32 is non-negative, we can't prove that (j+n*32) % 16 is j%16 # if we can't prove that j+n*32 is non-negative, we can't prove that (j+n*32) % 16 is j%16
index = tvm.ir_pass.CanonicalSimplify( index = tvm.ir_pass.CanonicalSimplify(
...@@ -95,8 +74,8 @@ def test_modular(): ...@@ -95,8 +74,8 @@ def test_modular():
y: tvm.Range(i32_const(0), i32_const(2)), y: tvm.Range(i32_const(0), i32_const(2)),
x: tvm.Range(i32_const(0), i32_const(14))} x: tvm.Range(i32_const(0), i32_const(14))}
idx = ry * 16 + rx + y * 16 + x idx = ry * 16 + rx + y * 16 + x
z1 = tvm.ir_pass.CanonicalSimplify(idx // 16, vmap)
z2 = tvm.ir_pass.CanonicalSimplify(idx % 16, vmap) z2 = tvm.ir_pass.CanonicalSimplify(idx % 16, vmap)
z1 = tvm.ir_pass.CanonicalSimplify(idx // 16, vmap)
assert tvm.ir_pass.CanonicalSimplify(z1 - (ry + y)).value == 0 assert tvm.ir_pass.CanonicalSimplify(z1 - (ry + y)).value == 0
assert tvm.ir_pass.CanonicalSimplify(z2 - (rx + x)).value == 0 assert tvm.ir_pass.CanonicalSimplify(z2 - (rx + x)).value == 0
...@@ -117,10 +96,9 @@ def test_const_propagation(): ...@@ -117,10 +96,9 @@ def test_const_propagation():
if __name__ == "__main__": if __name__ == "__main__":
test_simplify_div()
test_simplify_mod()
test_modular() test_modular()
test_simplify() test_simplify()
test_mul() test_mul()
test_simplify_minmax() test_simplify_minmax()
test_const_propagation() test_const_propagation()
test_simplify_mod()
...@@ -35,109 +35,8 @@ def test_bound(): ...@@ -35,109 +35,8 @@ def test_bound():
ret = tvm.ir_pass.Simplify(m % 10, vrange) ret = tvm.ir_pass.Simplify(m % 10, vrange)
assert ret == m assert ret == m
def test_canonical():
x = tvm.var("x")
z = tvm.const(3, "int32")
ret = tvm.ir_pass.CanonicalSimplify(x / (z*z) - x / (z*z))
assert(tvm.ir_pass.Equal(ret, 0))
ret = tvm.ir_pass.CanonicalSimplify(x / (z+z) - x / (z+z))
assert(tvm.ir_pass.Equal(ret, 0))
#make sure terms are ordered based on their top operators (e.g., / always precedes %)
ret1 = tvm.ir_pass.CanonicalSimplify(x % 3 + x / 3)
ret2 = tvm.ir_pass.CanonicalSimplify(x / 3 + x % 3)
assert(tvm.ir_pass.Equal(ret1, ret2))
#when top operators match, compare string representation of terms
ret1 = tvm.ir_pass.CanonicalSimplify(x % 4 + x % 3)
ret2 = tvm.ir_pass.CanonicalSimplify(x % 3 + x % 4)
assert (tvm.ir_pass.Equal(ret1, ret2))
def test_simplify_combiner():
dummy = tvm.var('dummy')
prod = comm_reducer(lambda x, y: x*y, lambda t0: tvm.const(1, t0))
sum_or_prod = comm_reducer(lambda x, y: tvm.expr.Select(dummy < 0,
x + y, x*y),
lambda t0: tvm.expr.Select(dummy < 0,
tvm.const(0, t0), tvm.const(1, t0)))
sum_and_prod = comm_reducer(lambda x, y: (x[0] + y[0],
x[1]*y[1]),
lambda t0, t1: (tvm.const(0, t0),
tvm.const(5, t0) - tvm.const(4, t0)))
sum_and_prod2 = comm_reducer(lambda x, y: (x[0] + y[0],
x[1]*y[1] + 0*x[0] + y[0] - y[0]),
lambda t0, t1: (tvm.const(5, t0) - tvm.const(5, t0),
tvm.const(1, t1)))
some_reducer1 = comm_reducer(lambda x, y: (x[0] + y[0],
x[0] + y[0] + x[1] + y[1],
x[0]*y[2] + y[0]*x[2],
x[1] + y[2],
4.0),
lambda t0, t1, t2, t3, t4: (tvm.const(0, t0),
tvm.const(1, t1),
tvm.const(2, t2),
tvm.const(3, t3),
tvm.const(4, t4)))
k = tvm.reduce_axis((0, 10), name="k")
A = tvm.placeholder((10,), name='A')
# Test that SimplifyCombiner makes use of vranges
vrange = {dummy: tvm.Range(-10, -5)}
assert Equal(Simplify(sum_or_prod(A[k], k), vrange), tvm.sum(A[k], k))
vrange = {dummy: tvm.Range(5, 10)}
assert Equal(Simplify(sum_or_prod(A[k], k), vrange), prod(A[k], k))
assert Equal(Simplify(sum_and_prod((A[k], A[10-k]), k)[0]), tvm.sum(A[k], k))
assert Equal(Simplify(sum_and_prod((A[k], A[10-k]), k)[1]), prod(A[10-k], k))
assert Equal(Simplify(sum_and_prod2((A[k], A[10-k]), k)[0]), tvm.sum(A[k], k))
assert Equal(Simplify(sum_and_prod2((A[k], A[10-k]), k)[1]), prod(A[10-k], k))
reference_simplified_sources = [[A[0]],
[A[0], A[1]],
[A[0], A[2]],
[A[0], A[1], A[2], A[3]],
[A[4]]]
for j in range(5):
# Here we use the j-th component of the result, so only it and the components it
# depends on are left.
simplified = Simplify(some_reducer1((A[0], A[1], A[2], A[3], A[4]), k)[j])
# Check that the remaining components are the expected ones.
for lhs, rhs in zip(simplified.source, reference_simplified_sources[j]):
assert Equal(lhs, rhs)
# Test that components with side effects are not removed
side_effect = lambda *xs: tvm.make.Call("int32", "dummy", xs, tvm.expr.Call.Intrinsic, None, 0)
assert Equal(Simplify(sum_and_prod((A[k], side_effect(A[10-k])), k)[0]),
sum_and_prod((A[k], side_effect(A[10-k])), k)[0])
assert Equal(Simplify(sum_and_prod((side_effect(A[k]), A[10-k]), k)[0]),
tvm.sum(side_effect(A[k]), k))
def test_simplify_reduce():
k = tvm.reduce_axis((0, 10), name="k")
j = tvm.reduce_axis((-5, 3), name="j")
A = tvm.placeholder((10,), name='A')
assert Equal(Simplify(tvm.sum(k/10, k)), tvm.sum(tvm.const(0, "int32"), k))
assert Equal(Simplify(tvm.sum(A[3], [])), A[3])
assert Equal(Simplify(tvm.sum(tvm.expr.Select(k + j < 12, k + j, 0), [k, j])),
tvm.sum(k + j, [k, j]))
if __name__ == "__main__": if __name__ == "__main__":
test_bound() test_bound()
test_basic() test_basic()
test_simplify() test_simplify()
test_canonical()
test_simplify_combiner()
test_simplify_reduce()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment