analyzer.h 12.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20
/*!
21 22
 * \file tvm/arith/analyzer.h
 * \brief Algebra expression simplifications.
23
 */
24 25
#ifndef TVM_ARITH_ANALYZER_H_
#define TVM_ARITH_ANALYZER_H_
26

27
#include <tvm/support/with.h>
28 29
#include <tvm/ir/expr.h>
#include <tvm/arith/int_set.h>
30

31
#include <vector>
32 33
#include <unordered_map>
#include <memory>
34
#include <limits>
35 36

namespace tvm {
37
/*! \brief namespace of arithmetic analysis. */
38
namespace arith {
39 40 41 42 43 44 45 46 47 48 49 50 51
//-------------------------------------------------------
// Base integer analysis API.
//
// We have multiple type of analyzers to do relaxed
// integer set analysis(bound analysis, modulo) and
// equivalence checking and simplification.
//
// Importantly, each analyzer may need result from
// another analyzer.
//-------------------------------------------------------

// Forward declare Analyzer
class Analyzer;
52

53 54
using tir::Var;

55 56 57 58 59 60
/*!
 * \brief Constant integer up and lower bound(inclusive).
 *  Useful for value bound analysis.
 *
 *  set = [min_value, max_value]
 */
61
class ConstIntBoundNode : public Object {
62 63 64 65
 public:
  int64_t min_value;
  int64_t max_value;

66
  void VisitAttrs(tvm::AttrVisitor* v) {
67 68 69 70
    v->Visit("min_value", &min_value);
    v->Visit("max_value", &max_value);
  }

71 72 73 74
  bool SEqualReduce(const ConstIntBoundNode* other, SEqualReducer equal) const {
    return equal(min_value, other->min_value) && equal(max_value, other->max_value);
  }

75 76 77 78 79 80 81 82 83
  /*! \brief Number to represent +inf */
  static const constexpr int64_t kPosInf = std::numeric_limits<int64_t>::max();
  /*!
   * \brief Number to represent -inf
   * \note We can make use the of fact that -kPosInf == kNegInf in the project.
   */
  static const constexpr int64_t kNegInf = -kPosInf;

  static constexpr const char* _type_key = "arith.ConstIntBound";
84
  TVM_DECLARE_FINAL_OBJECT_INFO(ConstIntBoundNode, Object);
85 86
};

87 88 89 90
/*!
 * \brief reference class to ConstIntBoundNode
 * \sa ConstIntBoundNode
 */
91
class ConstIntBound : public ObjectRef {
92 93 94 95 96 97 98 99 100 101
 public:
  /*!
   * \brief constructor by fields.
   * \param min_value The mininum value.
   * \param max_value The maximum value.
   */
  TVM_DLL ConstIntBound(int64_t min_value, int64_t max_value);

  static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf;
  static const constexpr int64_t kNegInf = ConstIntBoundNode::kNegInf;
102
  TVM_DEFINE_OBJECT_REF_METHODS(ConstIntBound, ObjectRef, ConstIntBoundNode);
103
};
104 105 106 107 108 109 110 111 112 113 114

/*!
 * \brief Analyzer to get constant integer bound over expression.
 */
class ConstIntBoundAnalyzer {
 public:
  /*!
   * \brief analyze the expr
   * \param expr The expression of interest.
   * \return the result of the analysis.
   */
115
  ConstIntBound operator()(const PrimExpr& expr);
116 117

  /*!
118 119 120 121 122 123 124 125 126
   * \brief analyze the expr with the intermediate memorized to avoid redundant computation
   * \param expr The expression of interest.
   * \param bound The lookup table to store the intermediate results
   * \return the result of the analysis.
   */
  ConstIntBound operator()(const PrimExpr& expr,
                           std::unordered_map<const PrimExprNode*, ConstIntBound>* bound);

  /*!
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
   * \brief Update constant int bound information of var.
   *
   * \param var The variable of interest.
   * \param info The bound information.
   * \param override Whether do we allow override of existing information.
   */
  void Update(const Var& var,
              const ConstIntBound& info,
              bool override = false);
  /*!
   * \brief Bind variable to a range.
   *
   * \param var The variable.
   * \param range The range we bind to.
   */
  void Bind(const Var& var, const Range& range);

 private:
  friend class Analyzer;
  friend class ConstraintContext;
  explicit ConstIntBoundAnalyzer(Analyzer* parent);
  ~ConstIntBoundAnalyzer();
  /*!
   * \brief Update the internal state to enter constraint.
   * \param constraint A constraint expression.
   *
   * \return an exit function that must be called to cleanup the constraint can be nullptr.
   */
155
  std::function<void()> EnterConstraint(const PrimExpr& constraint);
156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
  struct Entry;
  class Impl;
  /*! \brief Internal impl */
  Impl* impl_;
};

/*!
 * \brief Range of a linear integer function.
 *  Use to do specify the possible index values.
 *
 *  set = { coeff * x + base | x in Z }
 *
 *  When coeff != 0, it can also be written as
 *  set = { n | n % coeff == base }
 *
 *  This is useful to decide if the index is dividable by certain value.
 *  For example, if index = 0 + 4 x, then we know it can be divided by 4.
 */
174
class ModularSetNode : public Object {
175 176 177 178 179 180
 public:
  /*! \brief linear co-efficient */
  int64_t coeff;
  /*! \brief The base */
  int64_t base;

181
  void VisitAttrs(tvm::AttrVisitor* v) {
182 183 184 185
    v->Visit("coeff", &coeff);
    v->Visit("base", &base);
  }

186 187 188 189
  bool SEqualReduce(const ModularSetNode* other, SEqualReducer equal) const {
    return equal(coeff, other->coeff) && equal(base, other->base);
  }

190
  static constexpr const char* _type_key = "arith.ModularSet";
191
  TVM_DECLARE_FINAL_OBJECT_INFO(ModularSetNode, Object);
192 193
};

194 195 196 197
/*!
 * \brief reference of ModularSetNode
 * \sa ModularSetNode
 */
198
class ModularSet : public ObjectRef {
199 200 201
 public:
  TVM_DLL ModularSet(int64_t coeff, int64_t base);

202
  TVM_DEFINE_OBJECT_REF_METHODS(ModularSet, ObjectRef, ModularSetNode);
203
};
204 205 206 207 208 209 210 211 212 213 214

/*!
 * \brief Analyzer to get modular information over expression.
 */
class ModularSetAnalyzer {
 public:
  /*!
   * \brief analyze the expr
   * \param expr The expression of interest.
   * \return the result of the analysis.
   */
215
  ModularSet operator()(const PrimExpr& expr);
216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237
  /*!
   * \brief Update constant int bound information of var.
   *
   * \param var The variable of interest.
   * \param info The bound information.
   * \param override Whether do we allow override of existing information.
   */
  void Update(const Var& var,
              const ModularSet& info,
              bool override = false);

 private:
  friend class Analyzer;
  friend class ConstraintContext;
  explicit ModularSetAnalyzer(Analyzer* parent);
  ~ModularSetAnalyzer();
  /*!
   * \brief Update the internal state to enter constraint.
   * \param constraint A constraint expression.
   *
   * \return an exit function that must be called to cleanup the constraint can be nullptr.
   */
238
  std::function<void()> EnterConstraint(const PrimExpr& constraint);
239 240 241 242 243 244 245
  struct Entry;
  class Impl;
  /*! \brief Internal impl */
  Impl* impl_;
};

/*!
246 247 248 249 250 251 252 253 254
 * \brief Rewrite-rule based simplifier.
 */
class RewriteSimplifier {
 public:
  /*!
   * \brief analyze the expr
   * \param expr The expression of interest.
   * \return the result of the analysis.
   */
255
  PrimExpr operator()(const PrimExpr& expr);
256 257 258 259 260 261 262 263 264

  /*!
   * \brief Update binding of var to a new expression.
   *
   * \param var The variable of interest.
   * \param new_expr
   * \param override Whether do we allow override of existing information.
   */
  void Update(const Var& var,
265
              const PrimExpr& new_expr,
266 267
              bool override = false);

268
  std::function<void()> EnterConstraint(const PrimExpr& constraint);
269

270 271 272
 private:
  friend class Analyzer;
  friend class ConstraintContext;
273
  friend class CanonicalSimplifier;
274 275 276 277 278 279 280 281
  explicit RewriteSimplifier(Analyzer* parent);
  ~RewriteSimplifier();
  class Impl;
  /*! \brief Internal impl */
  Impl* impl_;
};

/*!
282 283 284 285 286 287 288 289 290
 * \brief Canonical-form based simplifier.
 */
class CanonicalSimplifier {
 public:
  /*!
   * \brief analyze the expr
   * \param expr The expression of interest.
   * \return the result of the analysis.
   */
291
  PrimExpr operator()(const PrimExpr& expr);
292 293 294 295 296 297 298 299 300

  /*!
   * \brief Update binding of var to a new expression.
   *
   * \param var The variable of interest.
   * \param new_expr
   * \param override Whether do we allow override of existing information.
   */
  void Update(const Var& var,
301
              const PrimExpr& new_expr,
302 303 304 305 306 307 308 309 310 311 312 313 314
              bool override = false);

 private:
  friend class Analyzer;
  friend class ConstraintContext;
  explicit CanonicalSimplifier(Analyzer* parent);
  ~CanonicalSimplifier();
  class Impl;
  /*! \brief Internal impl */
  Impl* impl_;
};

/*!
315
 * \brief Constraint context.
316 317 318 319 320 321
 *
 * \code
 *
 *  Var("x");
 *  arith::Analyzer analyzer;
 *  {
322
 *    With<arith::ConstraintContext> scope(&analyzer, x % 3 == 0);
323 324 325 326 327 328 329 330
 *    CHECK_EQ(analyzer.modular_set(x)->coeff, 3);
 *  }
 *  // constraint no longer in effect.
 *  CHECK_NE(analyzer.modular_set(x)->coeff, 3);
 *
 * \endcode
 */
class ConstraintContext {
331 332 333
 private:
  // declare friend to enable with.
  friend class With<ConstraintContext>;
334 335 336 337 338
  /*!
   * \brief Construct a constraint context.
   * \param analyzer The analyzer.
   * \param constraint The constraint to be applied.
   */
339
  ConstraintContext(Analyzer* analyzer, PrimExpr constraint)
340 341 342 343 344 345 346 347
      : analyzer_(analyzer), constraint_(constraint) {}
  // enter the scope.
  void EnterWithScope();
  // exit the scope.
  void ExitWithScope();
  /*! \brief The analyzer */
  Analyzer* analyzer_;
  /*! \brief The constraint */
348
  PrimExpr constraint_;
349 350 351 352
  /*! \brief function to be called in recovery */
  std::function<void()> exit_;
};

353
/*!
354
 * \brief Integer set analyzer.
355
 */
356 357 358 359 360 361 362 363 364 365
class IntSetAnalyzer {
 public:
  /*!
   * \brief Find a symbolic integer set that contains all possible values of
   *        expr given the domain of each variables.
   *
   * \param expr The expression of interest.
   * \param dom_map The domain map to indicate which variable to relax.
   * \return the result of the analysis.
   */
366
  IntSet operator()(const PrimExpr& expr, const Map<Var, IntSet>& dom_map);
367 368 369 370 371 372 373 374

 private:
  friend class Analyzer;
  explicit IntSetAnalyzer(Analyzer* parent);
  ~IntSetAnalyzer();
  class Impl;
  /*! \brief Internal impl */
  Impl* impl_;
375 376
};

377
/*!
378
 * \brief Analyzer that contains bunch of sub-analyzers.
379
 *
380 381
 * Each sub-analyzer can make use of another sub-analyzer
 * by weak reference of this.
382
 *
383 384 385
 * NOTE for sub-analyzer developers:
 * If the analyzer uses memoization, we need to clear the internal
 * cache when information about a Var has been overridden.
386
 */
387 388
class Analyzer {
 public:
389 390 391 392 393
  /*
   * Disable copy constructor.
   */
  Analyzer(const Analyzer&) = delete;
  Analyzer& operator=(const Analyzer&) = delete;
394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414
  /*! \brief sub-analyzer: const integer bound */
  ConstIntBoundAnalyzer const_int_bound;
  /*! \brief sub-analyzer: modular set */
  ModularSetAnalyzer modular_set;
  /*! \brief sub-analyzer rewrite simplify */
  RewriteSimplifier rewrite_simplify;
  /*! \brief sub-analyzer canonical simplify */
  CanonicalSimplifier canonical_simplify;
  /*! \brief sub-analyzer: int set */
  IntSetAnalyzer int_set;
  /*! \brief constructor */
  Analyzer();
  /*!
   * \brief Notify all the sub-analyzers that var
   *        is created and binded to expr.
   *
   *  Each var can only be binded once.
   *
   * \param var The variable.
   * \param expr The expression we bind to.
   */
415
  void Bind(const Var& var, const PrimExpr& expr);
416 417 418 419 420 421 422 423 424
  /*!
   * \brief Notify all the sub-analyzers that var
   *        is created and binded to a range.
   *
   *  Each var can only be binded once.
   *
   * \param var The variable.
   * \param range The range we bind to.
   */
425
  void Bind(const Var& var, const Range& range);
426
  /*!
427 428 429 430 431 432
   * \brief Bind all the vars in the Map
   *
   * \param variables The {variable -> range} map.
   */
  void Bind(const Map<Var, Range>& variables);
  /*!
433 434 435 436 437 438 439 440 441 442 443
   * \brief Whether can we prove expr >= val.

   *  Non-negative proof is very useful in integer analysis
   *  to lower divisions and mods given difference in trunc and ceil mode.
   *
   * \param expr The expression.
   * \param lower_bound The lower bound.
   * \return Whether we can prove it.
   *
   * \note Analyzer will call into sub-analyzers to get the result.
   */
444
  bool CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound);
445 446 447 448 449 450 451 452
  /*!
   * \brief Whether can we prove condition.
   *
   * \param cond The expression to be proved.
   * \return The result.
   *
   * \note Analyzer will call into sub-analyzers to get the result.
   */
453
  bool CanProve(const PrimExpr& cond);
454 455 456 457 458 459 460 461
  /*!
   * \brief Simplify expr.
   *
   * \param expr The expression to be simplified.
   * \return The result.
   *
   * \note Analyzer will call into sub-analyzers to get the result.
   */
462
  PrimExpr Simplify(const PrimExpr& expr);
463
};
464

465
}  // namespace arith
466
}  // namespace tvm
467
#endif  // TVM_ARITH_ANALYZER_H_