arithmetic.h 10.5 KB
Newer Older
1 2
/*!
 *  Copyright (c) 2016 by Contributors
tqchen committed
3
 * \file tvm/arithmetic.h
4
 * \brief Algebra and set operations and simplifications.
5
 */
6 7
#ifndef TVM_ARITHMETIC_H_
#define TVM_ARITHMETIC_H_
8

9
#include <vector>
10 11
#include <unordered_map>
#include <memory>
12
#include "expr.h"
13 14

namespace tvm {
15 16 17

class Tensor;

18
/*! \brief namespace of arithmetic */
19
namespace arith {
20 21 22
/*!
 * \brief Sign of an expression or set.
 */
23 24 25 26 27 28 29
enum SignType {
  kPositive,
  kNegative,
  kZero,
  kUnknown
};

30
// internal node container of int set.
31
struct IntSetNode;
32

33
/*!
34
 * \brief Integer set class, represent a set of integers in one dimension.
35
 */
36
class IntSet : public NodeRef {
37
 public:
38 39
  /*! \brief constructor */
  IntSet() {}
40
  // constructor from not container.
41
  explicit IntSet(NodePtr<Node> n) : NodeRef(n) {}
42
  /*!
43 44 45 46 47
   * \brief access the internal node container
   * \return the pointer to the internal node container
   */
  inline const IntSetNode* operator->() const;
  /*!
48 49 50 51 52 53 54 55
   * \brief Find a range that covers the region.
   * \param max_range The range to be covered.
   * \return The covering range.
   */
  Range cover_range(Range max_range) const;
  /*!
   * \brief find an interval that covers the set.
   * \return The covering interval set.
56
   */
57
  IntSet cover_interval() const;
58 59 60 61 62 63
  /*! \return Lower bound of the set */
  Expr min() const;
  /*! \return upper bound of the set */
  Expr max() const;
  /*! \return Whether the set represent nothing  */
  bool is_nothing() const;
64 65 66 67
  /*! \return Whether the set represent everything  */
  bool is_everything() const;
  /*! \return Whether the set is a single point */
  bool is_single_point() const;
68 69
  /*! \return Whether the set is proved to be bigger than 0 */
  bool can_prove_positive() const;
70 71
  /*! \return Whether the set is proved to be smaller than 0 */
  bool can_prove_negative() const;
72 73 74 75
  /*! \return Whether the set is proved to be smaller than or equal to 0 */
  bool can_prove_non_positive() const;
  /*! \return Whether the set is proved to be larger than or equal to 0 */
  bool can_prove_non_negative() const;
76 77
  /*! \return The sign of the elements in the integer set */
  SignType sign_type() const;
78 79 80 81 82 83 84 85 86 87 88 89
  /*!
   * \brief The single point value, call only if is_single_point is true
   * \return The point value.
   */
  Expr point_value() const;
  /*!
   * \brief Try to match IntSet with range r.
   *
   * \note It is guanrateed that IntSet::range(r).match_range(r) == true
   * \return true if we can prove they are the same.
   */
  bool match_range(const Range& r) const;
90 91 92
  /*! \return The set contains nothing */
  static IntSet nothing();
  /*! \return The set contains everything */
93
  static IntSet everything();
94
  /*!
95 96 97
   * \brief construct a point set.
   * \param point The point in the set.
   * \return construct a single point set
98
   */
99
  static IntSet single_point(Expr point);
100
  /*!
101 102 103 104 105 106
   * \brief construct a integer set from vector expression.
   * \param vec The vector expression, can also be single point.
   * \return The result set containing the indices in the vector.
   */
  static IntSet vector(Expr vec);
  /*!
107 108 109
   * \brief Construct a set representing a range.
   * \param r The range
   * \return constructed set.
110
   */
111
  static IntSet range(Range r);
112 113 114 115 116 117 118
  /*!
   * \brief Construct a set representing a interval.
   * \param min The minimum value of the interval.
   * \param max The maximum value of the interval.
   * \return constructed set.
   */
  static IntSet interval(Expr min, Expr max);
119 120 121
};

/*!
122 123 124
 * \brief Range of a linear integer function.
 *  Use to do specify the possible index values.
 *
125
 *  set = { coeff * x + base | x in Z }
126 127 128 129 130 131 132 133 134
 *
 *  When coeff != 0, it can also be written as
 *  set = { n | n % coeff == base }
 *
 *  This is useful to decide if the index is dividable by certain value.
 *  For example, if index = 0 + 4 x, then we know it can be divided by 4.
 */
struct ModularEntry {
  /*! \brief linear co-efficient */
135
  int coeff{1};
136 137
  /*! \brief The base */
  int base{0};
138 139 140 141 142

  /*! \return entry represent everything */
  static ModularEntry everything() {
    // always safe to set 0 + x, so it can be everything.
    ModularEntry e;
143 144
    e.coeff = 1;
    e.base = 0;
145 146 147 148 149 150 151 152 153 154 155 156 157
    return e;
  }
  /*!
   * \brief Add two modular entries together to get a new modular entry.
   * \param a The left operand.
   * \param b The right operand.
   * \return The combined modular entry.
   */
  static ModularEntry Add(const ModularEntry& a,
                          const ModularEntry& b);
};

/*!
158 159 160
 * \brief Base class of all IntSet containers.
 */
struct IntSetNode : public Node {
161 162
  static constexpr const char* _type_key = "IntSet";
  TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Node);
163 164
};

165
/*!
166 167
 * \brief Detect if e can be rewritten as e = sum_{i=0}^{n-1} var[i] * coeff[i] + coeff[n]
 *  Where coeff[i] and base are invariant of var[j] for all i and j.
168
 *
169 170 171 172 173 174 175 176 177 178 179 180 181
 * \param e The expression to be detected.
 * \param vars List of variables to be used in detection.
 * \return [coeff[i]] if it is possible, empty array if it is not.
 */
Array<Expr> DetectLinearEquation(const Expr& e, const Array<Var>& vars);

/*!
 * \brief Detect if expression corresponds to clip bound of the vars
 *
 * \param e The expression to be detected.
 * \param vars List of variables to be used in detection.
 * \return concat([min_value[i], max_value[i]]), None is returned if there is no min or max value
 *          return empty if the e does not match the pattern.
182
 */
183
Array<Expr> DetectClipBound(const Expr& e, const Array<Var>& vars);
184

185 186 187 188 189 190 191 192
/*!
 * \brief Find an symbolic integer set that contains all possible values of
 *  e given the domain of each iteration variables.
 *
 * \param e The expression to be evaluated.
 * \param dom_map The domain of each variable.
 * \return An integer set that can cover all the possible values of e.
 */
193 194
IntSet EvalSet(Expr e,
               const Map<IterVar, IntSet>& dom_map);
195 196 197 198 199 200 201
/*!
 * \brief Same as EvalSet, but takes unordered_map
 *
 * \param e The expression to be evaluated.
 * \param dom_map The domain of each variable.
 * \return An integer set that can cover all the possible values of e.
 */
202 203
IntSet EvalSet(Expr e,
               const std::unordered_map<const Variable*, IntSet>& dom_map);
204 205 206 207 208 209 210 211 212 213 214

/*!
 * \brief Find an symbolic integer set that contains is union over
 *  all the possible conditional values in dom_map.
 *
 * \param r The initial range.
 * \param dom_map The domain of each variable.
 * \return An integer set that can cover all the possible values.
 */
IntSet EvalSet(Range r,
               const Map<IterVar, IntSet>& dom_map);
215 216 217 218 219 220 221 222 223 224 225

/*!
 * \brief Find an symbolic integer set that contains is union over
 *  all the possible conditional values in dom_map.
 *
 * \param s The initial set.
 * \param dom_map The domain of each variable.
 * \return An integer set that can cover all the possible values.
 */
IntSet EvalSet(IntSet s,
               const std::unordered_map<const Variable*, IntSet>& dom_map);
226 227 228 229 230 231 232
/*!
 * \brief Same as EvalSet, but takes unordered_map
 *
 * \param r The range to be evaluated.
 * \param dom_map The domain of each variable.
 * \return An integer set that can cover all the possible values of e.
 */
233 234 235
IntSet EvalSet(Range r,
               const std::unordered_map<const Variable*, IntSet>& dom_map);

236 237
/*! \brief Map from Expr to IntSet */
using ExprIntSetMap = std::unordered_map<Expr, IntSet, ExprHash, ExprEqual>;
238 239 240 241 242 243 244 245
/*!
 * \brief Find the integer set of every sub-expression, given the
 *  domain of each iteration variables.
 *
 * \param e The expression to be evaluated.
 * \param dom_map The domain of each variable.
 * \return the map from the expression to its possible value.
 */
246 247
ExprIntSetMap EvalSetForEachSubExpr(
    Expr e,
248 249
    const std::unordered_map<const Variable*, IntSet>& dom_map);

250 251 252 253 254 255 256
/*!
 * \brief Create an union set of all sets
 * \param sets The sets to be unioned
 * \return the set after union
 */
IntSet Union(const Array<IntSet>& sets);

257 258 259 260 261 262 263
/*!
 * \brief Create an union set of all sets
 * \param sets The sets to be intersected
 * \return the set after intersected
 */
IntSet Intersect(const Array<IntSet>& sets);

264 265 266 267 268 269 270
/*!
 * \brief Deduce the bound of the target variable in a expression,
 *  give the domain of each variables. Return undefined IntSet to
 *  represent failure.
 *
 * \param v The target variable to be deduced.
 * \param cond The conditional expression.
271
 * \param hint_map The domain of variable, used to help deduce.
272 273
 * \param relax_map The domain of each variable, used to relax the domain,
 *        The deduce bound mush implies e for all value in relax_map
274 275
 * \return An integer set that can cover all the possible values.
 */
276 277 278
IntSet DeduceBound(Expr v, Expr cond,
                   const Map<Var, IntSet>& hint_map,
                   const Map<Var, IntSet>& relax_map);
279 280 281 282 283 284 285 286 287 288 289 290 291
/*!
 * \brief Same as DeduceBound with  unordered_map signature.
 *
 * \param v The target variable to be deduced.
 * \param cond The conditional expression.
 * \param hint_map The domain of variable, used to help deduce.
 * \param relax_map The domain of each variable, used to relax the domain,
 *        The deduce bound mush implies e for all value in relax_map
 * \return An integer set that can cover all the possible values.
 */
IntSet DeduceBound(Expr v, Expr cond,
                   const std::unordered_map<const Variable*, IntSet>& hint_map,
                   const std::unordered_map<const Variable*, IntSet>& relax_map);
292

293
/*!
294 295 296 297 298 299 300 301 302 303
 * \brief Infer a regular domain that covers all the calls or provides within the given statement.
 * \param body The given statement.
 * \param tensor The name of the calls or provides.
 * \param consider_calls If calls (read) are considered.
 * \param consider_provides If provides (write) are considered.
 * \return The domain that covers all the calls or provides within the given statement.
 */
Domain DomainTouched(Stmt body, const Tensor &tensor, bool consider_calls, bool consider_provides);

/*!
304 305 306 307 308 309 310 311
 * \brief Evaluate the expression with modular analysis
 * \param e The expression to be evaluated.
 * \param mod_map Map of modular statistics of known variables.
 * \return The ModularEntry covering all possible value of e.
 */
ModularEntry EvalModular(
    const Expr& e,
    const std::unordered_map<const Variable*, ModularEntry>& mod_map);
312

313 314 315 316 317 318 319 320 321 322 323 324
/*!
 * \brief Same as EvalModular, used by front-end.
 * \param e The expression to be evaluated.
 * \param mod_map Map of modular statistics of known variables.
 * \return A ModularSet covering all possible value of e.
 */
IntSet EvalModular(const Expr& e,
                   const Map<Var, IntSet>& mod_map);
// implementation
inline const IntSetNode* IntSet::operator->() const {
  return static_cast<const IntSetNode*>(node_.get());
}
325
}  // namespace arith
326
}  // namespace tvm
327
#endif  // TVM_ARITHMETIC_H_