Unverified Commit 153417a5 by Tianqi Chen Committed by GitHub

[ARITH] Revamp IntSet (#3272)

parent 9bb16872
......@@ -328,71 +328,14 @@ class ConstraintContext {
std::function<void()> exit_;
};
/*!
* \brief Analyzer that contains bunch of sub-analyzers.
*
* Each sub-analyzer can make use of another sub-analyzer
* by weak reference of this.
*
* NOTE for sub-analyzer developers:
* If the analyzer uses memoization, we need to clear the internal
* cache when information about a Var has been overrideen.
*/
class Analyzer {
public:
/*! \brief sub-analyzer: const integer bound */
ConstIntBoundAnalyzer const_int_bound;
/*! \brief sub-analyzer: modular set */
ModularSetAnalyzer modular_set;
/*! \brief sub-analyzer rewrite simplify */
RewriteSimplifier rewrite_simplify;
/*! \brief sub-analyzer canonical simplify */
CanonicalSimplifier canonical_simplify;
/*! \brief constructor */
Analyzer();
/*!
* \brief Notify all the sub-analyzers that var
* is created and binded to expr.
*
* Each var can only be binded once.
*
* \param var The variable.
* \param expr The expression we bind to.
*/
void Bind(const VarExpr& var, const Expr& expr);
/*!
* \brief Notify all the sub-analyzers that var
* is created and binded to a range.
*
* Each var can only be binded once.
*
* \param var The variable.
* \param range The range we bind to.
*/
void Bind(const VarExpr& var, const Range& range);
/*!
* \brief Whether can we proof expr >= val.
* Non-negative proof is very useful in integer analysis
* to lower divisions and mods given difference in trunc and ceil mode.
*
* \param expr The expression.
* \param lower_bound The lower bound.
* \return Whether we can proof it.
*
* \note Analyzer will call into sub-analyzers to get the result.
*/
bool CanProveGreaterEqual(const Expr& expr, int64_t lower_bound);
};
//-----------------------------------------------
// Integer set abstraction API.
// Integer set data structure.
//
// This is a API build on top of the base
// integer analysis API to provide set analysis.
//------------------------------------------------
/*!
* \brief Sign of an expression or set.
* \brief Sign type of an integer expression.
*/
enum SignType {
kPositive,
......@@ -401,8 +344,13 @@ enum SignType {
kUnknown
};
// internal node container of int set.
struct IntSetNode;
/*!
* \brief Base class of all IntSet containers.
*/
struct IntSetNode : public Node {
static constexpr const char* _type_key = "IntSet";
TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Node);
};
/*!
* \brief Integer set class, represent a set of integers in one dimension.
......@@ -424,11 +372,6 @@ class IntSet : public NodeRef {
* \return The covering range.
*/
Range cover_range(Range max_range) const;
/*!
* \brief find an interval that covers the set.
* \return The covering interval set.
*/
IntSet cover_interval() const;
/*! \return Lower bound of the set */
Expr min() const;
/*! \return upper bound of the set */
......@@ -493,33 +436,91 @@ class IntSet : public NodeRef {
};
/*!
* \brief Base class of all IntSet containers.
* \brief Integer set analyzer.
*/
struct IntSetNode : public Node {
static constexpr const char* _type_key = "IntSet";
TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Node);
class IntSetAnalyzer {
public:
/*!
* \brief Find a symbolic integer set that contains all possible values of
* expr given the domain of each variables.
*
* \param expr The expression of interest.
* \param dom_map The domain map to indicate which variable to relax.
* \return the result of the analysis.
*/
IntSet operator()(const Expr& expr, const Map<Var, IntSet>& dom_map);
private:
friend class Analyzer;
explicit IntSetAnalyzer(Analyzer* parent);
~IntSetAnalyzer();
class Impl;
/*! \brief Internal impl */
Impl* impl_;
};
/*!
* \brief Detect if e can be rewritten as e = sum_{i=0}^{n-1} var[i] * coeff[i] + coeff[n]
* Where coeff[i] and base are invariant of var[j] for all i and j.
* \brief Analyzer that contains bunch of sub-analyzers.
*
* \param e The expression to be detected.
* \param vars List of variables to be used in detection.
* \return [coeff[i]] if it is possible, empty array if it is not.
* Each sub-analyzer can make use of another sub-analyzer
* by weak reference of this.
*
* NOTE for sub-analyzer developers:
* If the analyzer uses memoization, we need to clear the internal
* cache when information about a Var has been overridden.
*/
class Analyzer {
public:
/*! \brief sub-analyzer: const integer bound */
ConstIntBoundAnalyzer const_int_bound;
/*! \brief sub-analyzer: modular set */
ModularSetAnalyzer modular_set;
/*! \brief sub-analyzer rewrite simplify */
RewriteSimplifier rewrite_simplify;
/*! \brief sub-analyzer canonical simplify */
CanonicalSimplifier canonical_simplify;
/*! \brief sub-analyzer: int set */
IntSetAnalyzer int_set;
/*! \brief constructor */
Analyzer();
/*!
* \brief Notify all the sub-analyzers that var
* is created and binded to expr.
*
* Each var can only be binded once.
*
* \param var The variable.
* \param expr The expression we bind to.
*/
void Bind(const VarExpr& var, const Expr& expr);
/*!
* \brief Notify all the sub-analyzers that var
* is created and binded to a range.
*
* Each var can only be binded once.
*
* \param var The variable.
* \param range The range we bind to.
*/
Array<Expr> DetectLinearEquation(const Expr& e, const Array<Var>& vars);
void Bind(const VarExpr& var, const Range& range);
/*!
* \brief Whether can we prove expr >= val.
/*!
* \brief Detect if expression corresponds to clip bound of the vars
* Non-negative proof is very useful in integer analysis
* to lower divisions and mods given difference in trunc and ceil mode.
*
* \param e The expression to be detected.
* \param vars List of variables to be used in detection.
* \return concat([min_value[i], max_value[i]]), None is returned if there is no min or max value
* return empty if the e does not match the pattern.
* \param expr The expression.
* \param lower_bound The lower bound.
* \return Whether we can prove it.
*
* \note Analyzer will call into sub-analyzers to get the result.
*/
Array<Expr> DetectClipBound(const Expr& e, const Array<Var>& vars);
bool CanProveGreaterEqual(const Expr& expr, int64_t lower_bound);
};
//-----------------------------------------------
// Integer set legacy API.
//------------------------------------------------
/*!
* \brief Find an symbolic integer set that contains all possible values of
* e given the domain of each iteration variables.
......@@ -638,6 +639,29 @@ IntSet DeduceBound(Expr v, Expr cond,
*/
Domain DomainTouched(Stmt body, const Tensor &tensor, bool consider_calls, bool consider_provides);
// Expression pattern detector.
/*!
* \brief Detect if e can be rewritten as e = sum_{i=0}^{n-1} var[i] * coeff[i] + coeff[n]
* Where coeff[i] and base are invariant of var[j] for all i and j.
*
* \param e The expression to be detected.
* \param vars List of variables to be used in detection.
* \return [coeff[i]] if it is possible, empty array if it is not.
*/
Array<Expr> DetectLinearEquation(const Expr& e,
const Array<Var>& vars);
/*!
* \brief Detect if expression corresponds to clip bound of the vars
*
* \param e The expression to be detected.
* \param vars List of variables to be used in detection.
* \return concat([min_value[i], max_value[i]]), None is returned if there is no min or max value
* return empty if the e does not match the pattern.
*/
Array<Expr> DetectClipBound(const Expr& e,
const Array<Var>& vars);
// implementation
inline const IntSetNode* IntSet::operator->() const {
return static_cast<const IntSetNode*>(node_.get());
......
......@@ -32,21 +32,21 @@ class IntSet(NodeBase):
return _api_internal._IntSetIsEverything(self)
@register_node
@register_node("arith.IntervalSet")
class IntervalSet(IntSet):
"""Represent set of continuous interval"""
def min(self):
"""get the minimum value"""
return _api_internal._IntervalSetGetMin(self)
def max(self):
"""get the maximum value"""
return _api_internal._IntervalSetGetMax(self)
"""Represent set of continuous interval [min_value, max_value]
Parameters
----------
min_value : Expr
The minimum value in the interval.
@register_node
class StrideSet(IntSet):
"""Represent set of strided integers"""
max_value : Expr
The maximum value in the interval.
"""
def __init__(self, min_value, max_value):
self.__init_handle_by_constructor__(
_make_IntervalSet, min_value, max_value)
@register_node("arith.ModularSet")
......@@ -114,6 +114,7 @@ class Analyzer:
self._modular_set = _mod("modular_set")
self._rewrite_simplify = _mod("rewrite_simplify")
self._canonical_simplify = _mod("canonical_simplify")
self._int_set = _mod("int_set")
self._enter_constraint_context = _mod("enter_constraint_context")
def const_int_bound(self, expr):
......@@ -176,6 +177,24 @@ class Analyzer:
"""
return self._canonical_simplify(expr)
def int_set(self, expr, dom_map):
"""Compute a symbolic IntSet that covers expr for all values in dom_map.
Parameters
----------
expr : tvm.Expr
The expression.
dom_map : Dict[Var, tvm.arith.IntSet]
The domain for variables to be relaxed.
Returns
-------
result : IntSet
The result.
"""
return self._int_set(expr, dom_map)
def bind(self, var, expr):
"""Bind a variable to the expression.
......
......@@ -39,6 +39,7 @@ TVM_REGISTER_API("arith.intset_vector")
TVM_REGISTER_API("arith.intset_interval")
.set_body_typed(IntSet::interval);
TVM_REGISTER_API("arith.DetectLinearEquation")
.set_body_typed(DetectLinearEquation);
......@@ -110,6 +111,10 @@ TVM_REGISTER_API("arith._CreateAnalyzer")
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
*ret = self->canonical_simplify(args[0]);
});
} else if (name == "int_set") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
*ret = self->int_set(args[0], args[1]);
});
} else if (name == "bind") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
auto& sptr = args[1].node_sptr();
......
......@@ -31,7 +31,8 @@ Analyzer::Analyzer()
: const_int_bound(this),
modular_set(this),
rewrite_simplify(this),
canonical_simplify(this) {
canonical_simplify(this),
int_set(this) {
}
void Analyzer::Bind(const VarExpr& v, const Expr& expr) {
......@@ -74,7 +75,7 @@ void ConstraintContext::ExitWithScope() {
bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) {
if (const auto* ptr = expr.as<ir::IntImm>()) {
return ptr->value > lower_bound;
return ptr->value >= lower_bound;
}
auto bd = this->const_int_bound(this->rewrite_simplify(expr));
if (bd->min_value >= lower_bound) return true;
......
......@@ -30,12 +30,12 @@
#include <unordered_set>
#include <unordered_map>
#include "int_set.h"
namespace tvm {
namespace arith {
using namespace ir;
using HalideIR::Internal::Interval;
// a visitor to find the path to the target variable
// from a expression.
......@@ -293,7 +293,7 @@ IntSet DeduceBound(Expr v, Expr e,
BoundDeducer d(v, e, hint_map, relax_map);
d.Deduce();
if (!d.success) return IntSet::nothing();
Expr min = Interval::neg_inf, max = Interval::pos_inf;
Expr min = neg_inf(), max = pos_inf();
if (d.is_greater) {
min = d.result;
} else {
......
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2019 by Contributors
* \file canonical_simplify.cc
* \brief Canonical form based simplification.
*/
......@@ -763,7 +762,10 @@ Mutate_(const Mod* op, const Expr& self) {
if (TryCompare(temp, cval) == kLT) {
return temp;
} else {
return SplitModConst(ToSplitExpr(temp), cval);
// contonue to use logic below.
a = extra;
psum = a.as<SumExprNode>();
CHECK(psum != nullptr);
}
}
}
......
......@@ -27,8 +27,8 @@
#define TVM_ARITHMETIC_COMPUTE_EXPR_H_
#include <tvm/ir.h>
#include <arithmetic/Interval.h>
#include <limits>
#include <algorithm>
namespace tvm {
namespace arith {
......@@ -105,12 +105,12 @@ inline Expr ComputeExpr<ir::Mod>(Expr a, Expr b) {
template<>
inline Expr ComputeExpr<ir::Max>(Expr a, Expr b) {
return HalideIR::Internal::Interval::make_max(a, b);
return max(a, b);
}
template<>
inline Expr ComputeExpr<ir::Min>(Expr a, Expr b) {
return HalideIR::Internal::Interval::make_min(a, b);
return min(a, b);
}
template<typename Op>
......
......@@ -206,6 +206,7 @@ inline Expr TryConstFold<ir::Min>(Expr a, Expr b) {
if (pa && pb) return IntImm::make(rtype, std::min(pa->value, pb->value));
if (fa && fb) return FloatImm::make(rtype, std::min(fa->value, fb->value));
});
if (a.same_as(b)) return a;
return Expr();
}
......@@ -216,6 +217,7 @@ inline Expr TryConstFold<ir::Max>(Expr a, Expr b) {
if (pa && pb) return IntImm::make(rtype, std::max(pa->value, pb->value));
if (fa && fb) return FloatImm::make(rtype, std::max(fa->value, fb->value));
});
if (a.same_as(b)) return a;
return Expr();
}
......@@ -307,6 +309,58 @@ inline Expr TryConstFold<ir::Not>(Expr a) {
return Expr();
}
/*! \brief Helper namespace for symbolic value limits */
struct SymbolicLimits {
/*! \brief positive infinity */
static Expr pos_inf_;
/*! \brief negative infinity */
static Expr neg_inf_;
};
/*!
* \brief Opaque expression representing positive infinity.
*
* It can can only be used as parameter of by min/max
* for integer analysis and cannot be used in normal expressions.
*
* \return positive infinity.
*/
inline Expr pos_inf() {
return SymbolicLimits::pos_inf_;
}
/*!
* \brief Check if value is positive infinity.
* \param value The value to be checked.
*
* \return The check result.
*/
inline bool is_pos_inf(const Expr& value) {
return value.same_as(SymbolicLimits::pos_inf_);
}
/*!
* \brief Opaque expression representing negative infinity.
*
* It can can only be used as parameter of by min/max
* for integer analysis and cannot be used in normal expressions.
*
* \return negative infinity.
*/
inline Expr neg_inf() {
return SymbolicLimits::neg_inf_;
}
/*!
* \brief Check if value is negative infinity.
* \param value The value to be checked.
*
* \return The check result.
*/
inline bool is_neg_inf(const Expr& value) {
return value.same_as(SymbolicLimits::neg_inf_);
}
} // namespace arith
} // namespace tvm
#endif // TVM_ARITHMETIC_CONST_FOLD_H_
......@@ -19,8 +19,8 @@
/*!
* Copyright (c) 2017 by Contributors
* \file bound_deducer.cc
* \brief Utility to deduce bound of expression
* \file detect_linear_equation.cc
* \brief Utility to detect patterns in the expression.
*/
#include <tvm/expr.h>
#include <tvm/ir_pass.h>
......
......@@ -18,201 +18,55 @@
*/
/*!
* Copyright (c) 2017 by Contributors
* \file int_set.cc
* \brief The integer set functions
*/
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/arithmetic.h>
#include <tvm/ir_functor_ext.h>
#include <arithmetic/Interval.h>
#include <tvm/api_registry.h>
#include <utility>
#include <algorithm>
#include <unordered_map>
#include "compute_expr.h"
#include "int_set_internal.h"
#include "int_set.h"
#include "pattern_match.h"
namespace tvm {
namespace arith {
using HalideIR::Internal::Interval;
using namespace ir;
Expr SymbolicLimits::pos_inf_ = Var("pos_inf", Handle());
Expr SymbolicLimits::neg_inf_ = Var("neg_inf", Handle());
inline IntSet IntSet::cover_interval() const {
if ((*this).as<IntervalSet>()) return *this;
const StrideSet* s = (*this).as<StrideSet>();
if (s) {
CHECK_NE(s->extents.size(), 0U);
Expr max = s->base.max;
for (size_t i = 0; i < s->extents.size(); ++i) {
max = max + s->extents[i] * s->strides[i] - s->strides[i];
}
return IntervalSet::make(s->base.min, Simplify(max));
}
LOG(FATAL) << "cannot convert set " << (*this)->type_key() << " to interval";
return IntSet::everything();
IntervalSet::IntervalSet(Expr min_value, Expr max_value) {
auto node = make_node<IntervalSetNode>();
node->min_value = std::move(min_value);
node->max_value = std::move(max_value);
node_ = std::move(node);
}
Range IntSet::cover_range(Range max_range) const {
IntSet temp;
const IntervalSet* s_int = (*this).as<IntervalSet>();
if (s_int == nullptr) {
temp = this->cover_interval();
s_int = temp.as<IntervalSet>();
}
if (s_int->i.is_bounded()) {
return Range::make_by_min_extent(
s_int->i.min, Simplify(s_int->i.max + 1 - s_int->i.min));
}
return max_range;
IntervalSet MakeIntervalSet(Expr min_value, Expr max_value) {
return IntervalSet(min_value, max_value);
}
Expr IntSet::min() const {
const IntervalSet* s_int = (*this).as<IntervalSet>();
CHECK(s_int);
return s_int->i.min;
}
Expr IntSet::max() const {
const IntervalSet* s_int = (*this).as<IntervalSet>();
CHECK(s_int);
return s_int->i.max;
}
TVM_REGISTER_API("arith._make_IntervalSet")
.set_body_typed(MakeIntervalSet);
bool IntSet::is_nothing() const {
const IntervalSet* s_int = (*this).as<IntervalSet>();
return (s_int && s_int->i.is_empty());
}
bool IntSet::is_everything() const {
const IntervalSet* s_int = (*this).as<IntervalSet>();
return (s_int && s_int->i.is_everything());
}
bool IntSet::is_single_point() const {
const IntervalSet* s_int = (*this).as<IntervalSet>();
return (s_int && s_int->i.is_single_point());
}
bool IntSet::can_prove_positive() const {
const IntervalSet* s_int = (*this).as<IntervalSet>();
return (s_int && is_positive_const(ir::Simplify(s_int->i.min)));
}
bool IntSet::can_prove_negative() const {
const IntervalSet* s_int = (*this).as<IntervalSet>();
return (s_int && is_negative_const(ir::Simplify(s_int->i.max)));
}
bool IntSet::can_prove_non_positive() const {
if (const IntervalSet* s_int = (*this).as<IntervalSet>()) {
auto max = ir::Simplify(s_int->i.max);
return is_zero(max) || is_negative_const(max);
}
return false;
}
bool IntSet::can_prove_non_negative() const {
if (const IntervalSet* s_int = (*this).as<IntervalSet>()) {
// Any reason why we should or should not use can_prove() to implement
// these functions?
auto min = ir::Simplify(s_int->i.min);
return is_zero(min) || is_positive_const(min);
}
return false;
}
SignType IntSet::sign_type() const {
if (can_prove_positive()) {
return kPositive;
} else if (can_prove_negative()) {
return kNegative;
} else if (is_single_point() && is_zero(point_value())) {
return kZero;
IntervalSet Intersect(Analyzer* analyzer, IntervalSet a, IntervalSet b) {
Expr max_value = min(a->max_value, b->max_value);
Expr min_value = max(a->min_value, b->min_value);
if ((max_value.type().is_int() || max_value.type().is_uint()) &&
(min_value.type().is_int() || min_value.type().is_uint()) &&
analyzer->CanProveGreaterEqual(min_value - max_value, 1)) {
return IntervalSet::Empty();
} else {
return kUnknown;
}
}
Expr IntSet::point_value() const {
const IntervalSet* s_int = (*this).as<IntervalSet>();
CHECK(s_int && s_int->i.is_single_point());
return s_int->i.min;
}
IntSet IntSet::nothing() {
return IntervalSet::make(Interval::nothing());
}
IntSet IntSet::everything() {
return IntervalSet::make(Interval::everything());
}
IntSet IntSet::single_point(Expr x) {
return IntervalSet::make(Interval::single_point(x));
}
IntSet IntSet::range(Range r) {
// must make sure it can be matched back by MatchRange.
if (is_one(r->extent)) {
return IntSet::single_point(r->min);
}
if (is_positive_const(r->extent) && is_const(r->min)) {
return IntervalSet::make(
r->min, ComputeExpr<Sub>(ComputeExpr<Add>(r->extent, r->min), 1));
}
return IntervalSet::make(r->min, (r->extent + r->min) - 1);
}
IntSet IntSet::interval(Expr min, Expr max) {
if (min.same_as(max)) {
return IntSet::single_point(min);
return IntervalSet(min_value, max_value);
}
return IntervalSet::make(min, max);
}
inline bool prove_equal(Expr lhs, Expr rhs) {
return is_zero(ir::Simplify(lhs - rhs));
}
// Check if a is created from b.
bool IntSet::match_range(const Range& b) const {
const IntSet& a = *this;
const IntervalSet* a_int = a.as<IntervalSet>();
if (!a_int) return false;
const Interval& i = a_int->i;
return prove_equal(i.min, b->min) &&
prove_equal(i.max, ComputeExpr<Sub>(ComputeExpr<Add>(b->extent, b->min), 1));
}
inline bool MatchPoint(const IntSet& a,
const Expr& b) {
const IntervalSet* a_int = a.as<IntervalSet>();
if (!a_int) return false;
const Interval& i = a_int->i;
return i.is_single_point() && i.min.same_as(b);
}
IntSet Union(const Array<IntSet>& sets) {
if (sets.size() == 0) return IntSet::nothing();
if (sets.size() == 1) return sets[0];
Interval x = sets[0].cover_interval().as<IntervalSet>()->i;
for (size_t i = 1; i < sets.size(); ++i) {
IntSet s = sets[i].cover_interval();
const Interval& y = s.as<IntervalSet>()->i;
x.include(y);
}
x.max = ir::Simplify(x.max);
x.min = ir::Simplify(x.min);
return IntervalSet::make(x);
}
IntSet Intersect(const Array<IntSet>& sets) {
Interval x = sets[0].cover_interval().as<IntervalSet>()->i;
for (size_t i = 1; i < sets.size(); ++i) {
Interval y = sets[i].cover_interval().as<IntervalSet>()->i;
x = Interval::make_intersection(x, y);
}
return IntervalSet::make(x);
IntervalSet Union(Analyzer* analyzer, IntervalSet a, IntervalSet b) {
Expr max_value = max(a->max_value, b->max_value);
Expr min_value = min(a->min_value, b->min_value);
return IntervalSet(min_value, max_value);
}
// type traits
......@@ -227,407 +81,623 @@ struct is_logical_op {
static const bool value = true; \
};
// interval related.
template<typename OP>
inline IntSet CombineInterval(Interval a, Interval b) {
if (a.is_single_point() && b.is_single_point()) {
return IntSet::single_point(ComputeExpr<OP>(a.min, b.min));
}
LOG(WARNING) << "Return Everything in CombineInterval " << OP::_type_key;
return IntSet::everything();
TVM_DECLARE_LOGICAL_OP(And);
TVM_DECLARE_LOGICAL_OP(Or);
TVM_DECLARE_LOGICAL_OP(EQ);
TVM_DECLARE_LOGICAL_OP(NE);
TVM_DECLARE_LOGICAL_OP(GE);
TVM_DECLARE_LOGICAL_OP(GT);
TVM_DECLARE_LOGICAL_OP(LE);
TVM_DECLARE_LOGICAL_OP(LT);
TVM_DECLARE_LOGICAL_OP(Not);
/*!
* \brief Combine two interval set under arithmetic operations.
* \note this can possibly relax the set.
*/
template<typename Op>
inline IntervalSet Combine(Analyzer* analyzer,
IntervalSet a,
IntervalSet b) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
Expr res = TryConstFold<Op>(a->min_value, b->min_value);
if (!res.defined()) res = Op::make(a->min_value, b->min_value);
return IntervalSet::SinglePoint(res);
}
if (is_logical_op<Op>::value) {
return IntervalSet(make_const(a->min_value.type(), 0),
make_const(a->min_value.type(), 1));
}
if (a->IsEmpty()) return a;
if (b->IsEmpty()) return b;
if (a->IsEverything()) return a;
if (b->IsEverything()) return b;
return IntervalSet::Everything();
}
template<>
inline IntSet CombineInterval<Add>(Interval a, Interval b) {
if (a.is_single_point() && b.is_single_point()) {
return IntSet::single_point(ComputeExpr<Add>(a.min, b.min));
}
Interval r = Interval::everything();
if (a.has_lower_bound() && b.has_lower_bound()) {
r.min = ComputeExpr<Add>(a.min, b.min);
}
if (a.has_upper_bound() && b.has_upper_bound()) {
r.max = ComputeExpr<Add>(a.max, b.max);
}
return IntervalSet::make(r);
inline IntervalSet Combine<ir::Add>(Analyzer* analyer,
IntervalSet a,
IntervalSet b) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(a->min_value + b->min_value);
}
if (a->IsEmpty()) return a;
if (b->IsEmpty()) return b;
Expr min_value =
a->HasLowerBound() && b->HasLowerBound() ?
a->min_value + b->min_value : neg_inf();
Expr max_value =
a->HasUpperBound() && b->HasUpperBound() ?
a->max_value + b->max_value : pos_inf();
return IntervalSet(min_value, max_value);
}
template<>
inline IntSet CombineInterval<Sub>(Interval a, Interval b) {
if (a.is_single_point() && b.is_single_point()) {
return IntSet::single_point(ComputeExpr<Sub>(a.min, b.min));
}
Interval r = Interval::everything();
if (a.has_lower_bound() && b.has_upper_bound()) {
r.min = ComputeExpr<Sub>(a.min, b.max);
}
if (a.has_upper_bound() && b.has_lower_bound()) {
r.max = ComputeExpr<Sub>(a.max, b.min);
inline IntervalSet Combine<ir::Sub>(Analyzer* analyer,
IntervalSet a,
IntervalSet b) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(a->min_value - b->min_value);
}
return IntervalSet::make(r);
if (a->IsEmpty()) return a;
if (b->IsEmpty()) return b;
Expr min_value =
a->HasLowerBound() && b->HasUpperBound() ?
a->min_value - b->max_value : neg_inf();
Expr max_value =
a->HasUpperBound() && b->HasLowerBound() ?
a->max_value - b->min_value : pos_inf();
return IntervalSet(min_value, max_value);
}
template<>
inline IntSet CombineInterval<Mul>(Interval a, Interval b) {
if (a.is_single_point() && b.is_single_point()) {
return IntSet::single_point(ComputeExpr<Mul>(a.min, b.min));
}
if (a.is_single_point() && !b.is_single_point()) {
inline IntervalSet Combine<ir::Mul>(Analyzer* analyzer,
IntervalSet a,
IntervalSet b) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(a->min_value * b->min_value);
}
if (a->IsEmpty()) return a;
if (b->IsEmpty()) return b;
if (a->IsSinglePoint()) {
std::swap(a, b);
}
if (b.is_single_point()) {
if (is_zero(b.min)) return IntSet::single_point(0);
if (is_one(b.min)) return IntervalSet::make(a);
Expr e1 = a.has_lower_bound() ? ComputeExpr<Mul>(a.min, b.min) : a.min;
Expr e2 = a.has_upper_bound() ? ComputeExpr<Mul>(a.max, b.min) : a.max;
// no relaxation is needed in here due to set is inclusive
// TODO(tqchen): consider convert to StrideSet.
if (is_positive_const(b.min)) {
return IntervalSet::make(e1, e2);
} else if (is_negative_const(b.min)) {
return IntervalSet::make(e2, e1);
} else if (a.is_bounded()) {
if (b->IsSinglePoint()) {
if (is_zero(b->min_value)) return b;
if (is_one(b->min_value)) return a;
if (analyzer->CanProveGreaterEqual(b->min_value, 0)) {
Expr min_value = a->HasLowerBound() ? a->min_value * b->min_value : neg_inf();
Expr max_value = a->HasUpperBound() ? a->max_value * b->min_value : pos_inf();
return IntervalSet(min_value, max_value);
} else if (analyzer->CanProveGreaterEqual(-b->min_value, 1)) {
Expr min_value = a->HasUpperBound() ? a->max_value * b->min_value : neg_inf();
Expr max_value = a->HasLowerBound() ? a->min_value * b->min_value : pos_inf();
return IntervalSet(min_value, max_value);
} else if (a->HasUpperBound() && a->HasLowerBound()) {
using ir::Select;
Expr cmp = b.min >= make_zero(b.min.type().element_of());
return IntervalSet::make(Select::make(cmp, e1, e2), Select::make(cmp, e2, e1));
Expr sign = b->min_value >= make_zero(b->min_value.type().element_of());
Expr e1 = a->min_value * b->min_value;
Expr e2 = a->max_value * b->min_value;
return IntervalSet(Select::make(sign, e1, e2), Select::make(sign, e2, e1));
}
}
LOG(WARNING) << "Return Everything in CombineInterval Mul";
return IntSet::everything();
DLOG(WARNING) << "Return Everything in CombineInterval Mul";
return IntervalSet::Everything();
}
template<>
inline IntSet CombineInterval<Div>(Interval a, Interval b) {
if (a.is_single_point() && b.is_single_point()) {
return IntSet::single_point(ComputeExpr<Div>(a.min, b.min));
}
if (b.is_single_point()) {
if (is_zero(b.min)) {
inline IntervalSet Combine<ir::Div>(Analyzer* analyzer,
IntervalSet a,
IntervalSet b) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(a->min_value / b->min_value);
}
if (a->IsEmpty()) return a;
if (b->IsEmpty()) return b;
if (b->IsSinglePoint()) {
if (is_zero(b->min_value)) {
LOG(FATAL) << "Divide by zero in CombineInterval Div";
}
if (is_one(b.min)) return IntervalSet::make(a);
Expr e1 = a.has_lower_bound() ? ComputeExpr<Div>(a.min, b.min) : a.min;
Expr e2 = a.has_upper_bound() ? ComputeExpr<Div>(a.max, b.min) : a.max;
if (is_one(b->min_value)) return a;
// no relaxation is needed in here due to set is inclusive
if (is_positive_const(b.min)) {
return IntervalSet::make(e1, e2);
} else if (is_negative_const(b.min)) {
return IntervalSet::make(e2, e1);
} else if (a.is_bounded()) {
if (analyzer->CanProveGreaterEqual(b->min_value, 0)) {
Expr min_value = a->HasLowerBound() ? a->min_value / b->min_value : neg_inf();
Expr max_value = a->HasUpperBound() ? a->max_value / b->min_value : pos_inf();
return IntervalSet(min_value, max_value);
} else if (analyzer->CanProveGreaterEqual(-b->min_value, 1)) {
Expr min_value = a->HasUpperBound() ? a->max_value / b->min_value : neg_inf();
Expr max_value = a->HasLowerBound() ? a->min_value / b->min_value : pos_inf();
return IntervalSet(min_value, max_value);
} else if (a->HasUpperBound() && a->HasLowerBound()) {
using ir::Select;
Expr cmp = b.min >= make_zero(b.min.type().element_of());
return IntervalSet::make(Select::make(cmp, e1, e2), Select::make(cmp, e2, e1));
Expr sign = b->min_value >= make_zero(b->min_value.type().element_of());
Expr e1 = a->min_value / b->min_value;
Expr e2 = a->max_value / b->min_value;
return IntervalSet(Select::make(sign, e1, e2), Select::make(sign, e2, e1));
}
}
LOG(WARNING) << "Return Everything in CombineInterval Div";
return IntSet::everything();
DLOG(WARNING) << "Return Everything in CombineInterval Div";
return IntervalSet::Everything();
}
template<>
inline IntSet CombineInterval<Mod>(Interval a, Interval b) {
if (a.is_single_point() && b.is_single_point()) {
return IntSet::single_point(ComputeExpr<Mod>(a.min, b.min));
inline IntervalSet Combine<ir::Mod>(Analyzer* analyzer,
IntervalSet a,
IntervalSet b) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(a->min_value % b->min_value);
}
if (b.is_single_point()) {
Expr divisor = b.min;
if (a->IsEmpty()) return a;
if (b->IsEmpty()) return b;
if (b->IsSinglePoint()) {
const Expr& divisor = b->min_value;
if (is_zero(divisor)) {
LOG(FATAL) << "Modular by zero in CombineInterval Mod";
}
return IntervalSet::make(make_zero(divisor.type()), divisor - 1);
// We need to add more bound constraints throughout the code.
// The logic below assumes a is non-negative, which usually
// is the case of our application.
// TODO(tqchen): add bound constraints for a.
if (analyzer->CanProveGreaterEqual(divisor, 0)) {
return IntervalSet(make_zero(divisor.type()), divisor - 1);
} else {
Expr bound = abs(divisor) - 1;
return IntervalSet(-bound, bound);
}
LOG(WARNING) << "Return Everything in CombineInterval Mod";
return IntSet::everything();
}
template<>
inline IntSet CombineInterval<Max>(Interval a, Interval b) {
if (a.is_single_point() && b.is_single_point()) {
return IntSet::single_point(ComputeExpr<Max>(a.min, b.min));
}
return IntervalSet::make(Interval::make_max(a.min, b.min),
Interval::make_max(a.max, b.max));
DLOG(WARNING) << "Return Everything in CombineInterval Mod";
return IntervalSet::Everything();
}
template<>
inline IntSet CombineInterval<Min>(Interval a, Interval b) {
if (a.is_single_point() && b.is_single_point()) {
return IntSet::single_point(ComputeExpr<Min>(a.min, b.min));
inline IntervalSet Combine<ir::Max>(Analyzer* analzyer,
IntervalSet a,
IntervalSet b) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(max(a->min_value, b->min_value));
}
return IntervalSet::make(Interval::make_min(a.min, b.min),
Interval::make_min(a.max, b.max));
}
template<typename OP>
inline IntSet CombineInterval_(IntSet a, IntSet b) {
return CombineInterval<OP>(
a.as<IntervalSet>()->i, b.as<IntervalSet>()->i);
}
// stride related
inline IntSet AsStrideSet(IntSet a) {
if (a.as<StrideSet>()) return a;
const IntervalSet* s = a.as<IntervalSet>();
CHECK(s->i.is_bounded());
NodePtr<StrideSet> n = make_node<StrideSet>();
n->base = s->i;
return IntSet(n);
}
template<typename OP>
inline IntSet CombineSets(IntSet a, IntSet b) {
return CombineInterval_<OP>(a.cover_interval(), b.cover_interval());
if (a->IsEmpty()) return a;
if (b->IsEmpty()) return b;
return IntervalSet(max(a->min_value, b->min_value),
max(a->max_value, b->max_value));
}
template<>
inline IntSet CombineSets<Add>(IntSet a, IntSet b) {
const IntervalSet* a_int = a.as<IntervalSet>();
const IntervalSet* b_int = b.as<IntervalSet>();
if (a_int && is_zero(a_int->i.min)) return b;
if (b_int && is_zero(b_int->i.min)) return a;
a = AsStrideSet(a);
b = AsStrideSet(b);
const StrideSet* a_stride = a.as<StrideSet>();
const StrideSet* b_stride = b.as<StrideSet>();
auto n = make_node<StrideSet>(*a_stride);
for (size_t i = 0; i < b_stride->extents.size(); ++i) {
n->extents.push_back(b_stride->extents[i]);
n->strides.push_back(b_stride->strides[i]);
}
n->base = CombineInterval<Add>(
a_stride->base, b_stride->base).as<IntervalSet>()->i;
return IntSet(n);
}
inline IntSet NegateSet(IntSet a) {
const IntervalSet* a_int = a.as<IntervalSet>();
if (a_int) {
if (a_int->i.is_single_point()) {
return IntSet::single_point(-a_int->i.min);
} else {
Interval r = Interval::everything();
if (a_int->i.has_upper_bound()) {
r.min = -(a_int->i.max);
}
if (a_int->i.has_lower_bound()) {
r.max = -(a_int->i.min);
}
return IntervalSet::make(r);
inline IntervalSet Combine<ir::Min>(Analyzer* analzyer,
IntervalSet a,
IntervalSet b) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(min(a->min_value, b->min_value));
}
} else {
return NegateSet(a.cover_interval());
}
}
template<>
inline IntSet CombineSets<Sub>(IntSet a, IntSet b) {
return CombineSets<Add>(a, NegateSet(b));
if (a->IsEmpty()) return a;
if (b->IsEmpty()) return b;
return IntervalSet(min(a->min_value, b->min_value),
min(a->max_value, b->max_value));
}
TVM_DECLARE_LOGICAL_OP(And);
TVM_DECLARE_LOGICAL_OP(Or);
TVM_DECLARE_LOGICAL_OP(EQ);
TVM_DECLARE_LOGICAL_OP(NE);
TVM_DECLARE_LOGICAL_OP(GE);
TVM_DECLARE_LOGICAL_OP(GT);
TVM_DECLARE_LOGICAL_OP(LE);
TVM_DECLARE_LOGICAL_OP(LT);
TVM_DECLARE_LOGICAL_OP(Not);
// generic combine operations of two sets
template<typename OP>
inline IntSet Combine(const IntSet& a, const IntSet &b) {
if (is_logical_op<OP>::value) {
return IntervalSet::make(0, 1);
}
const IntervalSet* a_int = a.as<IntervalSet>();
const IntervalSet* b_int = b.as<IntervalSet>();
if (a_int && a_int->i.is_everything()) return a;
if (b_int && b_int->i.is_everything()) return b;
if (a_int && b_int) {
return CombineInterval<OP>(a_int->i, b_int->i);
}
if (a_int && !(a_int->i.is_bounded())) {
return CombineInterval_<OP>(a, b.cover_interval());
// internal helper function to get an interval set
IntervalSet ToIntervalSet(IntSet set) {
if (auto* node = set.as<IntervalSetNode>()) {
return GetRef<IntervalSet>(node);
}
if (b_int && !(b_int->i.is_bounded())) {
return CombineInterval_<OP>(a.cover_interval(), b);
}
return CombineSets<OP>(a, b);
DLOG(INFO) << "cannot resolve int set " << set;
return IntervalSet::Everything();
}
class IntSetEvaluator :
public ExprFunctor<IntSet(const Expr&, const Expr&)> {
using namespace ir;
// Simplified version of int set evaluator that operates on IntervalSet
// We might use better set analysis in the future to replace the intervalset.
class IntervalSetEvaluator :
public ExprFunctor<IntervalSet(const Expr&)> {
public:
explicit IntSetEvaluator(
const std::unordered_map<const Variable*, IntSet>& dom_map,
IntervalSetEvaluator(Analyzer* analyzer,
const Map<Var, IntSet>& dom_map,
bool eval_vec = false)
: dom_map_(dom_map), eval_vec_(eval_vec) {}
// Evaluate.
IntSet Eval(const Expr& e) {
return this->VisitExpr(e, e);
: analyzer_(analyzer),
dom_map_(dom_map),
eval_vec_(eval_vec) {
}
IntSet VisitExpr_(const IntImm* op, const Expr& e) final {
return IntSet::single_point(e);
IntervalSet Eval(const Expr& val) {
return this->VisitExpr(val);
}
IntSet VisitExpr_(const UIntImm* op, const Expr& e) final {
return IntSet::single_point(e);
IntervalSet VisitExpr_(const IntImm* op) final {
return IntervalSet::SinglePoint(GetRef<Expr>(op));
}
IntervalSet VisitExpr_(const UIntImm* op) final {
return IntervalSet::SinglePoint(GetRef<Expr>(op));
}
IntSet VisitExpr_(const Variable* op, const Expr& e) final {
auto it = dom_map_.find(op);
IntervalSet VisitExpr_(const Variable* op) final {
Var var = GetRef<Var>(op);
auto it = dom_map_.find(var);
if (it != dom_map_.end()) {
return it->second;
return ToIntervalSet((*it).second);
} else {
return IntSet::single_point(e);
return IntervalSet::SinglePoint(var);
}
}
IntSet VisitExpr_(const Add* op, const Expr& e) final {
return Binary(op, e);
IntervalSet VisitExpr_(const Add* op) final {
return VisitBinaryExpr_(op);
}
IntSet VisitExpr_(const Sub* op, const Expr& e) final {
return Binary(op, e);
IntervalSet VisitExpr_(const Sub* op) final {
return VisitBinaryExpr_(op);
}
IntSet VisitExpr_(const Mul* op, const Expr& e) final {
return Binary(op, e);
IntervalSet VisitExpr_(const Mul* op) final {
return VisitBinaryExpr_(op);
}
IntSet VisitExpr_(const Div* op, const Expr& e) final {
return Binary(op, e);
IntervalSet VisitExpr_(const Div* op) final {
return VisitBinaryExpr_(op);
}
IntSet VisitExpr_(const Mod* op, const Expr& e) final {
return Binary(op, e);
IntervalSet VisitExpr_(const Mod* op) final {
return VisitBinaryExpr_(op);
}
IntSet VisitExpr_(const Min* op, const Expr& e) final {
return Binary(op, e);
IntervalSet VisitExpr_(const Min* op) final {
return VisitBinaryExpr_(op);
}
IntSet VisitExpr_(const Max* op, const Expr& e) final {
return Binary(op, e);
IntervalSet VisitExpr_(const Max* op) final {
return VisitBinaryExpr_(op);
}
IntSet VisitExpr_(const EQ* op, const Expr& e) final {
return Binary(op, e);
IntervalSet VisitExpr_(const EQ* op) final {
return VisitBinaryExpr_(op);
}
IntSet VisitExpr_(const NE* op, const Expr& e) final {
return Binary(op, e);
IntervalSet VisitExpr_(const NE* op) final {
return VisitBinaryExpr_(op);
}
IntSet VisitExpr_(const LT* op, const Expr& e) final {
return Binary(op, e);
IntervalSet VisitExpr_(const LT* op) final {
return VisitBinaryExpr_(op);
}
IntSet VisitExpr_(const LE* op, const Expr& e) final {
return Binary(op, e);
IntervalSet VisitExpr_(const LE* op) final {
return VisitBinaryExpr_(op);
}
IntSet VisitExpr_(const GT* op, const Expr& e) final {
return Binary(op, e);
IntervalSet VisitExpr_(const GT* op) final {
return VisitBinaryExpr_(op);
}
IntSet VisitExpr_(const GE* op, const Expr& e) final {
return Binary(op, e);
IntervalSet VisitExpr_(const GE* op) final {
return VisitBinaryExpr_(op);
}
IntSet VisitExpr_(const And* op, const Expr& e) final {
return Binary(op, e);
IntervalSet VisitExpr_(const And* op) final {
return VisitBinaryExpr_(op);
}
IntSet VisitExpr_(const Or* op, const Expr& e) final {
return Binary(op, e);
IntervalSet VisitExpr_(const Or* op) final {
return VisitBinaryExpr_(op);
}
IntSet VisitExpr_(const Ramp* op, const Expr& e) final {
IntervalSet VisitExpr_(const Ramp* op) final {
CHECK(eval_vec_);
IntSet base = Eval(op->base);
int vstride;
if (GetConstInt(op->stride, &vstride)) {
IntervalSet base = Eval(op->base);
PVar<Integer> stride;
if (stride.Match(op->stride)) {
Type t = op->base.type();
if (vstride > 0) {
int64_t vstride = stride.Eval()->value;
if (vstride> 0) {
return Combine<Add>(
analyzer_,
base,
IntSet::interval(make_zero(t),
make_const(t, vstride * op->lanes -1)));
IntervalSet(make_zero(t), make_const(t, vstride * op->lanes - 1)));
} else {
return Combine<Add>(
analyzer_,
base,
IntSet::interval(make_const(t, vstride * op->lanes + 1),
make_zero(t)));
IntervalSet(make_const(t, vstride * op->lanes + 1), make_zero(t)));
}
}
LOG(WARNING) << "cannot evaluate set on expression " << e;
return IntSet::everything();
DLOG(WARNING) << "cannot evaluate set on expression " << GetRef<Expr>(op);
return IntervalSet::Everything();
}
IntSet VisitExpr_(const Broadcast* op, const Expr& e) final {
IntervalSet VisitExpr_(const Broadcast* op) final {
CHECK(eval_vec_);
return Eval(op->value);
return VisitExpr(op->value);
}
IntSet VisitExpr_(const Select* op, const Expr& e) final {
IntSet true_set = this->Eval(op->true_value);
IntSet false_set = this->Eval(op->false_value);
return Union({false_set, true_set});
IntervalSet VisitExpr_(const Select* op) final {
IntervalSet true_set = this->Eval(op->true_value);
IntervalSet false_set = this->Eval(op->false_value);
return Union(analyzer_, false_set, true_set);
}
IntSet VisitExprDefault_(const Node* op, const Expr& e) final {
LOG(WARNING) << "cannot evaluate set type " << e->type_key();
return IntSet::everything();
IntervalSet VisitExprDefault_(const Node* op) final {
DLOG(WARNING) << "cannot evaluate set type " << op->type_key();
return IntervalSet::Everything();
}
private:
// whether set is exactly single point that equals value.
bool MatchPoint(const IntervalSet& set,
const Expr& value) const {
return set->min_value.same_as(value) && set->max_value.same_as(value);
}
template<typename T>
inline IntSet Binary(const T* op, const Expr& e) {
IntSet a = this->Eval(op->a);
IntSet b = this->Eval(op->b);
inline IntervalSet VisitBinaryExpr_(const T* op) {
IntervalSet a = this->Eval(op->a);
IntervalSet b = this->Eval(op->b);
if (MatchPoint(a, op->a) && MatchPoint(b, op->b)) {
return IntSet::single_point(e);
return IntervalSet::SinglePoint(GetRef<Expr>(op));
}
return Combine<T>(a, b);
return Combine<T>(analyzer_, a, b);
}
const std::unordered_map<const Variable*, IntSet>& dom_map_;
Analyzer* analyzer_;
const Map<Var, IntSet>& dom_map_;
bool eval_vec_{false};
};
IntSet EvalSet(Expr e,
class IntSetAnalyzer::Impl {
public:
explicit Impl(Analyzer* analyzer)
: analyzer_(analyzer) {
}
IntSet Eval(const Expr& expr, const Map<Var, IntSet>& dom_map) const {
return IntervalSetEvaluator(analyzer_, dom_map).Eval(expr);
}
private:
Analyzer* analyzer_;
};
IntSetAnalyzer::IntSetAnalyzer(Analyzer* parent)
: impl_(new Impl(parent)) {
}
IntSetAnalyzer::~IntSetAnalyzer() {
delete impl_;
}
IntSet IntSetAnalyzer::operator()(const Expr& expr,
const Map<Var, IntSet>& dom_map) {
return impl_->Eval(expr, dom_map);
}
// Quickly adapt to IntSet interface
// TODO(tqchen): revisit IntSet interface as well.
Range IntSet::cover_range(Range max_range) const {
IntSet temp;
const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
CHECK(s_int != nullptr);
if (s_int->HasUpperBound() && s_int->HasLowerBound()) {
return Range::make_by_min_extent(
s_int->min_value, Simplify(s_int->max_value + 1 - s_int->min_value));
}
return max_range;
}
Expr IntSet::min() const {
const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
CHECK(s_int);
return s_int->min_value;
}
Expr IntSet::max() const {
const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
CHECK(s_int);
return s_int->max_value;
}
bool IntSet::is_nothing() const {
const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
return (s_int && s_int->IsEmpty());
}
bool IntSet::is_everything() const {
const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
return (s_int && s_int->IsEverything());
}
bool IntSet::is_single_point() const {
const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
return (s_int && s_int->IsSinglePoint());
}
bool IntSet::can_prove_positive() const {
const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
return (s_int && is_positive_const(ir::Simplify(s_int->min_value)));
}
bool IntSet::can_prove_negative() const {
const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
return (s_int && is_negative_const(ir::Simplify(s_int->max_value)));
}
bool IntSet::can_prove_non_positive() const {
if (const auto* s_int = (*this).as<IntervalSetNode>()) {
auto max = ir::Simplify(s_int->max_value);
return is_zero(max) || is_negative_const(max);
}
return false;
}
bool IntSet::can_prove_non_negative() const {
if (const IntervalSetNode* s_int = (*this).as<IntervalSetNode>()) {
auto min = ir::Simplify(s_int->min_value);
return is_zero(min) || is_positive_const(min);
}
return false;
}
SignType IntSet::sign_type() const {
if (can_prove_positive()) {
return kPositive;
} else if (can_prove_negative()) {
return kNegative;
} else if (is_single_point() && is_zero(point_value())) {
return kZero;
} else {
return kUnknown;
}
}
Expr IntSet::point_value() const {
const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
CHECK(s_int && s_int->IsSinglePoint());
return s_int->min_value;
}
IntSet IntSet::nothing() {
return IntervalSet::Empty();
}
IntSet IntSet::everything() {
return IntervalSet::Everything();
}
IntSet IntSet::single_point(Expr x) {
return IntervalSet::SinglePoint(x);
}
IntSet IntSet::interval(Expr min, Expr max) {
if (min.same_as(max)) {
return IntSet::single_point(min);
}
return IntervalSet(min, max);
}
// Range related code
inline bool ProveEqual(Expr lhs, Expr rhs) {
return is_zero(ir::Simplify(lhs - rhs));
}
IntSet IntSet::range(Range r) {
// must make sure it can be matched back by MatchRange.
if (is_one(r->extent)) {
return IntSet::single_point(r->min);
}
return IntervalSet(r->min, r->extent + r->min - 1);
}
bool IntSet::match_range(const Range& b) const {
const IntSet& a = *this;
const IntervalSetNode* a_int = a.as<IntervalSetNode>();
if (!a_int) return false;
return ProveEqual(a_int->min_value, b->min) &&
ProveEqual(a_int->max_value, b->extent + b->min - 1);
}
IntSet Union(const Array<IntSet>& sets) {
if (sets.size() == 0) return IntSet::nothing();
if (sets.size() == 1) return sets[0];
Analyzer ana;
IntervalSet x = ToIntervalSet(sets[0]);
for (size_t i = 1; i < sets.size(); ++i) {
x = Union(&ana, x, ToIntervalSet(sets[i]));
}
return IntervalSet(ir::Simplify(x->min_value),
ir::Simplify(x->max_value));
}
IntSet Intersect(const Array<IntSet>& sets) {
if (sets.size() == 0) return IntSet::nothing();
if (sets.size() == 1) return sets[0];
Analyzer ana;
IntervalSet x = ToIntervalSet(sets[0]);
for (size_t i = 1; i < sets.size(); ++i) {
x = Intersect(&ana, x, ToIntervalSet(sets[i]));
}
return IntervalSet(ir::Simplify(x->min_value),
ir::Simplify(x->max_value));
}
Map<Var, IntSet> ConvertDomMap(const Map<IterVar, IntSet>& dom_map) {
Map<Var, IntSet> dmap;
for (auto kv : dom_map) {
dmap.Set(kv.first->var, kv.second);
}
return dmap;
}
Map<Var, IntSet> ConvertDomMap(
const std::unordered_map<const Variable*, IntSet>& dom_map) {
return IntSetEvaluator(dom_map, false).Eval(e);
Map<Var, IntSet> dmap;
for (auto kv : dom_map) {
dmap.Set(GetRef<Var>(kv.first), kv.second);
}
return dmap;
}
IntSet EvalSet(Expr e,
const Map<Var, IntSet>& dom_map) {
Analyzer ana;
return IntervalSetEvaluator(&ana, dom_map, false).Eval(e);
}
IntSet IntSet::vector(Expr x) {
std::unordered_map<const Variable*, IntSet> dmap;
return IntSetEvaluator(dmap, true).Eval(x);
Analyzer ana;
Map<Var, IntSet> dmap;
return IntervalSetEvaluator(&ana, dmap, true).Eval(x);
}
IntSet EvalSet(Expr e,
const Map<IterVar, IntSet>& dom_map) {
std::unordered_map<const Variable*, IntSet> dmap;
for (auto kv : dom_map) {
dmap[kv.first->var.as<Variable>()] = kv.second;
}
return EvalSet(e, dmap);
return EvalSet(e, ConvertDomMap(dom_map));
}
IntSet EvalSet(Range r,
IntSet EvalSet(Expr e,
const std::unordered_map<const Variable*, IntSet>& dom_map) {
IntSetEvaluator m(dom_map);
IntSet min_set = m.Eval(r->min).cover_interval();
return EvalSet(e, ConvertDomMap(dom_map));
}
IntSet EvalSet(Range r,
const Map<Var, IntSet>& dom_map) {
Analyzer ana;
IntervalSetEvaluator m(&ana, dom_map);
IntervalSet min_set = m.Eval(r->min);
// Simplifying first can give tighter bounds if r->min and r->extent share variables
Expr sum = ComputeExpr<Sub>(ComputeExpr<Add>(r->min, r->extent), 1);
IntSet max_set = m.Eval(Simplify(sum)).cover_interval();
const Interval& ni = min_set.as<IntervalSet>()->i;
const Interval& xi = max_set.as<IntervalSet>()->i;
if (!ni.has_lower_bound()) return IntSet::everything();
if (!xi.has_upper_bound()) return IntSet::everything();
return IntervalSet::make(ni.min, xi.max);
Expr sum = r->min + r->extent - 1;
IntervalSet max_set = m.Eval(Simplify(sum));
if (!min_set->HasLowerBound()) return IntSet::everything();
if (!max_set->HasUpperBound()) return IntSet::everything();
return IntervalSet(min_set->min_value, max_set->max_value);
}
IntSet EvalSet(IntSet s,
IntSet EvalSet(Range r,
const std::unordered_map<const Variable*, IntSet>& dom_map) {
IntSetEvaluator m(dom_map);
s = s.cover_interval();
const IntervalSet* s_int = s.as<IntervalSet>();
Expr vmax = s_int->i.has_upper_bound() ?
m.Eval(s_int->i.max).cover_interval().max() : s_int->i.max;
Expr vmin = s_int->i.has_lower_bound() ?
m.Eval(s_int->i.min).cover_interval().min() : s_int->i.min;
return IntervalSet::make(vmin, vmax);
return EvalSet(r, ConvertDomMap(dom_map));
}
class SubExprIntSetEvaluator : public IntSetEvaluator {
IntSet EvalSet(IntSet s,
const std::unordered_map<const Variable*, IntSet>& dom_map) {
Analyzer ana;
auto dmap = ConvertDomMap(dom_map);
IntervalSetEvaluator m(&ana, dmap);
const IntervalSetNode* s_int = s.as<IntervalSetNode>();
Expr vmax = s_int->HasUpperBound() ?
m.Eval(s_int->max_value).max() : s_int->max_value;
Expr vmin = s_int->HasLowerBound() ?
m.Eval(s_int->min_value).min() : s_int->min_value;
return IntervalSet(vmin, vmax);
}
class SubExprIntervalSetEvaluator : public IntervalSetEvaluator {
public:
explicit SubExprIntSetEvaluator(
const std::unordered_map<const Variable*, IntSet>& dom_map)
: IntSetEvaluator(dom_map) {}
explicit SubExprIntervalSetEvaluator(
Analyzer* analyzer,
const Map<Var, IntSet>& dom_map)
: IntervalSetEvaluator(analyzer, dom_map) {}
IntSet VisitExpr(const Expr& n, const Expr& e) final {
IntSet ret = IntSetEvaluator::VisitExpr(n, e);
IntervalSet VisitExpr(const Expr& n) final {
IntervalSet ret = IntervalSetEvaluator::VisitExpr(n);
expr_map[n] = ret;
return ret;
}
......@@ -635,28 +705,26 @@ class SubExprIntSetEvaluator : public IntSetEvaluator {
ExprIntSetMap expr_map;
};
ExprIntSetMap EvalSetForEachSubExpr(Expr e,
ExprIntSetMap EvalSetForEachSubExpr(
Expr e,
const std::unordered_map<const Variable*, IntSet>& dom_map) {
SubExprIntSetEvaluator m(dom_map);
Analyzer ana;
auto dmap = ConvertDomMap(dom_map);
SubExprIntervalSetEvaluator m(&ana, dmap);
m.Eval(e);
return m.expr_map;
}
IntSet EvalSet(Range r,
const Map<IterVar, IntSet>& dom_map) {
std::unordered_map<const Variable*, IntSet> dmap;
for (auto kv : dom_map) {
dmap[kv.first->var.as<Variable>()] = kv.second;
}
return EvalSet(r, dmap);
return EvalSet(r, ConvertDomMap(dom_map));
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<IntervalSet>([](const IntervalSet *op, IRPrinter *p) {
p->stream << "interval-set"
<< "[" << op->i.min << ", "
<< op->i.max << ']';
.set_dispatch<IntervalSetNode>([](const IntervalSetNode *op, IRPrinter *p) {
p->stream << "IntervalSet"
<< "[" << op->min_value << ", "
<< op->max_value << ']';
});
} // namespace arith
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file int_set.h
* \brief Internal data structure for integer set.
*/
#ifndef TVM_ARITHMETIC_INT_SET_H_
#define TVM_ARITHMETIC_INT_SET_H_
#include <tvm/arithmetic.h>
#include <tvm/expr_operator.h>
#include <limits>
#include "const_fold.h"
namespace tvm {
namespace arith {
/*!
* \brief Symbolic interval set.
*
* \note We intentionally keep the internal of IntSet private,
as we might change it later.
*/
class IntervalSetNode : public IntSetNode {
public:
/*! \brief Minimum value in the interval. */
Expr min_value;
/*! \brief Maximum value in the interval. */
Expr max_value;
// visitor overload.
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("min_value", &min_value);
v->Visit("max_value", &max_value);
}
/*! \return Whether the interval has upper bound. */
bool HasUpperBound() const {
return !is_pos_inf(max_value) && !IsEmpty();
}
/*! \return Whether the interval has lower bound. */
bool HasLowerBound() const {
return !is_neg_inf(min_value) && !IsEmpty();
}
/*! \return Whether the interval is a single point. */
bool IsSinglePoint() const {
return min_value.same_as(max_value);
}
/*! \return whether interval represent nothing */
bool IsEmpty() const {
// during computations, either extreme could occur.
return is_pos_inf(min_value) || is_neg_inf(max_value);
}
/*! \return whether interval represent everything */
bool IsEverything() const {
return is_neg_inf(min_value) && is_pos_inf(max_value);
}
static constexpr const char* _type_key = "arith.IntervalSet";
TVM_DECLARE_NODE_TYPE_INFO(IntervalSetNode, IntSetNode);
};
/*!
* \brief Interval set used for symbolic integer analysis.
* \sa IntervalSetNode
*/
class IntervalSet : public IntSet {
public:
/*!
* \brief Make a new instance of interval set.
* \param min_value The minimum value in the interval.
* \param max_value The maximum value in the interval.
* \return The created set.
*/
TVM_DLL IntervalSet(Expr min_value, Expr max_value);
/*!
* \brief Create an IntervalSet that represents a single point.
* \param value The value to be represented.
* \return The result set.
*/
static IntervalSet SinglePoint(Expr value) {
return IntervalSet(value, value);
}
/*!
* \brief Create an IntervalSet that represents everything.
* \param value The value to be represented.
* \return The result set.
*/
static IntervalSet Everything() {
return IntervalSet(neg_inf(), pos_inf());
}
/*!
* \brief Create an empty eet.
* \return The result set.
*/
static IntervalSet Empty() {
return IntervalSet(pos_inf(), neg_inf());
}
TVM_DEFINE_NODE_REF_COW(IntervalSetNode);
TVM_DEFINE_NODE_REF_METHODS(IntervalSet, IntSet, IntervalSetNode);
};
/*!
* \brief Create union of two IntervalSets.
* \param analyzer The analyzer for simplification analysis.
* \param a The first set.
* \param b The second set.
* \return The result set.
*/
TVM_DLL IntervalSet Union(Analyzer* analyzer, IntervalSet a, IntervalSet b);
/*!
* \brief Create insersection of two IntervalSets.
* \param analzyer The analyzer for simplification analysis.
* \param a The first set.
* \param b The second set.
* \return The result set.
*/
TVM_DLL IntervalSet Intersect(Analyzer *analzyer, IntervalSet a, IntervalSet b);
} // namespace arith
} // namespace tvm
#endif // TVM_ARITHMETIC_INT_SET_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2017 by Contributors
* \file int_set_internal.h
* \brief Implementations of integer set
*/
#ifndef TVM_ARITHMETIC_INT_SET_INTERNAL_H_
#define TVM_ARITHMETIC_INT_SET_INTERNAL_H_
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/arithmetic.h>
namespace tvm {
namespace arith {
using HalideIR::Internal::Interval;
/*! \brief Set of continuous interval */
struct IntervalSet : public IntSetNode {
/*! \brief the internal interval*/
Interval i;
static IntSet make(Interval i) {
NodePtr<IntervalSet> n =
make_node<IntervalSet>();
n->i = i;
return IntSet(n);
}
static IntSet make(Expr min, Expr max) {
NodePtr<IntervalSet> n =
make_node<IntervalSet>();
n->i.min = min;
n->i.max = max;
return IntSet(n);
}
static constexpr const char* _type_key = "IntervalSet";
TVM_DECLARE_NODE_TYPE_INFO(IntervalSet, IntSetNode);
};
/*!
* \brief set represented by strided integers
* Reserved for cases where strided access is supported.
*/
struct StrideSet : public IntSetNode {
/*! \brief the base inetrval */
Interval base;
/*! \brief additional extents in positive number */
Array<Expr> extents;
/*! \brief additional strides in positive number */
Array<Expr> strides;
static constexpr const char* _type_key = "StrideSet";
TVM_DECLARE_NODE_TYPE_INFO(StrideSet, IntSetNode);
};
} // namespace arith
} // namespace tvm
#endif // TVM_ARITHMETIC_INT_SET_INTERNAL_H_
......@@ -188,7 +188,15 @@ Expr operator%(Expr a, Expr b) {
return ir::Mod::make(a, b);
}
Expr min(Expr a, Expr b) {
// inf-aware simplificaiton
using arith::is_pos_inf;
using arith::is_neg_inf;
if (is_pos_inf(a)) return b;
if (is_neg_inf(a)) return a;
if (is_pos_inf(b)) return a;
if (is_neg_inf(b)) return b;
BinaryOpMatchTypes(a, b);
Expr ret = arith::TryConstFold<ir::Min>(a, b);
if (ret.defined()) return ret;
......@@ -196,6 +204,13 @@ Expr min(Expr a, Expr b) {
}
Expr max(Expr a, Expr b) {
// inf-aware simplificaiton
using arith::is_pos_inf;
using arith::is_neg_inf;
if (is_pos_inf(a)) return a;
if (is_neg_inf(a)) return b;
if (is_pos_inf(b)) return b;
if (is_neg_inf(b)) return a;
BinaryOpMatchTypes(a, b);
Expr ret = arith::TryConstFold<ir::Max>(a, b);
if (ret.defined()) return ret;
......
......@@ -28,7 +28,7 @@
#include <tvm/arithmetic.h>
#include <unordered_map>
#include <unordered_set>
#include "../arithmetic/int_set_internal.h"
#include "../arithmetic/int_set.h"
#include "../runtime/thread_storage_scope.h"
namespace tvm {
......@@ -366,7 +366,7 @@ class LoopPartitioner : public IRMutator {
std::pair<IntSet, std::unordered_set<const Node*>>
GetIntervalAndCondset(const Partition &partitions,
const arith::Interval &for_interval,
const arith::IntervalSet &for_interval,
bool cond_value);
inline Stmt MakeFor(const Node* op, Expr extent, Stmt body);
......@@ -374,6 +374,7 @@ class LoopPartitioner : public IRMutator {
/* Candidate IRs that may be partitioned potentially */
std::unordered_map<const Variable*, IntSet> hint_map_;
std::unordered_map<const Variable*, IntSet> relax_map_;
arith::Analyzer analyzer_;
CandidateSelector selector;
};
......@@ -381,16 +382,17 @@ class LoopPartitioner : public IRMutator {
// given in the second component provably have value given by cond_value
std::pair<IntSet, std::unordered_set<const Node*>>
LoopPartitioner::GetIntervalAndCondset(const Partition &partitions,
const arith::Interval &for_interval,
const arith::IntervalSet &for_interval,
bool cond_value) {
Array<IntSet> sets;
std::unordered_set<const Node*> cond_set;
for (const auto &kv : partitions) {
if (kv.first.second == cond_value) {
arith::Interval interval = kv.second.as<arith::IntervalSet>()->i;
arith::Interval intersection = arith::Interval::make_intersection(interval, for_interval);
if (!intersection.is_empty()) {
arith::IntervalSet interval = Downcast<arith::IntervalSet>(kv.second);
arith::IntervalSet intersection = arith::Intersect(
&analyzer_, interval, for_interval);
if (!intersection->IsEmpty()) {
sets.push_back(kv.second);
cond_set.insert(kv.first.first);
}
......@@ -463,11 +465,12 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
Expr max,
Stmt body,
bool partition_thread_scope) {
using namespace arith;
PartitionFinder finder(var, hint_map_, relax_map_);
finder.Visit(body);
if (finder.partitions.empty()) return Stmt();
arith::Interval for_interval(min, max);
arith::IntervalSet for_interval(min, max);
bool cond_value;
IntSet middle_interval;
std::unordered_set<const Node*> cond_set;
......@@ -488,7 +491,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
cond_value = true;
}
arith::Interval middle_interval_i = middle_interval.as<arith::IntervalSet>()->i;
IntervalSet middle_interval_i = Downcast<IntervalSet>(middle_interval);
// middle_interval is the subrange of the loop variable range for which a
// set of conditions are true (or false resp.)
// The part of the loop variable range that is before (after resp.) that
......@@ -499,7 +502,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
Expr body_begin;
Stmt pre_stmt;
bool pre_stmt_recurse = true;
if (middle_interval_i.has_lower_bound()) {
if (middle_interval_i->HasLowerBound()) {
body_begin = ir::Simplify(middle_interval.min());
if (!can_prove(body_begin == min)) {
Expr cond = (body_begin - min >= 0);
......@@ -524,7 +527,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
Expr post_doubt_begin;
Stmt post_stmt;
bool post_stmt_recurse = true;
if (middle_interval_i.has_upper_bound()) {
if (middle_interval_i->HasUpperBound()) {
post_doubt_begin = ir::Simplify(middle_interval.max() + 1);
if (!can_prove(middle_interval.max() == max)) {
// require the extent to be non-negative
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
def test_deduce():
a = tvm.var('a')
b = tvm.var('b')
c = tvm.var('c')
d = tvm.var('d')
b_s = tvm.arith.IntervalSet(2, 3)
c_s = tvm.arith.IntervalSet(10, 15)
d_s = tvm.arith.IntervalSet(-3, -1)
zero = tvm.const(0, "int32")
e0 = (-b)*a+c-d
res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
ans0 = ((d - c) /(b*-1))
assert str(tvm.ir_pass.Simplify(res0.max_value)) == str(ans0)
# expression containing variable a is on rhs
res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {})
assert str(tvm.ir_pass.Simplify(res0.max_value)) == str(ans0)
e0 = d*a+c-d
res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
ans0 = ((0-c)/d + 1)
assert str(tvm.ir_pass.Simplify(res0.max_value)) == str(ans0)
# expression containing variable a is on rhs
res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {})
assert str(tvm.ir_pass.Simplify(res0.max_value)) == str(ans0)
e1 = (a*4+b < c)
res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
ans1 = (((c - b) + -1)/4)
assert str(tvm.ir_pass.Simplify(res1.max_value)) == str(ans1)
# expression containing variable a is on rhs
e1 = (c > a*4+b)
res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
assert str(tvm.ir_pass.Simplify(res1.max_value)) == str(ans1)
e2 = (tvm.max(5, a * 4) < 0)
res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {})
assert str(res2.max_value) == "neg_inf"
assert str(res2.min_value) == "pos_inf"
# expression containing variable a is on rhs
e2 = (zero < tvm.max(5, a * 4))
res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {})
assert str(res2.max_value) == "neg_inf"
assert str(res2.min_value) == "pos_inf"
e3 = (-b)+a*c-d
res3 = tvm.arith.DeduceBound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
ans3 = 2/c+1
assert str(tvm.ir_pass.Simplify(res3.min_value)) == str(ans3)
res3 = tvm.arith.DeduceBound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
assert str(tvm.ir_pass.Simplify(res3.min_value)) == str(ans3)
def test_check():
a = tvm.var('a')
b = tvm.var('b')
c = tvm.var('c')
d = tvm.var('d')
b_s = tvm.arith.IntervalSet(2, 3)
c_s = tvm.arith.IntervalSet(5, 7)
d_s = tvm.arith.IntervalSet(-3, -1)
# no compare operator
res1 = tvm.arith.DeduceBound(a, a+b, {b: b_s}, {})
assert res1.is_nothing()
# multiple compare operators
res2 = tvm.arith.DeduceBound(a, (a+b>3).astype(c.dtype)>c , {b: b_s, c: c_s}, {})
assert res2.is_nothing()
# multiple target variable
res2 = tvm.arith.DeduceBound(a, a*2-a>b, {b: b_s}, {})
assert res2.is_nothing()
def test_deduce_basic():
def test_basic(a1, a2, coff):
a = tvm.var('a')
b = tvm.var('b')
b_s = tvm.arith.IntervalSet(a1, a2)
e0 = b + a*coff + 3
res1 = tvm.arith.DeduceBound(a, e0<17, {b: b_s}, {b: b_s})
[x, y] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value]
assert (tvm.ir_pass.Simplify((x * coff + 3 + y) < 17)).value == 1
# expression containing variable a is on rhs
res1 = tvm.arith.DeduceBound(a, tvm.const(17, "int32") < e0, {b: b_s}, {b: b_s})
[x, y] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value]
assert (tvm.ir_pass.Simplify((x * coff + 3 + y) > 17)).value == 1
# expression containing variable a is on rhs
res1 = tvm.arith.DeduceBound(a, tvm.const(17, "int32")>= e0, {b: b_s}, {b: b_s})
[x, y] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value]
assert (tvm.ir_pass.Simplify((x * coff + 3 + y) <= 17)).value == 1
res1 = tvm.arith.DeduceBound(a, e0>=17, {b: b_s}, {b: b_s})
[x, y] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value]
assert (tvm.ir_pass.Simplify((x * coff + 3 + y) >= 17)).value == 1
test_basic(0, 4, 4)
test_basic(1, 5, 4)
test_basic(2, 6, 4)
test_basic(0, 4, -4)
test_basic(1, 5, -4)
test_basic(2, 6, -4)
def test_deduce_complex():
def test_complex(a1, a2, coff):
a = tvm.var('a')
b = tvm.var('b')
b_s = tvm.arith.IntervalSet(a1, a2)
e0 = (b*3 + a* coff) * 4
res1 = tvm.arith.DeduceBound(a, e0<63, {b: b_s}, {b: b_s})
[t, x] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value]
assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) < 63)).value == 1
# expression containing variable a is on rhs
res1 = tvm.arith.DeduceBound(a, tvm.const(63, "int32")>= e0, {b: b_s}, {b: b_s})
[t, x] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value]
assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) <= 63)).value == 1
res1 = tvm.arith.DeduceBound(a, e0>63, {b: b_s}, {b: b_s})
[t, x] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value]
assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) > 63)).value == 1
# expression containing variable a is on rhs
res1 = tvm.arith.DeduceBound(a, tvm.const(63, "int32") <= e0, {b: b_s}, {b: b_s})
[t, x] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value]
assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) >= 63)).value == 1
test_complex(0, 4, 4)
test_complex(0, 4, -4)
test_complex(2, 6, 4)
test_complex(0, 4, -4)
test_complex(1, 5, -4)
test_complex(2, 6, -4)
if __name__ == "__main__":
test_check()
test_deduce_basic()
test_deduce_complex()
......@@ -16,168 +16,87 @@
# under the License.
import tvm
class IntSetChecker:
def __init__(self):
self.analyzer = tvm.arith.Analyzer()
def verify(self, data, dmap, expected):
res = self.analyzer.int_set(data, dmap)
def err_msg():
return "\ndata={}\ndmap={}\nres={}\nexpected={}".format(data, dmap, res, expected)
def equal(x, y):
res = self.analyzer.canonical_simplify(x - y)
return tvm.ir_pass.Equal(res, 0)
assert equal(res.min_value, expected[0]), err_msg()
assert equal(res.max_value, expected[1]), err_msg()
def test_basic():
s = tvm.arith.intset_interval(2, 3)
assert s.min().value == 2
assert s.max().value == 3
s = tvm.arith.IntervalSet(2, 3)
assert s.min_value.value == 2
assert s.max_value.value == 3
def test_vector():
base = 10
stride = 3
lanes = 2
s = tvm.arith.intset_vector(tvm.make.Ramp(base, stride, lanes))
assert s.min().value == base
assert s.max().value == base + stride * lanes - 1
def test_deduce():
a = tvm.var('a')
b = tvm.var('b')
c = tvm.var('c')
d = tvm.var('d')
b_s = tvm.arith.intset_interval(2, 3)
c_s = tvm.arith.intset_interval(10, 15)
d_s = tvm.arith.intset_interval(-3, -1)
zero = tvm.const(0, "int32")
e0 = (-b)*a+c-d
res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
ans0 = ((d - c) /(b*-1))
assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0)
# expression containing variable a is on rhs
res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {})
assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0)
e0 = d*a+c-d
res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
ans0 = ((0-c)/d + 1)
assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0)
# expression containing variable a is on rhs
res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {})
assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0)
e1 = (a*4+b < c)
res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
ans1 = (((c - b) + -1)/4)
assert str(tvm.ir_pass.Simplify(res1.max())) == str(ans1)
# expression containing variable a is on rhs
e1 = (c > a*4+b)
res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
assert str(tvm.ir_pass.Simplify(res1.max())) == str(ans1)
e2 = (tvm.max(5, a * 4) < 0)
res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {})
assert str(res2.max()) == "neg_inf"
assert str(res2.min()) == "pos_inf"
# expression containing variable a is on rhs
e2 = (zero < tvm.max(5, a * 4))
res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {})
assert str(res2.max()) == "neg_inf"
assert str(res2.min()) == "pos_inf"
e3 = (-b)+a*c-d
res3 = tvm.arith.DeduceBound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
ans3 = 2/c+1
assert str(tvm.ir_pass.Simplify(res3.min())) == str(ans3)
res3 = tvm.arith.DeduceBound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
assert str(tvm.ir_pass.Simplify(res3.min())) == str(ans3)
def test_check():
a = tvm.var('a')
b = tvm.var('b')
c = tvm.var('c')
d = tvm.var('d')
b_s = tvm.arith.intset_interval(2, 3)
c_s = tvm.arith.intset_interval(5, 7)
d_s = tvm.arith.intset_interval(-3, -1)
# no compare operator
res1 = tvm.arith.DeduceBound(a, a+b, {b: b_s}, {})
assert res1.is_nothing()
# multiple compare operators
res2 = tvm.arith.DeduceBound(a, (a+b>3).astype(c.dtype)>c , {b: b_s, c: c_s}, {})
assert res2.is_nothing()
# multiple target variable
res2 = tvm.arith.DeduceBound(a, a*2-a>b, {b: b_s}, {})
assert res2.is_nothing()
def test_deduce_basic():
def test_basic(a1, a2, coff):
a = tvm.var('a')
b = tvm.var('b')
b_s = tvm.arith.intset_interval(a1, a2)
e0 = b + a*coff + 3
res1 = tvm.arith.DeduceBound(a, e0<17, {b: b_s}, {b: b_s})
[x, y] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()]
assert (tvm.ir_pass.Simplify((x * coff + 3 + y) < 17)).value == 1
# expression containing variable a is on rhs
res1 = tvm.arith.DeduceBound(a, tvm.const(17, "int32") < e0, {b: b_s}, {b: b_s})
[x, y] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()]
assert (tvm.ir_pass.Simplify((x * coff + 3 + y) > 17)).value == 1
# expression containing variable a is on rhs
res1 = tvm.arith.DeduceBound(a, tvm.const(17, "int32")>= e0, {b: b_s}, {b: b_s})
[x, y] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()]
assert (tvm.ir_pass.Simplify((x * coff + 3 + y) <= 17)).value == 1
res1 = tvm.arith.DeduceBound(a, e0>=17, {b: b_s}, {b: b_s})
[x, y] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()]
assert (tvm.ir_pass.Simplify((x * coff + 3 + y) >= 17)).value == 1
test_basic(0, 4, 4)
test_basic(1, 5, 4)
test_basic(2, 6, 4)
test_basic(0, 4, -4)
test_basic(1, 5, -4)
test_basic(2, 6, -4)
def test_deduce_complex():
def test_complex(a1, a2, coff):
a = tvm.var('a')
b = tvm.var('b')
b_s = tvm.arith.intset_interval(a1, a2)
e0 = (b*3 + a* coff) * 4
res1 = tvm.arith.DeduceBound(a, e0<63, {b: b_s}, {b: b_s})
[t, x] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()]
assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) < 63)).value == 1
# expression containing variable a is on rhs
res1 = tvm.arith.DeduceBound(a, tvm.const(63, "int32")>= e0, {b: b_s}, {b: b_s})
[t, x] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()]
assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) <= 63)).value == 1
res1 = tvm.arith.DeduceBound(a, e0>63, {b: b_s}, {b: b_s})
[t, x] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()]
assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) > 63)).value == 1
# expression containing variable a is on rhs
res1 = tvm.arith.DeduceBound(a, tvm.const(63, "int32") <= e0, {b: b_s}, {b: b_s})
[t, x] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()]
assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) >= 63)).value == 1
test_complex(0, 4, 4)
test_complex(0, 4, -4)
test_complex(2, 6, 4)
test_complex(0, 4, -4)
test_complex(1, 5, -4)
test_complex(2, 6, -4)
assert s.min_value.value == base
assert s.max_value.value == base + stride * lanes - 1
def test_add_sub():
ck = IntSetChecker()
x, y = tvm.var("x"), tvm.var("y")
ck.verify(x + y, {x : tvm.arith.IntervalSet(0, 10)}, (y, 10 + y))
ck.verify(x + y,
{x : tvm.arith.IntervalSet(0, 10), y : tvm.arith.IntervalSet(1, 11)},
(1, 21))
ck.verify(x - y,
{x : tvm.arith.IntervalSet(0, 10), y : tvm.arith.IntervalSet(1, 11)},
(-11, 9))
def test_mul_div():
ck = IntSetChecker()
x, y = tvm.var("x"), tvm.var("y")
ck.analyzer.update(y, tvm.arith.ConstIntBound(1, 100), override=True)
ck.verify(x * y, {x : tvm.arith.IntervalSet(0, 10)}, (0, 10 * y))
ck.verify(x * 2, {x : tvm.arith.IntervalSet(1, 10)}, (2, 20))
ck.verify(x * -2, {x : tvm.arith.IntervalSet(1, 10)}, (-20, -2))
ck.verify(x / y, {x : tvm.arith.IntervalSet(0, 10)}, (0, 10 / y))
ck.verify(x / 2, {x : tvm.arith.IntervalSet(1, 10)}, (0, 5))
def test_mod():
ck = IntSetChecker()
x, y = tvm.var("x"), tvm.var("y")
ck.analyzer.update(y, tvm.arith.ConstIntBound(1, 100), override=True)
ck.verify(x % y, {x : tvm.arith.IntervalSet(0, 10)}, (0, y - 1))
ck.verify(x % 10, {x : tvm.arith.IntervalSet(1, 10)}, (0, 9))
def test_max_min():
ck = IntSetChecker()
x, y = tvm.var("x"), tvm.var("y")
ck.verify(tvm.max(x, x + 1), {x : tvm.arith.IntervalSet(0, 10)}, (1, 11))
ck.verify(tvm.min(x - 1, x + 1), {x : tvm.arith.IntervalSet(0, 10)}, (-1, 9))
ck.verify(tvm.min(x, y), {}, (tvm.min(x, y), tvm.min(x, y)))
ck.verify(tvm.max(x, y), {}, (tvm.max(x, y), tvm.max(x, y)))
def test_select():
ck = IntSetChecker()
x, y = tvm.var("x"), tvm.var("y")
ck.verify(tvm.expr.Select(x > 0, x - 1, x + 1),
{x : tvm.arith.IntervalSet(0, 10)}, (-1, 11))
if __name__ == "__main__":
test_basic()
test_vector()
test_deduce()
test_check()
test_deduce_basic()
test_deduce_complex()
test_add_sub()
test_mul_div()
test_max_min()
test_select()
test_mod()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment