Unverified Commit 153417a5 by Tianqi Chen Committed by GitHub

[ARITH] Revamp IntSet (#3272)

parent 9bb16872
...@@ -328,71 +328,14 @@ class ConstraintContext { ...@@ -328,71 +328,14 @@ class ConstraintContext {
std::function<void()> exit_; std::function<void()> exit_;
}; };
/*!
* \brief Analyzer that contains bunch of sub-analyzers.
*
* Each sub-analyzer can make use of another sub-analyzer
* by weak reference of this.
*
* NOTE for sub-analyzer developers:
* If the analyzer uses memoization, we need to clear the internal
* cache when information about a Var has been overrideen.
*/
class Analyzer {
public:
/*! \brief sub-analyzer: const integer bound */
ConstIntBoundAnalyzer const_int_bound;
/*! \brief sub-analyzer: modular set */
ModularSetAnalyzer modular_set;
/*! \brief sub-analyzer rewrite simplify */
RewriteSimplifier rewrite_simplify;
/*! \brief sub-analyzer canonical simplify */
CanonicalSimplifier canonical_simplify;
/*! \brief constructor */
Analyzer();
/*!
* \brief Notify all the sub-analyzers that var
* is created and binded to expr.
*
* Each var can only be binded once.
*
* \param var The variable.
* \param expr The expression we bind to.
*/
void Bind(const VarExpr& var, const Expr& expr);
/*!
* \brief Notify all the sub-analyzers that var
* is created and binded to a range.
*
* Each var can only be binded once.
*
* \param var The variable.
* \param range The range we bind to.
*/
void Bind(const VarExpr& var, const Range& range);
/*!
* \brief Whether can we proof expr >= val.
* Non-negative proof is very useful in integer analysis
* to lower divisions and mods given difference in trunc and ceil mode.
*
* \param expr The expression.
* \param lower_bound The lower bound.
* \return Whether we can proof it.
*
* \note Analyzer will call into sub-analyzers to get the result.
*/
bool CanProveGreaterEqual(const Expr& expr, int64_t lower_bound);
};
//----------------------------------------------- //-----------------------------------------------
// Integer set abstraction API. // Integer set data structure.
// //
// This is a API build on top of the base // This is a API build on top of the base
// integer analysis API to provide set analysis. // integer analysis API to provide set analysis.
//------------------------------------------------ //------------------------------------------------
/*! /*!
* \brief Sign of an expression or set. * \brief Sign type of an integer expression.
*/ */
enum SignType { enum SignType {
kPositive, kPositive,
...@@ -401,8 +344,13 @@ enum SignType { ...@@ -401,8 +344,13 @@ enum SignType {
kUnknown kUnknown
}; };
// internal node container of int set. /*!
struct IntSetNode; * \brief Base class of all IntSet containers.
*/
struct IntSetNode : public Node {
static constexpr const char* _type_key = "IntSet";
TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Node);
};
/*! /*!
* \brief Integer set class, represent a set of integers in one dimension. * \brief Integer set class, represent a set of integers in one dimension.
...@@ -424,11 +372,6 @@ class IntSet : public NodeRef { ...@@ -424,11 +372,6 @@ class IntSet : public NodeRef {
* \return The covering range. * \return The covering range.
*/ */
Range cover_range(Range max_range) const; Range cover_range(Range max_range) const;
/*!
* \brief find an interval that covers the set.
* \return The covering interval set.
*/
IntSet cover_interval() const;
/*! \return Lower bound of the set */ /*! \return Lower bound of the set */
Expr min() const; Expr min() const;
/*! \return upper bound of the set */ /*! \return upper bound of the set */
...@@ -493,33 +436,91 @@ class IntSet : public NodeRef { ...@@ -493,33 +436,91 @@ class IntSet : public NodeRef {
}; };
/*! /*!
* \brief Base class of all IntSet containers. * \brief Integer set analyzer.
*/ */
struct IntSetNode : public Node { class IntSetAnalyzer {
static constexpr const char* _type_key = "IntSet"; public:
TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Node); /*!
* \brief Find a symbolic integer set that contains all possible values of
* expr given the domain of each variables.
*
* \param expr The expression of interest.
* \param dom_map The domain map to indicate which variable to relax.
* \return the result of the analysis.
*/
IntSet operator()(const Expr& expr, const Map<Var, IntSet>& dom_map);
private:
friend class Analyzer;
explicit IntSetAnalyzer(Analyzer* parent);
~IntSetAnalyzer();
class Impl;
/*! \brief Internal impl */
Impl* impl_;
}; };
/*! /*!
* \brief Detect if e can be rewritten as e = sum_{i=0}^{n-1} var[i] * coeff[i] + coeff[n] * \brief Analyzer that contains bunch of sub-analyzers.
* Where coeff[i] and base are invariant of var[j] for all i and j.
* *
* \param e The expression to be detected. * Each sub-analyzer can make use of another sub-analyzer
* \param vars List of variables to be used in detection. * by weak reference of this.
* \return [coeff[i]] if it is possible, empty array if it is not. *
* NOTE for sub-analyzer developers:
* If the analyzer uses memoization, we need to clear the internal
* cache when information about a Var has been overridden.
*/
class Analyzer {
public:
/*! \brief sub-analyzer: const integer bound */
ConstIntBoundAnalyzer const_int_bound;
/*! \brief sub-analyzer: modular set */
ModularSetAnalyzer modular_set;
/*! \brief sub-analyzer rewrite simplify */
RewriteSimplifier rewrite_simplify;
/*! \brief sub-analyzer canonical simplify */
CanonicalSimplifier canonical_simplify;
/*! \brief sub-analyzer: int set */
IntSetAnalyzer int_set;
/*! \brief constructor */
Analyzer();
/*!
* \brief Notify all the sub-analyzers that var
* is created and binded to expr.
*
* Each var can only be binded once.
*
* \param var The variable.
* \param expr The expression we bind to.
*/
void Bind(const VarExpr& var, const Expr& expr);
/*!
* \brief Notify all the sub-analyzers that var
* is created and binded to a range.
*
* Each var can only be binded once.
*
* \param var The variable.
* \param range The range we bind to.
*/ */
Array<Expr> DetectLinearEquation(const Expr& e, const Array<Var>& vars); void Bind(const VarExpr& var, const Range& range);
/*!
* \brief Whether can we prove expr >= val.
/*! * Non-negative proof is very useful in integer analysis
* \brief Detect if expression corresponds to clip bound of the vars * to lower divisions and mods given difference in trunc and ceil mode.
* *
* \param e The expression to be detected. * \param expr The expression.
* \param vars List of variables to be used in detection. * \param lower_bound The lower bound.
* \return concat([min_value[i], max_value[i]]), None is returned if there is no min or max value * \return Whether we can prove it.
* return empty if the e does not match the pattern. *
* \note Analyzer will call into sub-analyzers to get the result.
*/ */
Array<Expr> DetectClipBound(const Expr& e, const Array<Var>& vars); bool CanProveGreaterEqual(const Expr& expr, int64_t lower_bound);
};
//-----------------------------------------------
// Integer set legacy API.
//------------------------------------------------
/*! /*!
* \brief Find an symbolic integer set that contains all possible values of * \brief Find an symbolic integer set that contains all possible values of
* e given the domain of each iteration variables. * e given the domain of each iteration variables.
...@@ -638,6 +639,29 @@ IntSet DeduceBound(Expr v, Expr cond, ...@@ -638,6 +639,29 @@ IntSet DeduceBound(Expr v, Expr cond,
*/ */
Domain DomainTouched(Stmt body, const Tensor &tensor, bool consider_calls, bool consider_provides); Domain DomainTouched(Stmt body, const Tensor &tensor, bool consider_calls, bool consider_provides);
// Expression pattern detector.
/*!
* \brief Detect if e can be rewritten as e = sum_{i=0}^{n-1} var[i] * coeff[i] + coeff[n]
* Where coeff[i] and base are invariant of var[j] for all i and j.
*
* \param e The expression to be detected.
* \param vars List of variables to be used in detection.
* \return [coeff[i]] if it is possible, empty array if it is not.
*/
Array<Expr> DetectLinearEquation(const Expr& e,
const Array<Var>& vars);
/*!
* \brief Detect if expression corresponds to clip bound of the vars
*
* \param e The expression to be detected.
* \param vars List of variables to be used in detection.
* \return concat([min_value[i], max_value[i]]), None is returned if there is no min or max value
* return empty if the e does not match the pattern.
*/
Array<Expr> DetectClipBound(const Expr& e,
const Array<Var>& vars);
// implementation // implementation
inline const IntSetNode* IntSet::operator->() const { inline const IntSetNode* IntSet::operator->() const {
return static_cast<const IntSetNode*>(node_.get()); return static_cast<const IntSetNode*>(node_.get());
......
...@@ -32,21 +32,21 @@ class IntSet(NodeBase): ...@@ -32,21 +32,21 @@ class IntSet(NodeBase):
return _api_internal._IntSetIsEverything(self) return _api_internal._IntSetIsEverything(self)
@register_node @register_node("arith.IntervalSet")
class IntervalSet(IntSet): class IntervalSet(IntSet):
"""Represent set of continuous interval""" """Represent set of continuous interval [min_value, max_value]
def min(self):
"""get the minimum value"""
return _api_internal._IntervalSetGetMin(self)
def max(self):
"""get the maximum value"""
return _api_internal._IntervalSetGetMax(self)
Parameters
----------
min_value : Expr
The minimum value in the interval.
@register_node max_value : Expr
class StrideSet(IntSet): The maximum value in the interval.
"""Represent set of strided integers""" """
def __init__(self, min_value, max_value):
self.__init_handle_by_constructor__(
_make_IntervalSet, min_value, max_value)
@register_node("arith.ModularSet") @register_node("arith.ModularSet")
...@@ -114,6 +114,7 @@ class Analyzer: ...@@ -114,6 +114,7 @@ class Analyzer:
self._modular_set = _mod("modular_set") self._modular_set = _mod("modular_set")
self._rewrite_simplify = _mod("rewrite_simplify") self._rewrite_simplify = _mod("rewrite_simplify")
self._canonical_simplify = _mod("canonical_simplify") self._canonical_simplify = _mod("canonical_simplify")
self._int_set = _mod("int_set")
self._enter_constraint_context = _mod("enter_constraint_context") self._enter_constraint_context = _mod("enter_constraint_context")
def const_int_bound(self, expr): def const_int_bound(self, expr):
...@@ -176,6 +177,24 @@ class Analyzer: ...@@ -176,6 +177,24 @@ class Analyzer:
""" """
return self._canonical_simplify(expr) return self._canonical_simplify(expr)
def int_set(self, expr, dom_map):
"""Compute a symbolic IntSet that covers expr for all values in dom_map.
Parameters
----------
expr : tvm.Expr
The expression.
dom_map : Dict[Var, tvm.arith.IntSet]
The domain for variables to be relaxed.
Returns
-------
result : IntSet
The result.
"""
return self._int_set(expr, dom_map)
def bind(self, var, expr): def bind(self, var, expr):
"""Bind a variable to the expression. """Bind a variable to the expression.
......
...@@ -39,6 +39,7 @@ TVM_REGISTER_API("arith.intset_vector") ...@@ -39,6 +39,7 @@ TVM_REGISTER_API("arith.intset_vector")
TVM_REGISTER_API("arith.intset_interval") TVM_REGISTER_API("arith.intset_interval")
.set_body_typed(IntSet::interval); .set_body_typed(IntSet::interval);
TVM_REGISTER_API("arith.DetectLinearEquation") TVM_REGISTER_API("arith.DetectLinearEquation")
.set_body_typed(DetectLinearEquation); .set_body_typed(DetectLinearEquation);
...@@ -110,6 +111,10 @@ TVM_REGISTER_API("arith._CreateAnalyzer") ...@@ -110,6 +111,10 @@ TVM_REGISTER_API("arith._CreateAnalyzer")
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
*ret = self->canonical_simplify(args[0]); *ret = self->canonical_simplify(args[0]);
}); });
} else if (name == "int_set") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
*ret = self->int_set(args[0], args[1]);
});
} else if (name == "bind") { } else if (name == "bind") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
auto& sptr = args[1].node_sptr(); auto& sptr = args[1].node_sptr();
......
...@@ -31,7 +31,8 @@ Analyzer::Analyzer() ...@@ -31,7 +31,8 @@ Analyzer::Analyzer()
: const_int_bound(this), : const_int_bound(this),
modular_set(this), modular_set(this),
rewrite_simplify(this), rewrite_simplify(this),
canonical_simplify(this) { canonical_simplify(this),
int_set(this) {
} }
void Analyzer::Bind(const VarExpr& v, const Expr& expr) { void Analyzer::Bind(const VarExpr& v, const Expr& expr) {
...@@ -74,7 +75,7 @@ void ConstraintContext::ExitWithScope() { ...@@ -74,7 +75,7 @@ void ConstraintContext::ExitWithScope() {
bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) { bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) {
if (const auto* ptr = expr.as<ir::IntImm>()) { if (const auto* ptr = expr.as<ir::IntImm>()) {
return ptr->value > lower_bound; return ptr->value >= lower_bound;
} }
auto bd = this->const_int_bound(this->rewrite_simplify(expr)); auto bd = this->const_int_bound(this->rewrite_simplify(expr));
if (bd->min_value >= lower_bound) return true; if (bd->min_value >= lower_bound) return true;
......
...@@ -30,12 +30,12 @@ ...@@ -30,12 +30,12 @@
#include <unordered_set> #include <unordered_set>
#include <unordered_map> #include <unordered_map>
#include "int_set.h"
namespace tvm { namespace tvm {
namespace arith { namespace arith {
using namespace ir; using namespace ir;
using HalideIR::Internal::Interval;
// a visitor to find the path to the target variable // a visitor to find the path to the target variable
// from a expression. // from a expression.
...@@ -293,7 +293,7 @@ IntSet DeduceBound(Expr v, Expr e, ...@@ -293,7 +293,7 @@ IntSet DeduceBound(Expr v, Expr e,
BoundDeducer d(v, e, hint_map, relax_map); BoundDeducer d(v, e, hint_map, relax_map);
d.Deduce(); d.Deduce();
if (!d.success) return IntSet::nothing(); if (!d.success) return IntSet::nothing();
Expr min = Interval::neg_inf, max = Interval::pos_inf; Expr min = neg_inf(), max = pos_inf();
if (d.is_greater) { if (d.is_greater) {
min = d.result; min = d.result;
} else { } else {
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2019 by Contributors
* \file canonical_simplify.cc * \file canonical_simplify.cc
* \brief Canonical form based simplification. * \brief Canonical form based simplification.
*/ */
...@@ -763,7 +762,10 @@ Mutate_(const Mod* op, const Expr& self) { ...@@ -763,7 +762,10 @@ Mutate_(const Mod* op, const Expr& self) {
if (TryCompare(temp, cval) == kLT) { if (TryCompare(temp, cval) == kLT) {
return temp; return temp;
} else { } else {
return SplitModConst(ToSplitExpr(temp), cval); // contonue to use logic below.
a = extra;
psum = a.as<SumExprNode>();
CHECK(psum != nullptr);
} }
} }
} }
......
...@@ -27,8 +27,8 @@ ...@@ -27,8 +27,8 @@
#define TVM_ARITHMETIC_COMPUTE_EXPR_H_ #define TVM_ARITHMETIC_COMPUTE_EXPR_H_
#include <tvm/ir.h> #include <tvm/ir.h>
#include <arithmetic/Interval.h>
#include <limits> #include <limits>
#include <algorithm>
namespace tvm { namespace tvm {
namespace arith { namespace arith {
...@@ -105,12 +105,12 @@ inline Expr ComputeExpr<ir::Mod>(Expr a, Expr b) { ...@@ -105,12 +105,12 @@ inline Expr ComputeExpr<ir::Mod>(Expr a, Expr b) {
template<> template<>
inline Expr ComputeExpr<ir::Max>(Expr a, Expr b) { inline Expr ComputeExpr<ir::Max>(Expr a, Expr b) {
return HalideIR::Internal::Interval::make_max(a, b); return max(a, b);
} }
template<> template<>
inline Expr ComputeExpr<ir::Min>(Expr a, Expr b) { inline Expr ComputeExpr<ir::Min>(Expr a, Expr b) {
return HalideIR::Internal::Interval::make_min(a, b); return min(a, b);
} }
template<typename Op> template<typename Op>
......
...@@ -206,6 +206,7 @@ inline Expr TryConstFold<ir::Min>(Expr a, Expr b) { ...@@ -206,6 +206,7 @@ inline Expr TryConstFold<ir::Min>(Expr a, Expr b) {
if (pa && pb) return IntImm::make(rtype, std::min(pa->value, pb->value)); if (pa && pb) return IntImm::make(rtype, std::min(pa->value, pb->value));
if (fa && fb) return FloatImm::make(rtype, std::min(fa->value, fb->value)); if (fa && fb) return FloatImm::make(rtype, std::min(fa->value, fb->value));
}); });
if (a.same_as(b)) return a;
return Expr(); return Expr();
} }
...@@ -216,6 +217,7 @@ inline Expr TryConstFold<ir::Max>(Expr a, Expr b) { ...@@ -216,6 +217,7 @@ inline Expr TryConstFold<ir::Max>(Expr a, Expr b) {
if (pa && pb) return IntImm::make(rtype, std::max(pa->value, pb->value)); if (pa && pb) return IntImm::make(rtype, std::max(pa->value, pb->value));
if (fa && fb) return FloatImm::make(rtype, std::max(fa->value, fb->value)); if (fa && fb) return FloatImm::make(rtype, std::max(fa->value, fb->value));
}); });
if (a.same_as(b)) return a;
return Expr(); return Expr();
} }
...@@ -307,6 +309,58 @@ inline Expr TryConstFold<ir::Not>(Expr a) { ...@@ -307,6 +309,58 @@ inline Expr TryConstFold<ir::Not>(Expr a) {
return Expr(); return Expr();
} }
/*! \brief Helper namespace for symbolic value limits */
struct SymbolicLimits {
/*! \brief positive infinity */
static Expr pos_inf_;
/*! \brief negative infinity */
static Expr neg_inf_;
};
/*!
* \brief Opaque expression representing positive infinity.
*
* It can can only be used as parameter of by min/max
* for integer analysis and cannot be used in normal expressions.
*
* \return positive infinity.
*/
inline Expr pos_inf() {
return SymbolicLimits::pos_inf_;
}
/*!
* \brief Check if value is positive infinity.
* \param value The value to be checked.
*
* \return The check result.
*/
inline bool is_pos_inf(const Expr& value) {
return value.same_as(SymbolicLimits::pos_inf_);
}
/*!
* \brief Opaque expression representing negative infinity.
*
* It can can only be used as parameter of by min/max
* for integer analysis and cannot be used in normal expressions.
*
* \return negative infinity.
*/
inline Expr neg_inf() {
return SymbolicLimits::neg_inf_;
}
/*!
* \brief Check if value is negative infinity.
* \param value The value to be checked.
*
* \return The check result.
*/
inline bool is_neg_inf(const Expr& value) {
return value.same_as(SymbolicLimits::neg_inf_);
}
} // namespace arith } // namespace arith
} // namespace tvm } // namespace tvm
#endif // TVM_ARITHMETIC_CONST_FOLD_H_ #endif // TVM_ARITHMETIC_CONST_FOLD_H_
...@@ -19,8 +19,8 @@ ...@@ -19,8 +19,8 @@
/*! /*!
* Copyright (c) 2017 by Contributors * Copyright (c) 2017 by Contributors
* \file bound_deducer.cc * \file detect_linear_equation.cc
* \brief Utility to deduce bound of expression * \brief Utility to detect patterns in the expression.
*/ */
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file int_set.h
* \brief Internal data structure for integer set.
*/
#ifndef TVM_ARITHMETIC_INT_SET_H_
#define TVM_ARITHMETIC_INT_SET_H_
#include <tvm/arithmetic.h>
#include <tvm/expr_operator.h>
#include <limits>
#include "const_fold.h"
namespace tvm {
namespace arith {
/*!
* \brief Symbolic interval set.
*
* \note We intentionally keep the internal of IntSet private,
as we might change it later.
*/
class IntervalSetNode : public IntSetNode {
public:
/*! \brief Minimum value in the interval. */
Expr min_value;
/*! \brief Maximum value in the interval. */
Expr max_value;
// visitor overload.
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("min_value", &min_value);
v->Visit("max_value", &max_value);
}
/*! \return Whether the interval has upper bound. */
bool HasUpperBound() const {
return !is_pos_inf(max_value) && !IsEmpty();
}
/*! \return Whether the interval has lower bound. */
bool HasLowerBound() const {
return !is_neg_inf(min_value) && !IsEmpty();
}
/*! \return Whether the interval is a single point. */
bool IsSinglePoint() const {
return min_value.same_as(max_value);
}
/*! \return whether interval represent nothing */
bool IsEmpty() const {
// during computations, either extreme could occur.
return is_pos_inf(min_value) || is_neg_inf(max_value);
}
/*! \return whether interval represent everything */
bool IsEverything() const {
return is_neg_inf(min_value) && is_pos_inf(max_value);
}
static constexpr const char* _type_key = "arith.IntervalSet";
TVM_DECLARE_NODE_TYPE_INFO(IntervalSetNode, IntSetNode);
};
/*!
* \brief Interval set used for symbolic integer analysis.
* \sa IntervalSetNode
*/
class IntervalSet : public IntSet {
public:
/*!
* \brief Make a new instance of interval set.
* \param min_value The minimum value in the interval.
* \param max_value The maximum value in the interval.
* \return The created set.
*/
TVM_DLL IntervalSet(Expr min_value, Expr max_value);
/*!
* \brief Create an IntervalSet that represents a single point.
* \param value The value to be represented.
* \return The result set.
*/
static IntervalSet SinglePoint(Expr value) {
return IntervalSet(value, value);
}
/*!
* \brief Create an IntervalSet that represents everything.
* \param value The value to be represented.
* \return The result set.
*/
static IntervalSet Everything() {
return IntervalSet(neg_inf(), pos_inf());
}
/*!
* \brief Create an empty eet.
* \return The result set.
*/
static IntervalSet Empty() {
return IntervalSet(pos_inf(), neg_inf());
}
TVM_DEFINE_NODE_REF_COW(IntervalSetNode);
TVM_DEFINE_NODE_REF_METHODS(IntervalSet, IntSet, IntervalSetNode);
};
/*!
* \brief Create union of two IntervalSets.
* \param analyzer The analyzer for simplification analysis.
* \param a The first set.
* \param b The second set.
* \return The result set.
*/
TVM_DLL IntervalSet Union(Analyzer* analyzer, IntervalSet a, IntervalSet b);
/*!
* \brief Create insersection of two IntervalSets.
* \param analzyer The analyzer for simplification analysis.
* \param a The first set.
* \param b The second set.
* \return The result set.
*/
TVM_DLL IntervalSet Intersect(Analyzer *analzyer, IntervalSet a, IntervalSet b);
} // namespace arith
} // namespace tvm
#endif // TVM_ARITHMETIC_INT_SET_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2017 by Contributors
* \file int_set_internal.h
* \brief Implementations of integer set
*/
#ifndef TVM_ARITHMETIC_INT_SET_INTERNAL_H_
#define TVM_ARITHMETIC_INT_SET_INTERNAL_H_
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/arithmetic.h>
namespace tvm {
namespace arith {
using HalideIR::Internal::Interval;
/*! \brief Set of continuous interval */
struct IntervalSet : public IntSetNode {
/*! \brief the internal interval*/
Interval i;
static IntSet make(Interval i) {
NodePtr<IntervalSet> n =
make_node<IntervalSet>();
n->i = i;
return IntSet(n);
}
static IntSet make(Expr min, Expr max) {
NodePtr<IntervalSet> n =
make_node<IntervalSet>();
n->i.min = min;
n->i.max = max;
return IntSet(n);
}
static constexpr const char* _type_key = "IntervalSet";
TVM_DECLARE_NODE_TYPE_INFO(IntervalSet, IntSetNode);
};
/*!
* \brief set represented by strided integers
* Reserved for cases where strided access is supported.
*/
struct StrideSet : public IntSetNode {
/*! \brief the base inetrval */
Interval base;
/*! \brief additional extents in positive number */
Array<Expr> extents;
/*! \brief additional strides in positive number */
Array<Expr> strides;
static constexpr const char* _type_key = "StrideSet";
TVM_DECLARE_NODE_TYPE_INFO(StrideSet, IntSetNode);
};
} // namespace arith
} // namespace tvm
#endif // TVM_ARITHMETIC_INT_SET_INTERNAL_H_
...@@ -188,7 +188,15 @@ Expr operator%(Expr a, Expr b) { ...@@ -188,7 +188,15 @@ Expr operator%(Expr a, Expr b) {
return ir::Mod::make(a, b); return ir::Mod::make(a, b);
} }
Expr min(Expr a, Expr b) { Expr min(Expr a, Expr b) {
// inf-aware simplificaiton
using arith::is_pos_inf;
using arith::is_neg_inf;
if (is_pos_inf(a)) return b;
if (is_neg_inf(a)) return a;
if (is_pos_inf(b)) return a;
if (is_neg_inf(b)) return b;
BinaryOpMatchTypes(a, b); BinaryOpMatchTypes(a, b);
Expr ret = arith::TryConstFold<ir::Min>(a, b); Expr ret = arith::TryConstFold<ir::Min>(a, b);
if (ret.defined()) return ret; if (ret.defined()) return ret;
...@@ -196,6 +204,13 @@ Expr min(Expr a, Expr b) { ...@@ -196,6 +204,13 @@ Expr min(Expr a, Expr b) {
} }
Expr max(Expr a, Expr b) { Expr max(Expr a, Expr b) {
// inf-aware simplificaiton
using arith::is_pos_inf;
using arith::is_neg_inf;
if (is_pos_inf(a)) return a;
if (is_neg_inf(a)) return b;
if (is_pos_inf(b)) return b;
if (is_neg_inf(b)) return a;
BinaryOpMatchTypes(a, b); BinaryOpMatchTypes(a, b);
Expr ret = arith::TryConstFold<ir::Max>(a, b); Expr ret = arith::TryConstFold<ir::Max>(a, b);
if (ret.defined()) return ret; if (ret.defined()) return ret;
......
...@@ -28,7 +28,7 @@ ...@@ -28,7 +28,7 @@
#include <tvm/arithmetic.h> #include <tvm/arithmetic.h>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "../arithmetic/int_set_internal.h" #include "../arithmetic/int_set.h"
#include "../runtime/thread_storage_scope.h" #include "../runtime/thread_storage_scope.h"
namespace tvm { namespace tvm {
...@@ -366,7 +366,7 @@ class LoopPartitioner : public IRMutator { ...@@ -366,7 +366,7 @@ class LoopPartitioner : public IRMutator {
std::pair<IntSet, std::unordered_set<const Node*>> std::pair<IntSet, std::unordered_set<const Node*>>
GetIntervalAndCondset(const Partition &partitions, GetIntervalAndCondset(const Partition &partitions,
const arith::Interval &for_interval, const arith::IntervalSet &for_interval,
bool cond_value); bool cond_value);
inline Stmt MakeFor(const Node* op, Expr extent, Stmt body); inline Stmt MakeFor(const Node* op, Expr extent, Stmt body);
...@@ -374,6 +374,7 @@ class LoopPartitioner : public IRMutator { ...@@ -374,6 +374,7 @@ class LoopPartitioner : public IRMutator {
/* Candidate IRs that may be partitioned potentially */ /* Candidate IRs that may be partitioned potentially */
std::unordered_map<const Variable*, IntSet> hint_map_; std::unordered_map<const Variable*, IntSet> hint_map_;
std::unordered_map<const Variable*, IntSet> relax_map_; std::unordered_map<const Variable*, IntSet> relax_map_;
arith::Analyzer analyzer_;
CandidateSelector selector; CandidateSelector selector;
}; };
...@@ -381,16 +382,17 @@ class LoopPartitioner : public IRMutator { ...@@ -381,16 +382,17 @@ class LoopPartitioner : public IRMutator {
// given in the second component provably have value given by cond_value // given in the second component provably have value given by cond_value
std::pair<IntSet, std::unordered_set<const Node*>> std::pair<IntSet, std::unordered_set<const Node*>>
LoopPartitioner::GetIntervalAndCondset(const Partition &partitions, LoopPartitioner::GetIntervalAndCondset(const Partition &partitions,
const arith::Interval &for_interval, const arith::IntervalSet &for_interval,
bool cond_value) { bool cond_value) {
Array<IntSet> sets; Array<IntSet> sets;
std::unordered_set<const Node*> cond_set; std::unordered_set<const Node*> cond_set;
for (const auto &kv : partitions) { for (const auto &kv : partitions) {
if (kv.first.second == cond_value) { if (kv.first.second == cond_value) {
arith::Interval interval = kv.second.as<arith::IntervalSet>()->i; arith::IntervalSet interval = Downcast<arith::IntervalSet>(kv.second);
arith::Interval intersection = arith::Interval::make_intersection(interval, for_interval); arith::IntervalSet intersection = arith::Intersect(
if (!intersection.is_empty()) { &analyzer_, interval, for_interval);
if (!intersection->IsEmpty()) {
sets.push_back(kv.second); sets.push_back(kv.second);
cond_set.insert(kv.first.first); cond_set.insert(kv.first.first);
} }
...@@ -463,11 +465,12 @@ Stmt LoopPartitioner::TryPartition(const Node* node, ...@@ -463,11 +465,12 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
Expr max, Expr max,
Stmt body, Stmt body,
bool partition_thread_scope) { bool partition_thread_scope) {
using namespace arith;
PartitionFinder finder(var, hint_map_, relax_map_); PartitionFinder finder(var, hint_map_, relax_map_);
finder.Visit(body); finder.Visit(body);
if (finder.partitions.empty()) return Stmt(); if (finder.partitions.empty()) return Stmt();
arith::Interval for_interval(min, max); arith::IntervalSet for_interval(min, max);
bool cond_value; bool cond_value;
IntSet middle_interval; IntSet middle_interval;
std::unordered_set<const Node*> cond_set; std::unordered_set<const Node*> cond_set;
...@@ -488,7 +491,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node, ...@@ -488,7 +491,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
cond_value = true; cond_value = true;
} }
arith::Interval middle_interval_i = middle_interval.as<arith::IntervalSet>()->i; IntervalSet middle_interval_i = Downcast<IntervalSet>(middle_interval);
// middle_interval is the subrange of the loop variable range for which a // middle_interval is the subrange of the loop variable range for which a
// set of conditions are true (or false resp.) // set of conditions are true (or false resp.)
// The part of the loop variable range that is before (after resp.) that // The part of the loop variable range that is before (after resp.) that
...@@ -499,7 +502,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node, ...@@ -499,7 +502,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
Expr body_begin; Expr body_begin;
Stmt pre_stmt; Stmt pre_stmt;
bool pre_stmt_recurse = true; bool pre_stmt_recurse = true;
if (middle_interval_i.has_lower_bound()) { if (middle_interval_i->HasLowerBound()) {
body_begin = ir::Simplify(middle_interval.min()); body_begin = ir::Simplify(middle_interval.min());
if (!can_prove(body_begin == min)) { if (!can_prove(body_begin == min)) {
Expr cond = (body_begin - min >= 0); Expr cond = (body_begin - min >= 0);
...@@ -524,7 +527,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node, ...@@ -524,7 +527,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
Expr post_doubt_begin; Expr post_doubt_begin;
Stmt post_stmt; Stmt post_stmt;
bool post_stmt_recurse = true; bool post_stmt_recurse = true;
if (middle_interval_i.has_upper_bound()) { if (middle_interval_i->HasUpperBound()) {
post_doubt_begin = ir::Simplify(middle_interval.max() + 1); post_doubt_begin = ir::Simplify(middle_interval.max() + 1);
if (!can_prove(middle_interval.max() == max)) { if (!can_prove(middle_interval.max() == max)) {
// require the extent to be non-negative // require the extent to be non-negative
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
def test_deduce():
a = tvm.var('a')
b = tvm.var('b')
c = tvm.var('c')
d = tvm.var('d')
b_s = tvm.arith.IntervalSet(2, 3)
c_s = tvm.arith.IntervalSet(10, 15)
d_s = tvm.arith.IntervalSet(-3, -1)
zero = tvm.const(0, "int32")
e0 = (-b)*a+c-d
res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
ans0 = ((d - c) /(b*-1))
assert str(tvm.ir_pass.Simplify(res0.max_value)) == str(ans0)
# expression containing variable a is on rhs
res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {})
assert str(tvm.ir_pass.Simplify(res0.max_value)) == str(ans0)
e0 = d*a+c-d
res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
ans0 = ((0-c)/d + 1)
assert str(tvm.ir_pass.Simplify(res0.max_value)) == str(ans0)
# expression containing variable a is on rhs
res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {})
assert str(tvm.ir_pass.Simplify(res0.max_value)) == str(ans0)
e1 = (a*4+b < c)
res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
ans1 = (((c - b) + -1)/4)
assert str(tvm.ir_pass.Simplify(res1.max_value)) == str(ans1)
# expression containing variable a is on rhs
e1 = (c > a*4+b)
res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
assert str(tvm.ir_pass.Simplify(res1.max_value)) == str(ans1)
e2 = (tvm.max(5, a * 4) < 0)
res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {})
assert str(res2.max_value) == "neg_inf"
assert str(res2.min_value) == "pos_inf"
# expression containing variable a is on rhs
e2 = (zero < tvm.max(5, a * 4))
res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {})
assert str(res2.max_value) == "neg_inf"
assert str(res2.min_value) == "pos_inf"
e3 = (-b)+a*c-d
res3 = tvm.arith.DeduceBound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
ans3 = 2/c+1
assert str(tvm.ir_pass.Simplify(res3.min_value)) == str(ans3)
res3 = tvm.arith.DeduceBound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
assert str(tvm.ir_pass.Simplify(res3.min_value)) == str(ans3)
def test_check():
a = tvm.var('a')
b = tvm.var('b')
c = tvm.var('c')
d = tvm.var('d')
b_s = tvm.arith.IntervalSet(2, 3)
c_s = tvm.arith.IntervalSet(5, 7)
d_s = tvm.arith.IntervalSet(-3, -1)
# no compare operator
res1 = tvm.arith.DeduceBound(a, a+b, {b: b_s}, {})
assert res1.is_nothing()
# multiple compare operators
res2 = tvm.arith.DeduceBound(a, (a+b>3).astype(c.dtype)>c , {b: b_s, c: c_s}, {})
assert res2.is_nothing()
# multiple target variable
res2 = tvm.arith.DeduceBound(a, a*2-a>b, {b: b_s}, {})
assert res2.is_nothing()
def test_deduce_basic():
def test_basic(a1, a2, coff):
a = tvm.var('a')
b = tvm.var('b')
b_s = tvm.arith.IntervalSet(a1, a2)
e0 = b + a*coff + 3
res1 = tvm.arith.DeduceBound(a, e0<17, {b: b_s}, {b: b_s})
[x, y] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value]
assert (tvm.ir_pass.Simplify((x * coff + 3 + y) < 17)).value == 1
# expression containing variable a is on rhs
res1 = tvm.arith.DeduceBound(a, tvm.const(17, "int32") < e0, {b: b_s}, {b: b_s})
[x, y] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value]
assert (tvm.ir_pass.Simplify((x * coff + 3 + y) > 17)).value == 1
# expression containing variable a is on rhs
res1 = tvm.arith.DeduceBound(a, tvm.const(17, "int32")>= e0, {b: b_s}, {b: b_s})
[x, y] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value]
assert (tvm.ir_pass.Simplify((x * coff + 3 + y) <= 17)).value == 1
res1 = tvm.arith.DeduceBound(a, e0>=17, {b: b_s}, {b: b_s})
[x, y] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value]
assert (tvm.ir_pass.Simplify((x * coff + 3 + y) >= 17)).value == 1
test_basic(0, 4, 4)
test_basic(1, 5, 4)
test_basic(2, 6, 4)
test_basic(0, 4, -4)
test_basic(1, 5, -4)
test_basic(2, 6, -4)
def test_deduce_complex():
def test_complex(a1, a2, coff):
a = tvm.var('a')
b = tvm.var('b')
b_s = tvm.arith.IntervalSet(a1, a2)
e0 = (b*3 + a* coff) * 4
res1 = tvm.arith.DeduceBound(a, e0<63, {b: b_s}, {b: b_s})
[t, x] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value]
assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) < 63)).value == 1
# expression containing variable a is on rhs
res1 = tvm.arith.DeduceBound(a, tvm.const(63, "int32")>= e0, {b: b_s}, {b: b_s})
[t, x] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value]
assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) <= 63)).value == 1
res1 = tvm.arith.DeduceBound(a, e0>63, {b: b_s}, {b: b_s})
[t, x] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value]
assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) > 63)).value == 1
# expression containing variable a is on rhs
res1 = tvm.arith.DeduceBound(a, tvm.const(63, "int32") <= e0, {b: b_s}, {b: b_s})
[t, x] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value]
assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) >= 63)).value == 1
test_complex(0, 4, 4)
test_complex(0, 4, -4)
test_complex(2, 6, 4)
test_complex(0, 4, -4)
test_complex(1, 5, -4)
test_complex(2, 6, -4)
if __name__ == "__main__":
test_check()
test_deduce_basic()
test_deduce_complex()
...@@ -16,168 +16,87 @@ ...@@ -16,168 +16,87 @@
# under the License. # under the License.
import tvm import tvm
class IntSetChecker:
def __init__(self):
self.analyzer = tvm.arith.Analyzer()
def verify(self, data, dmap, expected):
res = self.analyzer.int_set(data, dmap)
def err_msg():
return "\ndata={}\ndmap={}\nres={}\nexpected={}".format(data, dmap, res, expected)
def equal(x, y):
res = self.analyzer.canonical_simplify(x - y)
return tvm.ir_pass.Equal(res, 0)
assert equal(res.min_value, expected[0]), err_msg()
assert equal(res.max_value, expected[1]), err_msg()
def test_basic(): def test_basic():
s = tvm.arith.intset_interval(2, 3) s = tvm.arith.IntervalSet(2, 3)
assert s.min().value == 2 assert s.min_value.value == 2
assert s.max().value == 3 assert s.max_value.value == 3
def test_vector(): def test_vector():
base = 10 base = 10
stride = 3 stride = 3
lanes = 2 lanes = 2
s = tvm.arith.intset_vector(tvm.make.Ramp(base, stride, lanes)) s = tvm.arith.intset_vector(tvm.make.Ramp(base, stride, lanes))
assert s.min().value == base assert s.min_value.value == base
assert s.max().value == base + stride * lanes - 1 assert s.max_value.value == base + stride * lanes - 1
def test_deduce():
a = tvm.var('a') def test_add_sub():
b = tvm.var('b') ck = IntSetChecker()
c = tvm.var('c') x, y = tvm.var("x"), tvm.var("y")
d = tvm.var('d') ck.verify(x + y, {x : tvm.arith.IntervalSet(0, 10)}, (y, 10 + y))
ck.verify(x + y,
b_s = tvm.arith.intset_interval(2, 3) {x : tvm.arith.IntervalSet(0, 10), y : tvm.arith.IntervalSet(1, 11)},
c_s = tvm.arith.intset_interval(10, 15) (1, 21))
d_s = tvm.arith.intset_interval(-3, -1) ck.verify(x - y,
zero = tvm.const(0, "int32") {x : tvm.arith.IntervalSet(0, 10), y : tvm.arith.IntervalSet(1, 11)},
(-11, 9))
e0 = (-b)*a+c-d
res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {}) def test_mul_div():
ans0 = ((d - c) /(b*-1)) ck = IntSetChecker()
assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0) x, y = tvm.var("x"), tvm.var("y")
ck.analyzer.update(y, tvm.arith.ConstIntBound(1, 100), override=True)
# expression containing variable a is on rhs ck.verify(x * y, {x : tvm.arith.IntervalSet(0, 10)}, (0, 10 * y))
res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {}) ck.verify(x * 2, {x : tvm.arith.IntervalSet(1, 10)}, (2, 20))
assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0) ck.verify(x * -2, {x : tvm.arith.IntervalSet(1, 10)}, (-20, -2))
ck.verify(x / y, {x : tvm.arith.IntervalSet(0, 10)}, (0, 10 / y))
e0 = d*a+c-d ck.verify(x / 2, {x : tvm.arith.IntervalSet(1, 10)}, (0, 5))
res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
ans0 = ((0-c)/d + 1)
assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0) def test_mod():
ck = IntSetChecker()
# expression containing variable a is on rhs x, y = tvm.var("x"), tvm.var("y")
res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {}) ck.analyzer.update(y, tvm.arith.ConstIntBound(1, 100), override=True)
assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0) ck.verify(x % y, {x : tvm.arith.IntervalSet(0, 10)}, (0, y - 1))
ck.verify(x % 10, {x : tvm.arith.IntervalSet(1, 10)}, (0, 9))
e1 = (a*4+b < c)
res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {}) def test_max_min():
ans1 = (((c - b) + -1)/4) ck = IntSetChecker()
assert str(tvm.ir_pass.Simplify(res1.max())) == str(ans1) x, y = tvm.var("x"), tvm.var("y")
ck.verify(tvm.max(x, x + 1), {x : tvm.arith.IntervalSet(0, 10)}, (1, 11))
# expression containing variable a is on rhs ck.verify(tvm.min(x - 1, x + 1), {x : tvm.arith.IntervalSet(0, 10)}, (-1, 9))
e1 = (c > a*4+b) ck.verify(tvm.min(x, y), {}, (tvm.min(x, y), tvm.min(x, y)))
res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {}) ck.verify(tvm.max(x, y), {}, (tvm.max(x, y), tvm.max(x, y)))
assert str(tvm.ir_pass.Simplify(res1.max())) == str(ans1)
e2 = (tvm.max(5, a * 4) < 0) def test_select():
res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {}) ck = IntSetChecker()
assert str(res2.max()) == "neg_inf" x, y = tvm.var("x"), tvm.var("y")
assert str(res2.min()) == "pos_inf" ck.verify(tvm.expr.Select(x > 0, x - 1, x + 1),
{x : tvm.arith.IntervalSet(0, 10)}, (-1, 11))
# expression containing variable a is on rhs
e2 = (zero < tvm.max(5, a * 4))
res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {})
assert str(res2.max()) == "neg_inf"
assert str(res2.min()) == "pos_inf"
e3 = (-b)+a*c-d
res3 = tvm.arith.DeduceBound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
ans3 = 2/c+1
assert str(tvm.ir_pass.Simplify(res3.min())) == str(ans3)
res3 = tvm.arith.DeduceBound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
assert str(tvm.ir_pass.Simplify(res3.min())) == str(ans3)
def test_check():
a = tvm.var('a')
b = tvm.var('b')
c = tvm.var('c')
d = tvm.var('d')
b_s = tvm.arith.intset_interval(2, 3)
c_s = tvm.arith.intset_interval(5, 7)
d_s = tvm.arith.intset_interval(-3, -1)
# no compare operator
res1 = tvm.arith.DeduceBound(a, a+b, {b: b_s}, {})
assert res1.is_nothing()
# multiple compare operators
res2 = tvm.arith.DeduceBound(a, (a+b>3).astype(c.dtype)>c , {b: b_s, c: c_s}, {})
assert res2.is_nothing()
# multiple target variable
res2 = tvm.arith.DeduceBound(a, a*2-a>b, {b: b_s}, {})
assert res2.is_nothing()
def test_deduce_basic():
def test_basic(a1, a2, coff):
a = tvm.var('a')
b = tvm.var('b')
b_s = tvm.arith.intset_interval(a1, a2)
e0 = b + a*coff + 3
res1 = tvm.arith.DeduceBound(a, e0<17, {b: b_s}, {b: b_s})
[x, y] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()]
assert (tvm.ir_pass.Simplify((x * coff + 3 + y) < 17)).value == 1
# expression containing variable a is on rhs
res1 = tvm.arith.DeduceBound(a, tvm.const(17, "int32") < e0, {b: b_s}, {b: b_s})
[x, y] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()]
assert (tvm.ir_pass.Simplify((x * coff + 3 + y) > 17)).value == 1
# expression containing variable a is on rhs
res1 = tvm.arith.DeduceBound(a, tvm.const(17, "int32")>= e0, {b: b_s}, {b: b_s})
[x, y] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()]
assert (tvm.ir_pass.Simplify((x * coff + 3 + y) <= 17)).value == 1
res1 = tvm.arith.DeduceBound(a, e0>=17, {b: b_s}, {b: b_s})
[x, y] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()]
assert (tvm.ir_pass.Simplify((x * coff + 3 + y) >= 17)).value == 1
test_basic(0, 4, 4)
test_basic(1, 5, 4)
test_basic(2, 6, 4)
test_basic(0, 4, -4)
test_basic(1, 5, -4)
test_basic(2, 6, -4)
def test_deduce_complex():
def test_complex(a1, a2, coff):
a = tvm.var('a')
b = tvm.var('b')
b_s = tvm.arith.intset_interval(a1, a2)
e0 = (b*3 + a* coff) * 4
res1 = tvm.arith.DeduceBound(a, e0<63, {b: b_s}, {b: b_s})
[t, x] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()]
assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) < 63)).value == 1
# expression containing variable a is on rhs
res1 = tvm.arith.DeduceBound(a, tvm.const(63, "int32")>= e0, {b: b_s}, {b: b_s})
[t, x] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()]
assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) <= 63)).value == 1
res1 = tvm.arith.DeduceBound(a, e0>63, {b: b_s}, {b: b_s})
[t, x] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()]
assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) > 63)).value == 1
# expression containing variable a is on rhs
res1 = tvm.arith.DeduceBound(a, tvm.const(63, "int32") <= e0, {b: b_s}, {b: b_s})
[t, x] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()]
assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) >= 63)).value == 1
test_complex(0, 4, 4)
test_complex(0, 4, -4)
test_complex(2, 6, 4)
test_complex(0, 4, -4)
test_complex(1, 5, -4)
test_complex(2, 6, -4)
if __name__ == "__main__": if __name__ == "__main__":
test_basic() test_basic()
test_vector() test_vector()
test_deduce() test_add_sub()
test_check() test_mul_div()
test_deduce_basic() test_max_min()
test_deduce_complex() test_select()
test_mod()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment