Unverified Commit 1eb1dac4 by Tianqi Chen Committed by GitHub

[ARITH] Analyzer Infra, ConstIntBound, Modular (#2668)

parent 84590063
...@@ -9,14 +9,282 @@ ...@@ -9,14 +9,282 @@
#include <vector> #include <vector>
#include <unordered_map> #include <unordered_map>
#include <memory> #include <memory>
#include <limits>
#include "expr.h" #include "expr.h"
namespace tvm { namespace tvm {
// forward delcare Tensor
class Tensor; class Tensor;
/*! \brief namespace of arithmetic */ /*! \brief namespace of arithmetic */
namespace arith { namespace arith {
//-------------------------------------------------------
// Base integer analysis API.
//
// We have multiple type of analyzers to do relaxed
// integer set analysis(bound analysis, modulo) and
// equivalence checking and simplification.
//
// Importantly, each analyzer may need result from
// another analyzer.
//-------------------------------------------------------
// Forward declare Analyzer
class Analyzer;
/*!
* \brief reference class to ConstIntBoundNode
* \sa ConstIntBoundNode
*/
class ConstIntBound;
/*!
* \brief Constant integer up and lower bound(inclusive).
* Useful for value bound analysis.
*
* set = [min_value, max_value]
*/
class ConstIntBoundNode : public Node {
public:
int64_t min_value;
int64_t max_value;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("min_value", &min_value);
v->Visit("max_value", &max_value);
}
TVM_DLL static ConstIntBound make(int64_t min_value, int64_t max_value);
/*! \brief Number to represent +inf */
static const constexpr int64_t kPosInf = std::numeric_limits<int64_t>::max();
/*!
* \brief Number to represent -inf
* \note We can make use the of fact that -kPosInf == kNegInf in the project.
*/
static const constexpr int64_t kNegInf = -kPosInf;
static constexpr const char* _type_key = "arith.ConstIntBound";
TVM_DECLARE_NODE_TYPE_INFO(ConstIntBoundNode, Node);
};
TVM_DEFINE_NODE_REF(ConstIntBound, ConstIntBoundNode);
/*!
* \brief Analyzer to get constant integer bound over expression.
*/
class ConstIntBoundAnalyzer {
public:
/*!
* \brief analyze the expr
* \param expr The expression of interest.
* \return the result of the analysis.
*/
ConstIntBound operator()(const Expr& expr);
/*!
* \brief Update constant int bound information of var.
*
* \param var The variable of interest.
* \param info The bound information.
* \param override Whether do we allow override of existing information.
*/
void Update(const Var& var,
const ConstIntBound& info,
bool override = false);
/*!
* \brief Bind variable to a range.
*
* \param var The variable.
* \param range The range we bind to.
*/
void Bind(const Var& var, const Range& range);
private:
friend class Analyzer;
friend class ConstraintContext;
explicit ConstIntBoundAnalyzer(Analyzer* parent);
~ConstIntBoundAnalyzer();
/*!
* \brief Update the internal state to enter constraint.
* \param constraint A constraint expression.
*
* \return an exit function that must be called to cleanup the constraint can be nullptr.
*/
std::function<void()> EnterConstraint(const Expr& constraint);
struct Entry;
class Impl;
/*! \brief Internal impl */
Impl* impl_;
};
/*!
* \brief reference of ModularSetNode
* \sa ModularSetNode
*/
class ModularSet;
/*!
* \brief Range of a linear integer function.
* Use to do specify the possible index values.
*
* set = { coeff * x + base | x in Z }
*
* When coeff != 0, it can also be written as
* set = { n | n % coeff == base }
*
* This is useful to decide if the index is dividable by certain value.
* For example, if index = 0 + 4 x, then we know it can be divided by 4.
*/
class ModularSetNode : public Node {
public:
/*! \brief linear co-efficient */
int64_t coeff;
/*! \brief The base */
int64_t base;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("coeff", &coeff);
v->Visit("base", &base);
}
TVM_DLL static ModularSet make(int64_t coeff, int64_t base);
static constexpr const char* _type_key = "arith.ModularSet";
TVM_DECLARE_NODE_TYPE_INFO(ModularSetNode, Node);
};
TVM_DEFINE_NODE_REF(ModularSet, ModularSetNode);
/*!
* \brief Analyzer to get modular information over expression.
*/
class ModularSetAnalyzer {
public:
/*!
* \brief analyze the expr
* \param expr The expression of interest.
* \return the result of the analysis.
*/
ModularSet operator()(const Expr& expr);
/*!
* \brief Update constant int bound information of var.
*
* \param var The variable of interest.
* \param info The bound information.
* \param override Whether do we allow override of existing information.
*/
void Update(const Var& var,
const ModularSet& info,
bool override = false);
private:
friend class Analyzer;
friend class ConstraintContext;
explicit ModularSetAnalyzer(Analyzer* parent);
~ModularSetAnalyzer();
/*!
* \brief Update the internal state to enter constraint.
* \param constraint A constraint expression.
*
* \return an exit function that must be called to cleanup the constraint can be nullptr.
*/
std::function<void()> EnterConstraint(const Expr& constraint);
struct Entry;
class Impl;
/*! \brief Internal impl */
Impl* impl_;
};
/*!
* \brief A RAII constraint context.
*
* \code
*
* Var("x");
* arith::Analyzer analyzer;
* {
* arith::ConstraintContext cctx(&analyzer, x % 3 == 0);
* CHECK_EQ(analyzer.modular_set(x)->coeff, 3);
* }
* // constraint no longer in effect.
* CHECK_NE(analyzer.modular_set(x)->coeff, 3);
*
* \endcode
*/
class ConstraintContext {
public:
/*!
* \brief Construct a constraint context.
* \param analyzer The analyzer.
* \param constraint The constraint to be applied.
*/
ConstraintContext(Analyzer* analyzer, const Expr& constraint) DMLC_THROW_EXCEPTION;
/*! \brief destructor */
~ConstraintContext() DMLC_THROW_EXCEPTION {
exit_();
}
private:
/*! \brief function to be called in recovery */
std::function<void()> exit_;
};
/*!
* \brief Analyzer that contains bunch of sub-analyzers.
*
* Each sub-analyzer can make use of another sub-analyzer
* by weak reference of this.
*
* NOTE for sub-analyzer developers:
* If the analyzer uses memoization, we need to clear the internal
* cache when information about a Var has been overrideen.
*/
class Analyzer {
public:
/*! \brief sub-analyzer: const integer bound */
ConstIntBoundAnalyzer const_int_bound;
/*! \brief sub-analyzer: modular set */
ModularSetAnalyzer modular_set;
/*! \brief constructor */
Analyzer();
/*!
* \brief Notify all the sub-analyzers that var
* is created and binded to expr.
*
* Each var can only be binded once.
*
* \param var The variable.
* \param expr The expression we bind to.
*/
void Bind(const VarExpr& var, const Expr& expr);
/*!
* \brief Notify all the sub-analyzers that var
* is created and binded to a range.
*
* Each var can only be binded once.
*
* \param var The variable.
* \param range The range we bind to.
*/
void Bind(const VarExpr& var, const Range& range);
/*!
* \brief Whether can we proof expr >= val.
* Non-negative proof is very useful in integer analysis
* to lower divisions and mods given difference in trunc and ceil mode.
*
* \param expr The expression.
* \param lower_bound The lower bound.
* \return Whether we can proof it.
*
* \note Analyzer will call into sub-analyzers to get the result.
*/
bool CanProveGreaterEqual(const Expr& expr, int64_t lower_bound);
};
//-----------------------------------------------
// Integer set abstraction API.
//
// This is a API build on top of the base
// integer analysis API to provide set analysis.
//------------------------------------------------
/*! /*!
* \brief Sign of an expression or set. * \brief Sign of an expression or set.
*/ */
...@@ -119,42 +387,6 @@ class IntSet : public NodeRef { ...@@ -119,42 +387,6 @@ class IntSet : public NodeRef {
}; };
/*! /*!
* \brief Range of a linear integer function.
* Use to do specify the possible index values.
*
* set = { coeff * x + base | x in Z }
*
* When coeff != 0, it can also be written as
* set = { n | n % coeff == base }
*
* This is useful to decide if the index is dividable by certain value.
* For example, if index = 0 + 4 x, then we know it can be divided by 4.
*/
struct ModularEntry {
/*! \brief linear co-efficient */
int coeff{1};
/*! \brief The base */
int base{0};
/*! \return entry represent everything */
static ModularEntry everything() {
// always safe to set 0 + x, so it can be everything.
ModularEntry e;
e.coeff = 1;
e.base = 0;
return e;
}
/*!
* \brief Add two modular entries together to get a new modular entry.
* \param a The left operand.
* \param b The right operand.
* \return The combined modular entry.
*/
static ModularEntry Add(const ModularEntry& a,
const ModularEntry& b);
};
/*!
* \brief Base class of all IntSet containers. * \brief Base class of all IntSet containers.
*/ */
struct IntSetNode : public Node { struct IntSetNode : public Node {
...@@ -300,24 +532,6 @@ IntSet DeduceBound(Expr v, Expr cond, ...@@ -300,24 +532,6 @@ IntSet DeduceBound(Expr v, Expr cond,
*/ */
Domain DomainTouched(Stmt body, const Tensor &tensor, bool consider_calls, bool consider_provides); Domain DomainTouched(Stmt body, const Tensor &tensor, bool consider_calls, bool consider_provides);
/*!
* \brief Evaluate the expression with modular analysis
* \param e The expression to be evaluated.
* \param mod_map Map of modular statistics of known variables.
* \return The ModularEntry covering all possible value of e.
*/
ModularEntry EvalModular(
const Expr& e,
const std::unordered_map<const Variable*, ModularEntry>& mod_map);
/*!
* \brief Same as EvalModular, used by front-end.
* \param e The expression to be evaluated.
* \param mod_map Map of modular statistics of known variables.
* \return A ModularSet covering all possible value of e.
*/
IntSet EvalModular(const Expr& e,
const Map<Var, IntSet>& mod_map);
// implementation // implementation
inline const IntSetNode* IntSet::operator->() const { inline const IntSetNode* IntSet::operator->() const {
return static_cast<const IntSetNode*>(node_.get()); return static_cast<const IntSetNode*>(node_.get());
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
namespace tvm { namespace tvm {
namespace ir { namespace ir {
using HalideIR::Internal::BaseExprNode;
using HalideIR::Internal::ExprNode; using HalideIR::Internal::ExprNode;
using HalideIR::Internal::StmtNode; using HalideIR::Internal::StmtNode;
using HalideIR::Internal::IRNodeType; using HalideIR::Internal::IRNodeType;
......
...@@ -33,9 +33,162 @@ class StrideSet(IntSet): ...@@ -33,9 +33,162 @@ class StrideSet(IntSet):
"""Represent set of strided integers""" """Represent set of strided integers"""
@register_node @register_node("arith.ModularSet")
class ModularSet(IntSet): class ModularSet(NodeBase):
"""Represent range of (coeff * x + base) for x in Z """ """Represent range of (coeff * x + base) for x in Z """
def __init__(self, coeff, base):
self.__init_handle_by_constructor__(
_make_ModularSet, coeff, base)
@register_node("arith.ConstIntBound")
class ConstIntBound(NodeBase):
"""Represent constant integer bound
Parameters
----------
min_value : int
The minimum value of the bound.
max_value : int
The maximum value of the bound.
"""
POS_INF = (1 << 63) - 1
NEG_INF = -POS_INF
def __init__(self, min_value, max_value):
self.__init_handle_by_constructor__(
_make_ConstIntBound, min_value, max_value)
class ConstraintScope:
"""Constraint scope.
Parameters
----------
fenter : function
A function that will be called to create an enter context.
Note
----
Do not create object directly, use Analyzer.constraint_scope
"""
def __init__(self, fenter):
self._fenter = fenter
self._fexit = None
def __enter__(self):
self._fexit = self._fenter()
def __exit__(self, ptype, value, trace):
self._fexit()
class Analyzer:
"""Integer arithmetic analyzer
This is a stateful analyzer class that can
be used to perform various symbolic integer analysis.
"""
def __init__(self):
_mod = _CreateAnalyzer()
self._const_int_bound = _mod("const_int_bound")
self._const_int_bound_update = _mod("const_int_bound_update")
self._bind = _mod("bind")
self._modular_set = _mod("modular_set")
self._enter_constraint_context = _mod("enter_constraint_context")
def const_int_bound(self, expr):
"""Find constant integer bound for expr.
Parameters
----------
expr : tvm.Expr
The expression.
Returns
-------
bound : ConstIntBound
The result bound
"""
return self._const_int_bound(expr)
def modular_set(self, expr):
"""Find a modular set that expr belongs to.
Parameters
----------
expr : tvm.Expr
The expression.
Returns
-------
result : ModularSet
The result.
"""
return self._modular_set(expr)
def bind(self, var, expr):
"""Bind a variable to the expression.
Parameters
----------
var : tvm.Var
The variable.
expr : tvm.Expr
The expression.
"""
return self._bind(var, expr)
def constraint_scope(self, constraint):
"""Create a constraint scope.
Parameters
----------
constraint : tvm.Expr
The constraint expression.
returns
-------
scope : ConstraintScope
The constraint scope
Examples
--------
.. code-block:: python
x = tvm.var("x")
analyzer = tvm.arith.Analyzer()
with analzyer.constraint_scope(x % 3 == 0):
# constraint in effect
assert analyzer.modular_set(x).coeff == 3
# constraint no longer in effect
assert analyzer.modular_set(x).coeff != 3
"""
def _fenter():
return self._enter_constraint_context(constraint)
return ConstraintScope(_fenter)
def update(self, var, info, override=False):
"""Update infomation about var
Parameters
----------
var : tvm.Var
The variable.
info : tvm.NodeBase
Related information.
override : bool
Whether allow override.
"""
if isinstance(info, ConstIntBound):
self._const_int_bound_update(var, info, override)
else:
raise TypeError(
"Do not know how to handle type {}".format(type(info)))
_init_api("tvm.arith") _init_api("tvm.arith")
...@@ -26,11 +26,6 @@ TVM_REGISTER_API("arith.intset_interval") ...@@ -26,11 +26,6 @@ TVM_REGISTER_API("arith.intset_interval")
*ret = IntSet::interval(args[0], args[1]); *ret = IntSet::interval(args[0], args[1]);
}); });
TVM_REGISTER_API("arith.EvalModular")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = EvalModular(args[0], Map<Var, IntSet>());
});
TVM_REGISTER_API("arith.DetectLinearEquation") TVM_REGISTER_API("arith.DetectLinearEquation")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = DetectLinearEquation(args[0], args[1]); *ret = DetectLinearEquation(args[0], args[1]);
...@@ -75,5 +70,56 @@ TVM_REGISTER_API("_IntSetIsEverything") ...@@ -75,5 +70,56 @@ TVM_REGISTER_API("_IntSetIsEverything")
*ret = args[0].operator IntSet().is_everything(); *ret = args[0].operator IntSet().is_everything();
}); });
TVM_REGISTER_API("arith._make_ConstIntBound")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ConstIntBoundNode::make(args[0], args[1]);
});
TVM_REGISTER_API("arith._make_ModularSet")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ModularSetNode::make(args[0], args[1]);
});
TVM_REGISTER_API("arith._CreateAnalyzer")
.set_body([](TVMArgs args, TVMRetValue* ret) {
using runtime::PackedFunc;
using runtime::TypedPackedFunc;
auto self = std::make_shared<Analyzer>();
auto f = [self](std::string name) -> PackedFunc {
if (name == "const_int_bound") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
*ret = self->const_int_bound(args[0]);
});
} else if (name == "modular_set") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
*ret = self->modular_set(args[0]);
});
} else if (name == "const_int_bound_update") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
self->const_int_bound.Update(args[0], args[1], args[2]);
});
} else if (name == "bind") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
auto& sptr = args[1].node_sptr();
if (sptr->is_type<Range::ContainerType>()) {
self->Bind(args[0], args[1].operator Range());
} else {
self->Bind(args[0], args[1].operator Expr());
}
});
} else if (name == "enter_constraint_context") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
auto ctx = std::make_shared<ConstraintContext>(self.get(), args[0]);
auto fexit = [ctx](TVMArgs, TVMRetValue*) mutable {
ctx.reset();
};
*ret = PackedFunc(fexit);
});
}
return PackedFunc();
};
*ret = TypedPackedFunc<PackedFunc(std::string)>(f);
});
} // namespace arith } // namespace arith
} // namespace tvm } // namespace tvm
/*!
* Copyright (c) 2019 by Contributors
* \file tvm/arithmetic/analyzer.cc
*/
#include <tvm/arithmetic.h>
namespace tvm {
namespace arith {
Analyzer::Analyzer()
: const_int_bound(this),
modular_set(this) {
}
void Analyzer::Bind(const VarExpr& v, const Expr& expr) {
Var var(v.node_);
this->const_int_bound.Update(var, this->const_int_bound(expr));
this->modular_set.Update(var, this->modular_set(expr));
}
void Analyzer::Bind(const VarExpr& v, const Range& range) {
Var var(v.node_);
this->const_int_bound.Bind(var, range);
// skip modular_set
}
ConstraintContext::ConstraintContext(Analyzer* analyzer, const Expr& constraint) {
// entering the scope.
auto f0 = analyzer->const_int_bound.EnterConstraint(constraint);
auto f1 = analyzer->modular_set.EnterConstraint(constraint);
// recovery function.
exit_ = [f0, f1]() {
if (f1 != nullptr) f1();
if (f0 != nullptr) f0();
};
}
bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) {
auto bd = this->const_int_bound(expr);
if (bd->min_value >= lower_bound) return true;
return false;
}
} // namespace arith
} // namespace tvm
/*!
* Copyright (c) 2019 by Contributors
* \file int_op_overflow.h
* \brief Utility functions to detect if an integer op will overflow.
*/
#ifndef TVM_ARITHMETIC_INT_OP_OVERFLOW_H_
#define TVM_ARITHMETIC_INT_OP_OVERFLOW_H_
#include <limits>
namespace tvm {
namespace arith {
/*!
* \brief Check if an integer op with operand x, y will overflow.
* \param x The left operand.
* \param y The left operand.
* \param min_value The minimum value of the domain.
* \param max_value The maximum value of the domain.
* \return Whether overflow can happen.
* \tparam Op The integer operator.
*/
template<typename Op>
inline bool WillOverflow(int64_t x,
int64_t y,
int64_t min_value,
int64_t max_value) {
return false;
}
template<>
bool WillOverflow<ir::Add>(int64_t x,
int64_t y,
int64_t min_value,
int64_t max_value) {
if ((y > 0) && (x > max_value - y)) return true;
if ((y < 0) && (x < min_value - y)) return true;
return false;
}
template<>
bool WillOverflow<ir::Sub>(int64_t x,
int64_t y,
int64_t min_value,
int64_t max_value) {
if ((y > 0) && (x < min_value + y)) return true;
if ((y < 0) && (x > max_value + y)) return true;
return false;
}
template<>
bool WillOverflow<ir::Mul>(int64_t x,
int64_t y,
int64_t min_value,
int64_t max_value) {
if (y == 0) return false;
if (y > 0) {
if (x < min_value / y) return true;
if (x > max_value / y) return true;
} else {
if (y == -1 && x == std::numeric_limits<int64_t>::min()) return true;
if (x > min_value / y) return true;
if (x < max_value / y) return true;
}
return false;
}
template<>
bool WillOverflow<ir::Mod>(int64_t x,
int64_t y,
int64_t min_value,
int64_t max_value) {
return y == 0;
}
} // namespace arith
} // namespace tvm
#endif // TVM_ARITHMETIC_INT_OP_OVERFLOW_H_
...@@ -54,23 +54,6 @@ struct StrideSet : public IntSetNode { ...@@ -54,23 +54,6 @@ struct StrideSet : public IntSetNode {
TVM_DECLARE_NODE_TYPE_INFO(StrideSet, IntSetNode); TVM_DECLARE_NODE_TYPE_INFO(StrideSet, IntSetNode);
}; };
/*!
* \brief Set represented by range of ModularEntry.
* Used for front-end modular analysis.
*/
struct ModularSet : public IntSetNode {
/*! \brief Internal modular entry */
ModularEntry e;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("base", &(e.base));
v->Visit("coeff", &(e.coeff));
}
static constexpr const char* _type_key = "ModularSet";
TVM_DECLARE_NODE_TYPE_INFO(ModularSet, IntSetNode);
};
} // namespace arith } // namespace arith
} // namespace tvm } // namespace tvm
......
/*!
* Copyright (c) 2017 by Contributors
* \file modular.cc
* \brief Modular analysis
*/
#include <tvm/ir.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_visitor.h>
#include <tvm/arithmetic.h>
#include <limits>
#include "int_set_internal.h"
namespace tvm {
namespace arith {
using namespace ir;
class ModularEvaluator
: public ExprFunctor<ModularEntry(const Expr&)> {
public:
explicit ModularEvaluator(
const std::unordered_map<
const Variable*, ModularEntry>& mod_map)
: mod_map_(mod_map) {
}
ModularEntry Eval(const Expr& e) {
return VisitExpr(e);
}
// default
ModularEntry VisitExprDefault_(const Node*) final {
return ModularEntry::everything();
}
// override combination rules.
ModularEntry VisitExpr_(const IntImm* op) final {
if (op->value < std::numeric_limits<int>::max()) {
ModularEntry ret;
ret.base = static_cast<int>(op->value);
ret.coeff = 0;
return ret;
} else {
return ModularEntry::everything();
}
}
ModularEntry VisitExpr_(const UIntImm* op) final {
if (op->value < static_cast<uint64_t>(
std::numeric_limits<int>::max())) {
ModularEntry ret;
ret.base = static_cast<int>(op->value);
ret.coeff = 0;
return ret;
} else {
return ModularEntry::everything();
}
}
ModularEntry VisitExpr_(const Variable* op) final {
auto it = mod_map_.find(op);
if (it != mod_map_.end()) {
return it->second;
} else {
return ModularEntry::everything();
}
}
ModularEntry VisitExpr_(const Add* op) final {
ModularEntry a = Eval(op->a);
ModularEntry b = Eval(op->b);
ModularEntry ret;
ret.coeff = ZeroAwareGCD(a.coeff, b.coeff);
ret.base = BaseSimplify(a.base + b.base, ret.coeff);
return ret;
}
ModularEntry VisitExpr_(const Sub* op) final {
ModularEntry a = Eval(op->a);
ModularEntry b = Eval(op->b);
ModularEntry ret;
ret.coeff = ZeroAwareGCD(a.coeff, b.coeff);
ret.base = BaseSimplify(a.base - b.base, ret.coeff);
return ret;
}
ModularEntry VisitExpr_(const Mul* op) final {
ModularEntry a = Eval(op->a);
ModularEntry b = Eval(op->b);
// Simplification rule, x, y, z are in Z
// (p x + n) (q y + m)
// -> pq xy + pm x + qn y + mn
// -> pq z + pm x + qn y + mn
int pq = a.coeff * b.coeff;
int pm = a.coeff * b.base;
int qn = a.base * b.coeff;
ModularEntry ret;
ret.coeff = ZeroAwareGCD(pq, ZeroAwareGCD(pm, qn));
ret.base = BaseSimplify(a.base * b.base, ret.coeff);
return ret;
}
ModularEntry VisitExpr_(const Div* op) final {
// a c x / c -> a x
// We cannot do cases where offset is non-zero
// because of different integer rounding in pos/neg
ModularEntry a = Eval(op->a);
ModularEntry b = Eval(op->b);
if (b.coeff == 0 &&
a.base == 0) {
CHECK_NE(b.base, 0);
if (a.coeff % b.base == 0) {
ModularEntry ret;
ret.coeff = a.coeff / b.base;
ret.base = 0;
return ret;
}
}
return ModularEntry::everything();
}
private:
const std::unordered_map<
const Variable*, ModularEntry>& mod_map_;
friend struct ModularEntry;
// simplify the base by putting it in range.
static int BaseSimplify(int base, int coeff) {
if (coeff == 0) return base;
base = base % coeff;
if (base < 0) base += coeff;
return base;
}
static int ZeroAwareGCD(int a, int b) {
CHECK_GE(a, 0);
CHECK_GE(b, 0);
if (a < b) std::swap(a, b);
if (b == 0) return a;
// perform GCD (greatest common divisor)
// ax + by = gcd(a, b) z if a != 0, b != 0
while (a % b != 0) {
a = a % b;
std::swap(a, b);
}
return b;
}
};
ModularEntry ModularEntry::Add(const ModularEntry& a,
const ModularEntry& b) {
ModularEntry ret;
ret.coeff = ModularEvaluator::ZeroAwareGCD(a.coeff, b.coeff);
ret.base = ModularEvaluator::BaseSimplify(a.base + b.base, ret.coeff);
return ret;
}
ModularEntry EvalModular(
const Expr& e,
const std::unordered_map<const Variable*, ModularEntry>& mod_map) {
return ModularEvaluator(mod_map)(e);
}
IntSet EvalModular(const Expr& e,
const Map<Var, IntSet>& mod_map) {
std::unordered_map<const Variable*, ModularEntry> mmap;
for (auto& kv : mod_map) {
const ModularSet* m = kv.second.as<ModularSet>();
CHECK(m) << "Need to pass ModularSet for Modular Analysis";
mmap[kv.first.get()] = m->e;
}
NodePtr<ModularSet> n = make_node<ModularSet>();
n->e = ModularEvaluator(mmap)(e);
return IntSet(n);
}
} // namespace arith
} // namespace tvm
/*!
* Copyright (c) 2019 by Contributors
* \file modular_set.cc
* \brief Modular set analysis
*/
#include <tvm/arithmetic.h>
#include <tvm/ir_operator.h>
#include <tvm/ir_functor_ext.h>
#include <limits>
#include "pattern_match.h"
namespace tvm {
namespace arith {
using namespace ir;
TVM_REGISTER_NODE_TYPE(ModularSetNode);
ModularSet ModularSetNode::make(int64_t coeff, int64_t base) {
auto node = make_node<ModularSetNode>();
node->coeff = coeff;
node->base = base;
return ModularSet(node);
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ModularSetNode>([](const ModularSetNode *op, IRPrinter *p) {
p->stream << "ModularSet("
<< "coeff=" << op->coeff << ", base="
<< op->base << ')';
});
// internal entry for const int bound
struct ModularSetAnalyzer::Entry {
int64_t coeff{1};
int64_t base{0};
bool is_const() const {
return coeff == 0;
}
};
class ModularSetAnalyzer::Impl :
public ExprFunctor<ModularSetAnalyzer::Entry(const Expr&)> {
public:
explicit Impl(Analyzer* parent)
: parent_(parent) {}
void Update(const Var& var,
const ModularSet& info,
bool override) {
if (!override) {
CHECK(!var_map_.count(var));
}
Entry e;
e.coeff = info->coeff;
e.base = info->base;
var_map_[var] = e;
}
// Detect useful constraints and use them in the analysis scope.
std::function<void()> EnterConstraint(const Expr& constraint) {
PVar<Var> var;
PVar<Integer> coeff, base;
// pattern match interesting constraints
if (((var % coeff) == base).Match(constraint)) {
Entry entry;
entry.coeff = coeff.Eval()->value;
entry.base = base.Eval()->value;
return UpdateByIntersect(var.Eval(), entry);
}
return nullptr;
}
// Override visitor behaviors
Entry VisitExprDefault_(const Node* op) final {
return Everything();
}
Entry VisitExpr_(const Cast* op) final {
return VisitExpr(op->value);
}
Entry VisitExpr_(const IntImm* op) final {
Entry ret;
ret.base = op->value;
ret.coeff = 0;
return ret;
}
Entry VisitExpr_(const UIntImm* op) final {
if (op->value < std::numeric_limits<int64_t>::max()) {
Entry ret;
ret.base = static_cast<int>(op->value);
ret.coeff = 0;
return ret;
} else {
return Everything();
}
}
Entry VisitExpr_(const Add* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
Entry ret;
ret.coeff = ZeroAwareGCD(a.coeff, b.coeff);
ret.base = BaseSimplify(a.base + b.base, ret.coeff);
return ret;
}
Entry VisitExpr_(const Sub* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
Entry ret;
ret.coeff = ZeroAwareGCD(a.coeff, b.coeff);
ret.base = BaseSimplify(a.base - b.base, ret.coeff);
return ret;
}
Entry VisitExpr_(const Mul* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
// Simplification rule, x, y, z are in Z
// (p x + n) (q y + m)
// -> pq xy + pm x + qn y + mn
// -> pq z + pm x + qn y + mn
int64_t pq = a.coeff * b.coeff;
int64_t pm = a.coeff * b.base;
int64_t qn = a.base * b.coeff;
Entry ret;
ret.coeff = ZeroAwareGCD(pq, ZeroAwareGCD(pm, qn));
ret.base = BaseSimplify(a.base * b.base, ret.coeff);
return ret;
}
Entry DivByConst(const Expr& lhs,
int64_t val,
bool round_down) {
Entry a = VisitExpr(lhs);
CHECK_NE(val, 0);
if (a.coeff % val == 0) {
Entry ret;
if (a.base == 0) {
// a c x / c -> a x
ret.coeff = std::abs(a.coeff / val);
ret.base = 0;
return ret;
}
// positive division have a clear rounding mode.
// Only handle case where we clearly know we need to round down.
if (a.base > 0 && val > 0 &&
(round_down || parent_->CanProveGreaterEqual(lhs, 0))) {
ret.coeff = a.coeff / val;
ret.base = a.base / val;
return ret;
}
}
return Everything();
}
Entry VisitExpr_(const Div* op) final {
Entry b = VisitExpr(op->b);
if (b.is_const()) {
return DivByConst(op->a, b.base, false);
}
return Everything();
}
Entry VisitExpr_(const Min* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
return Union(a, b);
}
Entry VisitExpr_(const Max* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
return Union(a, b);
}
Entry VisitExpr_(const Select* op) final {
Entry a = VisitExpr(op->true_value);
Entry b = VisitExpr(op->false_value);
return Union(a, b);
}
Entry VisitExpr_(const Call* op) final {
// only special handle >> which can be
// used for index calculation.
if (op->is_intrinsic(Call::shift_right)) {
return VisitRightShift(op);
} else {
return Everything();
}
}
Entry VisitExpr_(const Variable* op) final {
Var v = GetRef<Var>(op);
auto it = var_map_.find(v);
if (it != var_map_.end()) {
return it->second;
} else {
return Everything();
}
}
Entry VisitRightShift(const Call* op) {
Entry b = VisitExpr(op->args[1]);
// a c x / c -> a x
if (b.is_const()) {
return DivByConst(op->args[0], 1 << b.base, true);
}
return Everything();
}
private:
/*! \brief pointer to parent. */
Analyzer* parent_{nullptr};
// internal variable map
std::unordered_map<Var, Entry, ExprHash, ExprEqual> var_map_;
/*!
* \brief Update var by intersecting entry with var's current set.
* \param var The variable.
* \param entry The entry to be updated.
* \return The recovery function of the scope.
*/
std::function<void()> UpdateByIntersect(const Var& var, Entry entry) {
Entry old = Everything();
auto it = var_map_.find(var);
if (it != var_map_.end()) {
old = it->second;
}
var_map_[var] = Intersect(old, entry);
// reover function.
return [this, old, var]() {
var_map_[var] = old;
};
}
/*!
* \brief Create union of two sets.
* \param a The left operand.
* \param b the right operand.
*/
static Entry Union(Entry a, Entry b) {
// {ax + y} \cup {bz + h} => {gcd(a, b) x + {y or h}}
int64_t coeff = ZeroAwareGCD(a.coeff, b.coeff);
if (coeff == 0) {
if (a.base == b.base) return a;
return Everything();
}
int64_t base0 = a.base % coeff;
int64_t base1 = b.base % coeff;
Entry ret;
if (base0 == base1) {
ret.coeff = coeff;
ret.base = base0;
return ret;
} else {
ret.coeff = ZeroAwareGCD(ZeroAwareGCD(base0, base1), coeff);
ret.base = 0;
return ret;
}
}
/*!
* \brief Create interect of two sets.
* \param a The left operand.
* \param b the right operand.
*/
static Entry Intersect(Entry a, Entry b) {
// simple rule for now: pick higher constraints.
// TODO(team-team): Use extended euclidean algorithm.
if (a.coeff == 0) return a;
if (b.coeff == 0) return b;
if (a.coeff >= b.coeff) return a;
return b;
}
/*!
* \brief Simplify base so that it is in [0, coeff) when coeff != 0.
* \param base The base value.
* \param coeff The coeff value.
* \return The simplified base.
*/
static int64_t BaseSimplify(int64_t base, int64_t coeff) {
if (coeff == 0) return base;
base = base % coeff;
if (base < 0) base += coeff;
return base;
}
/*!
* \brief Take GCD of a and b.
* \param a The first operand.
* \param b The second operand.
* \return The result.
*/
static int64_t ZeroAwareGCD(int64_t a, int64_t b) {
if (a < 0) a = -a;
if (b < 0) b = -b;
if (a < b) std::swap(a, b);
if (b == 0) return a;
// perform GCD (greatest common divisor)
// ax + by = gcd(a, b) z if a != 0, b != 0
while (a % b != 0) {
a = a % b;
std::swap(a, b);
}
return b;
}
/*!
* \brief return everything dtype can represent.
* \return Bound that represent everything dtype can represent.
*/
static Entry Everything() {
Entry ret;
ret.coeff = 1; ret.base = 0;
return ret;
}
};
ModularSet ModularSetAnalyzer::operator()(const Expr& expr) {
Entry ret = impl_->VisitExpr(expr);
return ModularSetNode::make(ret.coeff, ret.base);
}
void ModularSetAnalyzer::Update(const Var& var,
const ModularSet& info,
bool override) {
impl_->Update(var, info, override);
}
std::function<void()> ModularSetAnalyzer::EnterConstraint(const Expr& constraint) {
return impl_->EnterConstraint(constraint);
}
ModularSetAnalyzer::ModularSetAnalyzer(Analyzer* parent)
: impl_(new Impl(parent)) {
}
ModularSetAnalyzer::~ModularSetAnalyzer() {
delete impl_;
}
} // namespace arith
} // namespace tvm
...@@ -25,6 +25,17 @@ ...@@ -25,6 +25,17 @@
* // The filled value is valid until the next call to Match. * // The filled value is valid until the next call to Match.
* return (max(x, y) + z).Eval(); * return (max(x, y) + z).Eval();
* } * }
*
* tvm::Var tx, ty;
* arith::PVar<Integer> c;
* arith::PVar<Var> v;
* // We can match integer and Var, both of which are
* // special case container of Expr
* CHECK((v * c).Match(tx * 3));
* CHECK_EQ(c.Eval()->value, 3);
* // cannot match c to ty
* CHECK(!(v * c).Match(tx * ty));
*
* \endcode * \endcode
* *
* \note The pattern matcher is not threadsafe, * \note The pattern matcher is not threadsafe,
...@@ -109,6 +120,22 @@ class PEqualChecker<Expr> { ...@@ -109,6 +120,22 @@ class PEqualChecker<Expr> {
} }
}; };
template<>
class PEqualChecker<Integer> {
public:
bool operator()(const Integer& lhs, const Integer& rhs) const {
return lhs->value == rhs->value;
}
};
template<>
class PEqualChecker<Var> {
public:
bool operator()(const Var& lhs, const Var& rhs) const {
return lhs.same_as(rhs);
}
};
/*! /*!
* \brief Pattern variable container. * \brief Pattern variable container.
* *
...@@ -123,7 +150,7 @@ template<typename T> ...@@ -123,7 +150,7 @@ template<typename T>
class PVar : public Pattern<PVar<T> > { class PVar : public Pattern<PVar<T> > {
public: public:
// Store PVars by reference in the expression. // Store PVars by reference in the expression.
using Nested = const PVar&; using Nested = const PVar<T>&;
void InitMatch_() const { void InitMatch_() const {
filled_ = false; filled_ = false;
...@@ -139,12 +166,23 @@ class PVar : public Pattern<PVar<T> > { ...@@ -139,12 +166,23 @@ class PVar : public Pattern<PVar<T> > {
} }
} }
template<typename NodeRefType,
typename = typename std::enable_if<
std::is_base_of<NodeRefType, T>::value>::type>
bool Match_(const NodeRefType& value) const {
if (const auto* ptr = value.template as<typename T::ContainerType>()) {
return Match_(GetRef<T>(ptr));
} else {
return false;
}
}
T Eval() const { T Eval() const {
CHECK(filled_); CHECK(filled_);
return value_; return value_;
} }
private: protected:
/*! \brief The matched value */ /*! \brief The matched value */
mutable T value_; mutable T value_;
/*! \brief whether the variable has been filled */ /*! \brief whether the variable has been filled */
...@@ -171,6 +209,7 @@ class PConst : public Pattern<PConst<T> > { ...@@ -171,6 +209,7 @@ class PConst : public Pattern<PConst<T> > {
T Eval() const { T Eval() const {
return value_; return value_;
} }
private: private:
const T value_; const T value_;
}; };
......
/*!
* Copyright (c) 2018 by Contributors
* \file codegen_common.h
* \brief Common utility for codegen.
*/
#ifndef TVM_CODEGEN_CODEGEN_COMMON_H_
#define TVM_CODEGEN_CODEGEN_COMMON_H_
#include <tvm/arithmetic.h>
#include "../arithmetic/compute_expr.h"
namespace tvm {
namespace codegen {
/*!
* \brief Visit AssertStmt recursively, update align_map from condition.
* \param op The AssertStmt
* \param align_map The alignmap
* \param fvisit The recursive visitor
* \tparam FVisit the recursive visitor
*/
template<typename FVisit>
inline void VisitAssert(
const ir::AssertStmt* op,
std::unordered_map<const Variable*, arith::ModularEntry>* align_map,
FVisit fvisit) {
using namespace ir;
auto& align_map_ = *align_map;
// Detect useful invariant pattern and use them to visit child.
// Pattern: Var % const == 0
// TODO(tqchen) merge these pattern to a generic scope info visitor.
if (const EQ* eq = op->condition.as<EQ>()) {
const Mod* mod = eq->a.as<Mod>();
int64_t factor = 0, offset = 0;
if (mod && arith::GetConst(eq->b, &offset)) {
const Variable *var = mod->a.as<Variable>();
if (var && arith::GetConst(mod->b, &factor)) {
arith::ModularEntry old = align_map_[var];
if (factor > old.coeff) {
arith::ModularEntry e;
e.coeff = static_cast<int>(factor);
e.base = static_cast<int>(offset);
// new alignment info,
align_map_[var] = e;
fvisit(op->body);
// restore old info
align_map_[var] = old;
return;
}
}
}
}
fvisit(op->body);
}
} // namespace codegen
} // namespace tvm
#endif // TVM_CODEGEN_CODEGEN_COMMON_H_
...@@ -9,7 +9,6 @@ ...@@ -9,7 +9,6 @@
#include <tvm/runtime/c_runtime_api.h> #include <tvm/runtime/c_runtime_api.h>
#include "codegen_llvm.h" #include "codegen_llvm.h"
#include "codegen_cpu.h" #include "codegen_cpu.h"
#include "../codegen_common.h"
#include "../../pass/ir_util.h" #include "../../pass/ir_util.h"
#include "../../arithmetic/compute_expr.h" #include "../../arithmetic/compute_expr.h"
...@@ -84,9 +83,9 @@ void CodeGenLLVM::AddFunction(const LoweredFunc& f) { ...@@ -84,9 +83,9 @@ void CodeGenLLVM::AddFunction(const LoweredFunc& f) {
void CodeGenLLVM::InitFuncState() { void CodeGenLLVM::InitFuncState() {
var_map_.clear(); var_map_.clear();
alias_var_set_.clear(); alias_var_set_.clear();
align_map_.clear();
alloc_storage_info_.clear(); alloc_storage_info_.clear();
volatile_buf_.clear(); volatile_buf_.clear();
analyzer_.reset(new arith::Analyzer());
} }
void CodeGenLLVM::AddFunctionInternal(const LoweredFunc& f, bool ret_void) { void CodeGenLLVM::AddFunctionInternal(const LoweredFunc& f, bool ret_void) {
...@@ -381,14 +380,16 @@ void CodeGenLLVM::GetAlignment(Type t, ...@@ -381,14 +380,16 @@ void CodeGenLLVM::GetAlignment(Type t,
*p_native_bits = native_vector_bits_; *p_native_bits = native_vector_bits_;
} }
arith::ModularEntry me = arith::EvalModular(index, align_map_); arith::ModularSet me = analyzer_->modular_set(index);
int64_t base = me->base;
int64_t coeff = me->coeff;
int align_bits = t.bits(); int align_bits = t.bits();
while (align_bits < max_align_bits && while (align_bits < max_align_bits &&
me.base % 2 == 0 && base % 2 == 0 &&
me.coeff % 2 == 0) { coeff % 2 == 0) {
me.base = me.base / 2; base = base / 2;
me.coeff = me.coeff / 2; coeff = coeff / 2;
align_bits *= 2; align_bits *= 2;
} }
if (align_bits < 8) { if (align_bits < 8) {
...@@ -874,7 +875,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Select* op) { ...@@ -874,7 +875,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Select* op) {
llvm::Value* CodeGenLLVM::VisitExpr_(const Let* op) { llvm::Value* CodeGenLLVM::VisitExpr_(const Let* op) {
CHECK(!var_map_.count(op->var.get())); CHECK(!var_map_.count(op->var.get()));
var_map_[op->var.get()] = MakeValue(op->value); var_map_[op->var.get()] = MakeValue(op->value);
align_map_[op->var.get()] = EvalModular(op->value, align_map_); analyzer_->Bind(op->var, op->value);
return MakeValue(op->body); return MakeValue(op->body);
} }
...@@ -998,6 +999,7 @@ void CodeGenLLVM::VisitStmt_(const Store* op) { ...@@ -998,6 +999,7 @@ void CodeGenLLVM::VisitStmt_(const Store* op) {
void CodeGenLLVM::VisitStmt_(const For* op) { void CodeGenLLVM::VisitStmt_(const For* op) {
CHECK(is_zero(op->min)); CHECK(is_zero(op->min));
analyzer_->Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent));
if (op->for_type == ForType::Unrolled) { if (op->for_type == ForType::Unrolled) {
LOG(WARNING) << "Unroll hint get ignore at CodeGenLLVM backend, " LOG(WARNING) << "Unroll hint get ignore at CodeGenLLVM backend, "
<< " consider set unroll_explicit=True"; << " consider set unroll_explicit=True";
...@@ -1078,6 +1080,7 @@ void CodeGenLLVM::VisitStmt_(const AttrStmt* op) { ...@@ -1078,6 +1080,7 @@ void CodeGenLLVM::VisitStmt_(const AttrStmt* op) {
if (iv->thread_tag.length() != 0) { if (iv->thread_tag.length() != 0) {
if (!var_map_.count(iv->var.get())) { if (!var_map_.count(iv->var.get())) {
var_map_[iv->var.get()] = GetThreadIndex(iv); var_map_[iv->var.get()] = GetThreadIndex(iv);
analyzer_->Bind(iv->var, Range::make_by_min_extent(0, op->value));
} }
} }
} else if (op->attr_key == ir::attr::storage_scope) { } else if (op->attr_key == ir::attr::storage_scope) {
...@@ -1099,21 +1102,19 @@ void CodeGenLLVM::VisitStmt_(const AttrStmt* op) { ...@@ -1099,21 +1102,19 @@ void CodeGenLLVM::VisitStmt_(const AttrStmt* op) {
} }
void CodeGenLLVM::VisitStmt_(const AssertStmt* op) { void CodeGenLLVM::VisitStmt_(const AssertStmt* op) {
VisitAssert(op, &align_map_, [this](const Stmt& body) { arith::ConstraintContext cctx(analyzer_.get(), op->condition);
this->VisitStmt(body); this->VisitStmt(op->body);
});
} }
void CodeGenLLVM::VisitStmt_(const LetStmt* op) { void CodeGenLLVM::VisitStmt_(const LetStmt* op) {
CHECK(!var_map_.count(op->var.get())); CHECK(!var_map_.count(op->var.get()));
CHECK(!align_map_.count(op->var.get()));
if (op->var.type().is_handle()) { if (op->var.type().is_handle()) {
if (!is_restricted_) { if (!is_restricted_) {
alias_var_set_.insert(op->var.get()); alias_var_set_.insert(op->var.get());
} }
} }
var_map_[op->var.get()] = MakeValue(op->value); var_map_[op->var.get()] = MakeValue(op->value);
align_map_[op->var.get()] = EvalModular(op->value, align_map_); analyzer_->Bind(op->var, op->value);
this->VisitStmt(op->body); this->VisitStmt(op->body);
} }
......
...@@ -23,7 +23,6 @@ namespace codegen { ...@@ -23,7 +23,6 @@ namespace codegen {
using namespace ir; using namespace ir;
/*! /*!
* \brief A base class to generate a LLVM. * \brief A base class to generate a LLVM.
*/ */
...@@ -267,8 +266,8 @@ class CodeGenLLVM : ...@@ -267,8 +266,8 @@ class CodeGenLLVM :
std::unordered_map<std::string, llvm::Constant*> str_map_; std::unordered_map<std::string, llvm::Constant*> str_map_;
// Whether current function is restricted // Whether current function is restricted
bool is_restricted_{true}; bool is_restricted_{true};
// The alignment information // The analyzer information
std::unordered_map<const Variable*, arith::ModularEntry> align_map_; std::unique_ptr<arith::Analyzer> analyzer_;
// set of var that are not restricted(can alias) // set of var that are not restricted(can alias)
std::unordered_set<const Variable*> alias_var_set_; std::unordered_set<const Variable*> alias_var_set_;
// set of volatile buffer. // set of volatile buffer.
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <string> #include <string>
#include "../codegen_common.h" #include "../../arithmetic/compute_expr.h"
#include "codegen_spirv.h" #include "codegen_spirv.h"
namespace tvm { namespace tvm {
...@@ -66,7 +66,7 @@ void CodeGenSPIRV::InitFuncState() { ...@@ -66,7 +66,7 @@ void CodeGenSPIRV::InitFuncState() {
std::fill(workgroup_size_, workgroup_size_ + 3, 1); std::fill(workgroup_size_, workgroup_size_ + 3, 1);
var_map_.clear(); var_map_.clear();
storage_info_.clear(); storage_info_.clear();
align_map_.clear(); analyzer_.reset(new arith::Analyzer());
builder_.reset(new spirv::IRBuilder()); builder_.reset(new spirv::IRBuilder());
builder_->InitHeader(); builder_->InitHeader();
} }
...@@ -217,7 +217,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const Select* op) { ...@@ -217,7 +217,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const Select* op) {
spirv::Value CodeGenSPIRV::VisitExpr_(const Let* op) { spirv::Value CodeGenSPIRV::VisitExpr_(const Let* op) {
CHECK(!var_map_.count(op->var.get())); CHECK(!var_map_.count(op->var.get()));
var_map_[op->var.get()] = MakeValue(op->value); var_map_[op->var.get()] = MakeValue(op->value);
align_map_[op->var.get()] = EvalModular(op->value, align_map_); analyzer_->Bind(op->var, op->value);
return MakeValue(op->body); return MakeValue(op->body);
} }
...@@ -378,9 +378,9 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const Load* op) { ...@@ -378,9 +378,9 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const Load* op) {
if (const Ramp* ramp = op->index.as<Ramp>()) { if (const Ramp* ramp = op->index.as<Ramp>()) {
if (is_one(ramp->stride)) { if (is_one(ramp->stride)) {
CHECK_EQ(ramp->lanes, op->type.lanes()); CHECK_EQ(ramp->lanes, op->type.lanes());
arith::ModularEntry me = arith::EvalModular(ramp->base, align_map_); arith::ModularSet me = analyzer_->modular_set(ramp->base);
CHECK((me.coeff % ramp->lanes) == 0 && CHECK((me->coeff % ramp->lanes) == 0 &&
(me.base % ramp->lanes) == 0) (me->base % ramp->lanes) == 0)
<< "Only aligned vector access is allowed in SPIRV"; << "Only aligned vector access is allowed in SPIRV";
Expr vec_index = ir::Simplify( Expr vec_index = ir::Simplify(
ramp->base / make_const(ramp->base.type(), ramp->lanes)); ramp->base / make_const(ramp->base.type(), ramp->lanes));
...@@ -458,9 +458,9 @@ void CodeGenSPIRV::VisitStmt_(const Store* op) { ...@@ -458,9 +458,9 @@ void CodeGenSPIRV::VisitStmt_(const Store* op) {
if (const Ramp* ramp = op->index.as<Ramp>()) { if (const Ramp* ramp = op->index.as<Ramp>()) {
if (is_one(ramp->stride)) { if (is_one(ramp->stride)) {
CHECK_EQ(ramp->lanes, op->value.type().lanes()); CHECK_EQ(ramp->lanes, op->value.type().lanes());
arith::ModularEntry me = arith::EvalModular(ramp->base, align_map_); arith::ModularSet me = analyzer_->modular_set(ramp->base);
CHECK((me.coeff % ramp->lanes) == 0 && CHECK((me->coeff % ramp->lanes) == 0 &&
(me.base % ramp->lanes) == 0) (me->base % ramp->lanes) == 0)
<< "Only aligned vector access is allowed in SPIRV"; << "Only aligned vector access is allowed in SPIRV";
Expr vec_index = ir::Simplify( Expr vec_index = ir::Simplify(
ramp->base / make_const(ramp->base.type(), ramp->lanes)); ramp->base / make_const(ramp->base.type(), ramp->lanes));
...@@ -477,6 +477,7 @@ void CodeGenSPIRV::VisitStmt_(const Store* op) { ...@@ -477,6 +477,7 @@ void CodeGenSPIRV::VisitStmt_(const Store* op) {
void CodeGenSPIRV::VisitStmt_(const For* op) { void CodeGenSPIRV::VisitStmt_(const For* op) {
CHECK(is_zero(op->min)); CHECK(is_zero(op->min));
analyzer_->Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent));
spirv::Value init_value = MakeValue(op->min); spirv::Value init_value = MakeValue(op->min);
spirv::Value extent_value = MakeValue(op->extent); spirv::Value extent_value = MakeValue(op->extent);
// Must get init label after making value(to make sure they are correct) // Must get init label after making value(to make sure they are correct)
...@@ -589,6 +590,7 @@ void CodeGenSPIRV::VisitStmt_(const AttrStmt* op) { ...@@ -589,6 +590,7 @@ void CodeGenSPIRV::VisitStmt_(const AttrStmt* op) {
if (iv->thread_tag.length() != 0) { if (iv->thread_tag.length() != 0) {
if (!var_map_.count(iv->var.get())) { if (!var_map_.count(iv->var.get())) {
var_map_[iv->var.get()] = GetThreadIndex(iv, op->value); var_map_[iv->var.get()] = GetThreadIndex(iv, op->value);
analyzer_->Bind(iv->var, Range::make_by_min_extent(0, op->value));
} }
} }
} else if (op->attr_key == ir::attr::storage_scope) { } else if (op->attr_key == ir::attr::storage_scope) {
...@@ -605,17 +607,15 @@ void CodeGenSPIRV::VisitStmt_(const AttrStmt* op) { ...@@ -605,17 +607,15 @@ void CodeGenSPIRV::VisitStmt_(const AttrStmt* op) {
} }
void CodeGenSPIRV::VisitStmt_(const AssertStmt* op) { void CodeGenSPIRV::VisitStmt_(const AssertStmt* op) {
VisitAssert(op, &align_map_, [this](const Stmt& body) { arith::ConstraintContext cctx(analyzer_.get(), op->condition);
this->VisitStmt(body); this->VisitStmt(op->body);
});
} }
void CodeGenSPIRV::VisitStmt_(const LetStmt* op) { void CodeGenSPIRV::VisitStmt_(const LetStmt* op) {
CHECK(!var_map_.count(op->var.get())); CHECK(!var_map_.count(op->var.get()));
CHECK(!align_map_.count(op->var.get()));
CHECK(!op->var.type().is_handle()); CHECK(!op->var.type().is_handle());
var_map_[op->var.get()] = MakeValue(op->value); var_map_[op->var.get()] = MakeValue(op->value);
align_map_[op->var.get()] = EvalModular(op->value, align_map_); analyzer_->Bind(op->var, op->value);
this->VisitStmt(op->body); this->VisitStmt(op->body);
} }
......
...@@ -122,8 +122,8 @@ class CodeGenSPIRV: ...@@ -122,8 +122,8 @@ class CodeGenSPIRV:
std::unordered_map<const Variable*, StorageInfo> storage_info_; std::unordered_map<const Variable*, StorageInfo> storage_info_;
// The definition of local variable. // The definition of local variable.
std::unordered_map<const Variable*, spirv::Value> var_map_; std::unordered_map<const Variable*, spirv::Value> var_map_;
// The alignment information // The analyzer.
std::unordered_map<const Variable*, arith::ModularEntry> align_map_; std::unique_ptr<arith::Analyzer> analyzer_;
}; };
} // namespace codegen } // namespace codegen
......
...@@ -936,10 +936,8 @@ class VectorAllocRewriter : public IRMutator { ...@@ -936,10 +936,8 @@ class VectorAllocRewriter : public IRMutator {
tvec[0].lanes() != op->type.lanes()) { tvec[0].lanes() != op->type.lanes()) {
int factor = tvec[0].lanes() / op->type.lanes(); int factor = tvec[0].lanes() / op->type.lanes();
Array<Expr> extents = op->extents; Array<Expr> extents = op->extents;
arith::ModularEntry me = EvalModular( arith::ModularSet me = analyzer_.modular_set(extents[extents.size() - 1]);
extents[extents.size() - 1], if (me->base % factor == 0 && me->coeff % factor == 0) {
std::unordered_map<const Variable*, arith::ModularEntry>());
if (me.base % factor == 0 && me.coeff % factor == 0) {
extents.Set(extents.size() - 1, extents.Set(extents.size() - 1,
extents[extents.size() - 1] / make_const(extents[0].type(), factor)); extents[extents.size() - 1] / make_const(extents[0].type(), factor));
return Allocate::make( return Allocate::make(
...@@ -959,6 +957,8 @@ class VectorAllocRewriter : public IRMutator { ...@@ -959,6 +957,8 @@ class VectorAllocRewriter : public IRMutator {
// Internal access map // Internal access map
std::unordered_map<const Variable*, std::vector<Type> > acc_map_; std::unordered_map<const Variable*, std::vector<Type> > acc_map_;
// internal analyzer
arith::Analyzer analyzer_;
}; };
......
...@@ -107,6 +107,23 @@ TEST(Pattern, Basic) { ...@@ -107,6 +107,23 @@ TEST(Pattern, Basic) {
} }
} }
TEST(Pattern, Integer) {
using namespace tvm;
tvm::Var tx, ty;
arith::PVar<Integer> c;
arith::PVar<Var> v;
{
// We can match integer and Var, both of which are
// special case container of Expr
CHECK((v * c).Match(tx * 3));
CHECK_EQ(c.Eval()->value, 3);
}
// cannot match c to ty
CHECK(!(v * c).Match(tx * ty));
// cannot match tx + 1 to v
CHECK(!(v * c).Match((tx + 1) * 3));
}
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe"; testing::FLAGS_gtest_death_test_style = "threadsafe";
......
GTEST_LIB=$(GTEST_PATH)/lib/
GTEST_INC=$(GTEST_PATH)/include/
TEST_SRC = $(wildcard tests/cpp/*_test.cc)
TEST = $(patsubst tests/cpp/%_test.cc, tests/cpp/%_test, $(TEST_SRC))
tests/cpp/%_test: tests/cpp/%_test.cc lib/libtvm.so
$(CXX) -std=c++11 $(CFLAGS) -MM -MT tests/cpp/$* $< >tests/cpp/$*.d
$(CXX) -std=c++11 $(CFLAGS) -I$(GTEST_INC) -o $@ $(filter %.cc %.a, $^) \
-L$(GTEST_LIB) $(LDFLAGS) -lgtest -Llib -ltvm
-include tests/cpp/*.d
import tvm
def test_dtype_bound():
analyzer = tvm.arith.Analyzer()
x = tvm.var("x", dtype="int64")
bd = analyzer.const_int_bound(x)
assert bd.min_value == bd.NEG_INF
assert bd.max_value == bd.POS_INF
x = tvm.var("x", dtype="int8")
bd = analyzer.const_int_bound(x)
assert bd.min_value == -128
assert bd.max_value == 127
x = tvm.var("x", dtype="uint8")
bd = analyzer.const_int_bound(x)
assert bd.min_value == 0
assert bd.max_value == 255
def test_cast_bound():
analyzer = tvm.arith.Analyzer()
x = tvm.var("x", dtype="int8")
bd = analyzer.const_int_bound((x % 3).astype("uint32"))
assert bd.min_value == 0
assert bd.max_value == 2
bd = analyzer.const_int_bound(
(x % 3).astype("float32").astype("int32"))
assert bd.min_value == -2
assert bd.max_value == 2
def test_add_sub_bound():
analyzer = tvm.arith.Analyzer()
x, y = tvm.var("x", "int64"), tvm.var("y", "int64")
bd = analyzer.const_int_bound(x + y)
assert bd.min_value == bd.NEG_INF
assert bd.max_value == bd.POS_INF
analyzer.update(x, tvm.arith.ConstIntBound(0, 4))
analyzer.update(y, tvm.arith.ConstIntBound(1, 10))
bd = analyzer.const_int_bound(x + y)
assert bd.min_value == 1
assert bd.max_value == 14
bd = analyzer.const_int_bound(x - y)
assert bd.min_value == -10
assert bd.max_value == 3
analyzer.update(x, tvm.arith.ConstIntBound(0, bd.POS_INF), override=True)
bd = analyzer.const_int_bound(x - y)
assert bd.min_value == -10
assert bd.max_value == bd.POS_INF
bd = analyzer.const_int_bound(1 - x)
assert bd.min_value == bd.NEG_INF
assert bd.max_value == 1
def test_mul_bound():
analyzer = tvm.arith.Analyzer()
x, y = tvm.var("x"), tvm.var("y")
analyzer.update(x, tvm.arith.ConstIntBound(-2, 4))
analyzer.update(y, tvm.arith.ConstIntBound(4, 10))
bd = analyzer.const_int_bound(x * y + 20)
assert bd.min_value == 0
assert bd.max_value == 60
analyzer.update(x, tvm.arith.ConstIntBound(-3, 4), override=True)
analyzer.update(y, tvm.arith.ConstIntBound(-8, 2), override=True)
bd = analyzer.const_int_bound(x * y)
assert bd.min_value == -32
assert bd.max_value == 24
analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, 4), override=True)
analyzer.update(y, tvm.arith.ConstIntBound(-8, 2), override=True)
bd = analyzer.const_int_bound(x * y)
assert bd.min_value == bd.NEG_INF
assert bd.max_value == bd.POS_INF
def test_div_bound():
analyzer = tvm.arith.Analyzer()
x, y = tvm.var("x"), tvm.var("y")
analyzer.update(x, tvm.arith.ConstIntBound(-9, 4))
analyzer.update(y, tvm.arith.ConstIntBound(4, 10))
bd = analyzer.const_int_bound(x / y)
assert bd.min_value == -2
analyzer.update(x, tvm.arith.ConstIntBound(-9, 4), override=True)
analyzer.update(y, tvm.arith.ConstIntBound(-2, 0), override=True)
bd = analyzer.const_int_bound(x / y)
assert bd.min_value == -4
assert bd.max_value == 9
analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, 4), override=True)
analyzer.update(y, tvm.arith.ConstIntBound(-2, 1), override=True)
bd = analyzer.const_int_bound(x / y)
assert bd.min_value == bd.NEG_INF
assert bd.max_value == bd.POS_INF
def test_mod_bound():
analyzer = tvm.arith.Analyzer()
x, y = tvm.var("x"), tvm.var("y")
analyzer.update(x, tvm.arith.ConstIntBound(-9, 4))
analyzer.update(y, tvm.arith.ConstIntBound(4, 10))
bd = analyzer.const_int_bound(x % y)
assert bd.min_value == -9
assert bd.max_value == 4
analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, bd.POS_INF), override=True)
analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True)
bd = analyzer.const_int_bound(x % y)
assert bd.min_value == -9
assert bd.max_value == 9
analyzer.update(x, tvm.arith.ConstIntBound(1, bd.POS_INF), override=True)
analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True)
bd = analyzer.const_int_bound(x % y)
assert bd.min_value == 0
assert bd.max_value == 9
def test_min_max_bound():
analyzer = tvm.arith.Analyzer()
x, y = tvm.var("x"), tvm.var("y")
analyzer.update(x, tvm.arith.ConstIntBound(-9, 11))
analyzer.update(y, tvm.arith.ConstIntBound(4, 10))
bd = analyzer.const_int_bound(tvm.min(x, y))
assert bd.min_value == -9
assert bd.max_value == 10
analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, bd.POS_INF), override=True)
analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True)
bd = analyzer.const_int_bound(tvm.min(x, y))
assert bd.min_value == bd.NEG_INF
assert bd.max_value == 10
bd = analyzer.const_int_bound(tvm.max(x, y))
assert bd.min_value == 4
assert bd.max_value == bd.POS_INF
analyzer.update(x, tvm.arith.ConstIntBound(1, bd.POS_INF), override=True)
analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True)
bd = analyzer.const_int_bound(tvm.max(x, y))
assert bd.min_value == 4
assert bd.max_value == bd.POS_INF
def test_select_bound():
analyzer = tvm.arith.Analyzer()
x, y = tvm.var("x"), tvm.var("y")
analyzer.update(x, tvm.arith.ConstIntBound(-9, 11))
analyzer.update(y, tvm.arith.ConstIntBound(4, 10))
bd = analyzer.const_int_bound(
tvm.expr.Select(x > 1, (y < 0).astype("int32"), y + 1))
assert bd.min_value == 0
assert bd.max_value == 11
def test_shift_and_bound():
analyzer = tvm.arith.Analyzer()
x, y = tvm.var("x"), tvm.var("y")
analyzer.update(x, tvm.arith.ConstIntBound(-9, 11))
analyzer.update(y, tvm.arith.ConstIntBound(2, 10))
bd = analyzer.const_int_bound(x >> y)
assert bd.min_value == -3
assert bd.max_value == 2
bd = analyzer.const_int_bound(x & y)
assert bd.min_value == 0
assert bd.max_value == 10
analyzer.update(x, tvm.arith.ConstIntBound(10, 11), override=True)
bd = analyzer.const_int_bound(x & y)
assert bd.min_value == 0
assert bd.max_value == 10
def test_mix_index_bound():
analyzer = tvm.arith.Analyzer()
x, y = tvm.var("x"), tvm.var("y")
analyzer.update(x, tvm.arith.ConstIntBound(0, 24 - 1))
analyzer.update(y, tvm.arith.ConstIntBound(0, 3 - 1))
bd = analyzer.const_int_bound((x % 8) + (x / 8) * 8)
assert bd.min_value == 0
assert bd.max_value == 24 - 1
bd = analyzer.const_int_bound(y + x * 3)
assert bd.min_value == 0
assert bd.max_value == 24 * 3 - 1
bd = analyzer.const_int_bound((x % 7) + (x / 7) * 7)
assert bd.min_value == 0
assert bd.max_value == (23 // 7) * 7 + 6
if __name__ == "__main__":
test_dtype_bound()
test_cast_bound()
test_add_sub_bound()
test_mul_bound()
test_div_bound()
test_mod_bound()
test_min_max_bound()
test_select_bound()
test_shift_and_bound()
test_mix_index_bound()
import tvm
def test_basic():
a = tvm.var()
b = tvm.var()
m = tvm.arith.EvalModular(a * 4 + b * 6 + 7)
assert m.coeff == 2
assert m.base == 1
m = tvm.arith.EvalModular((a * 4 + 1) * (b * 8 + 3))
assert m.coeff == 4
assert m.base == 3
m = tvm.arith.EvalModular((a * 4 + 1) / (b * 8 + 3))
assert m.coeff == 1
assert m.base == 0
m = tvm.arith.EvalModular((a * 4 + 1) * (b * 8 / 4))
assert m.coeff == 2
assert m.base == 0
m = tvm.arith.EvalModular((a * 12 + 1) - (b * 3 * 7 + 2))
assert m.coeff == 3
assert m.base == 2
m = tvm.arith.EvalModular(a * 12 + tvm.min(b * 3 * 7, 2))
assert m.coeff == 1
assert m.base == 0
if __name__ == "__main__":
test_basic()
import tvm
def test_cast():
analyzer = tvm.arith.Analyzer()
x = tvm.var("x", dtype="int8")
m = analyzer.modular_set((x * 3).astype("uint32"))
assert m.coeff == 3
assert m.base == 0
m = analyzer.modular_set(
(x * 3 + 1).astype("float32").astype("int32"))
assert m.coeff == 3
assert m.base == 1
def test_add_sub():
analyzer = tvm.arith.Analyzer()
x, y = tvm.var("x", "int64"), tvm.var("y", "int64")
m = analyzer.modular_set(x * 6 + y * 4)
assert m.coeff == 2
assert m.base == 0
analyzer.bind(y, x * 4 + 1)
m = analyzer.modular_set(1 - y)
assert m.coeff == 4
assert m.base == 0
def test_mul():
analyzer = tvm.arith.Analyzer()
x, y = tvm.var("x"), tvm.var("y")
m = analyzer.modular_set((x * 4 + 2) * (y * 6 + 1))
assert m.coeff == 4
assert m.base == 2
def test_div_shift():
analyzer = tvm.arith.Analyzer()
x, y = tvm.var("x"), tvm.var("y")
# not sure if x is non-negative
m = analyzer.modular_set((x * 4 + 2) / 2)
assert m.coeff == 1
assert m.base == 0
# right shift always round down so it is fine
m = analyzer.modular_set((x * 4 + 2) >> 1)
assert m.coeff == 2
assert m.base == 1
# x is non-negative
analyzer.update(x, tvm.arith.ConstIntBound(0, 100))
m = analyzer.modular_set((x * 4 + 2) / 2)
assert m.coeff == 2
assert m.base == 1
def test_min_max_select():
analyzer = tvm.arith.Analyzer()
x, y = tvm.var("x"), tvm.var("y")
m = analyzer.modular_set(tvm.min(x * 3, y * 9))
assert m.coeff == 3
assert m.base == 0
m = analyzer.modular_set(tvm.max(x * 3 + 1, y * 9 + 4))
assert m.coeff == 3
assert m.base == 1
m = analyzer.modular_set(tvm.expr.Select(x > 0, x * 3 + 1, y * 9 + 2))
assert m.coeff == 1
assert m.base == 0
def test_mix_index():
a = tvm.var("a")
b = tvm.var("b")
analyzer = tvm.arith.Analyzer()
m = analyzer.modular_set(a * 4 + b * 6 + 7)
assert m.coeff == 2
assert m.base == 1
m = analyzer.modular_set((a * 4 + 1) * (b * 8 + 3))
assert m.coeff == 4
assert m.base == 3
m = analyzer.modular_set((a * 4 + 1) / (b * 8 + 3))
assert m.coeff == 1
assert m.base == 0
m = analyzer.modular_set((a * 4 + 1) * (b * 8 / 4))
assert m.coeff == 2
assert m.base == 0
m = analyzer.modular_set((a * 12 + 1) - (b * 3 * 7 + 2))
assert m.coeff == 3
assert m.base == 2
m = analyzer.modular_set(a * 12 + tvm.min(b * 3 * 7, 2))
assert m.coeff == 1
assert m.base == 0
def test_constraint_scope():
a = tvm.var("a")
b = tvm.var("b")
analyzer = tvm.arith.Analyzer()
with analyzer.constraint_scope(b % 4 == 2):
m = analyzer.modular_set(b + 1)
assert m.coeff == 4
assert m.base == 3
with analyzer.constraint_scope(a % 2 == 1):
m = analyzer.modular_set(b + a * 2)
assert m.coeff == 4
assert m.base == 0
m = analyzer.modular_set(b + a * 2)
assert m.coeff == 2
assert m.base == 0
m = analyzer.modular_set(b + 1)
assert m.coeff == 1
assert m.base == 0
if __name__ == "__main__":
test_cast()
test_add_sub()
test_mul()
test_div_shift()
test_min_max_select()
test_mix_index()
test_constraint_scope()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment