analyzer.h 12.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20
/*!
21 22
 * \file tvm/arith/analyzer.h
 * \brief Algebra expression simplifications.
23
 */
24 25
#ifndef TVM_ARITH_ANALYZER_H_
#define TVM_ARITH_ANALYZER_H_
26

27
#include <tvm/support/with.h>
28 29
#include <tvm/ir/expr.h>
#include <tvm/arith/int_set.h>
30

31
#include <vector>
32 33
#include <unordered_map>
#include <memory>
34
#include <limits>
35 36

namespace tvm {
37
/*! \brief namespace of arithmetic analysis. */
38
namespace arith {
39 40 41 42 43 44 45 46 47 48 49 50 51
//-------------------------------------------------------
// Base integer analysis API.
//
// We have multiple type of analyzers to do relaxed
// integer set analysis(bound analysis, modulo) and
// equivalence checking and simplification.
//
// Importantly, each analyzer may need result from
// another analyzer.
//-------------------------------------------------------

// Forward declare Analyzer
class Analyzer;
52

53 54
using tir::Var;

55 56 57 58 59 60
/*!
 * \brief Constant integer up and lower bound(inclusive).
 *  Useful for value bound analysis.
 *
 *  set = [min_value, max_value]
 */
61
class ConstIntBoundNode : public Object {
62 63 64 65
 public:
  int64_t min_value;
  int64_t max_value;

66
  void VisitAttrs(tvm::AttrVisitor* v) {
67 68 69 70
    v->Visit("min_value", &min_value);
    v->Visit("max_value", &max_value);
  }

71 72 73 74
  bool SEqualReduce(const ConstIntBoundNode* other, SEqualReducer equal) const {
    return equal(min_value, other->min_value) && equal(max_value, other->max_value);
  }

75 76 77 78 79 80 81 82 83
  /*! \brief Number to represent +inf */
  static const constexpr int64_t kPosInf = std::numeric_limits<int64_t>::max();
  /*!
   * \brief Number to represent -inf
   * \note We can make use the of fact that -kPosInf == kNegInf in the project.
   */
  static const constexpr int64_t kNegInf = -kPosInf;

  static constexpr const char* _type_key = "arith.ConstIntBound";
84
  TVM_DECLARE_FINAL_OBJECT_INFO(ConstIntBoundNode, Object);
85 86
};

87 88 89 90
/*!
 * \brief reference class to ConstIntBoundNode
 * \sa ConstIntBoundNode
 */
91
class ConstIntBound : public ObjectRef {
92 93 94 95 96 97 98 99 100 101
 public:
  /*!
   * \brief constructor by fields.
   * \param min_value The mininum value.
   * \param max_value The maximum value.
   */
  TVM_DLL ConstIntBound(int64_t min_value, int64_t max_value);

  static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf;
  static const constexpr int64_t kNegInf = ConstIntBoundNode::kNegInf;
102
  TVM_DEFINE_OBJECT_REF_METHODS(ConstIntBound, ObjectRef, ConstIntBoundNode);
103
};
104 105 106 107 108 109 110 111 112 113 114

/*!
 * \brief Analyzer to get constant integer bound over expression.
 */
class ConstIntBoundAnalyzer {
 public:
  /*!
   * \brief analyze the expr
   * \param expr The expression of interest.
   * \return the result of the analysis.
   */
115
  ConstIntBound operator()(const PrimExpr& expr);
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145

  /*!
   * \brief Update constant int bound information of var.
   *
   * \param var The variable of interest.
   * \param info The bound information.
   * \param override Whether do we allow override of existing information.
   */
  void Update(const Var& var,
              const ConstIntBound& info,
              bool override = false);
  /*!
   * \brief Bind variable to a range.
   *
   * \param var The variable.
   * \param range The range we bind to.
   */
  void Bind(const Var& var, const Range& range);

 private:
  friend class Analyzer;
  friend class ConstraintContext;
  explicit ConstIntBoundAnalyzer(Analyzer* parent);
  ~ConstIntBoundAnalyzer();
  /*!
   * \brief Update the internal state to enter constraint.
   * \param constraint A constraint expression.
   *
   * \return an exit function that must be called to cleanup the constraint can be nullptr.
   */
146
  std::function<void()> EnterConstraint(const PrimExpr& constraint);
147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
  struct Entry;
  class Impl;
  /*! \brief Internal impl */
  Impl* impl_;
};

/*!
 * \brief Range of a linear integer function.
 *  Use to do specify the possible index values.
 *
 *  set = { coeff * x + base | x in Z }
 *
 *  When coeff != 0, it can also be written as
 *  set = { n | n % coeff == base }
 *
 *  This is useful to decide if the index is dividable by certain value.
 *  For example, if index = 0 + 4 x, then we know it can be divided by 4.
 */
165
class ModularSetNode : public Object {
166 167 168 169 170 171
 public:
  /*! \brief linear co-efficient */
  int64_t coeff;
  /*! \brief The base */
  int64_t base;

172
  void VisitAttrs(tvm::AttrVisitor* v) {
173 174 175 176
    v->Visit("coeff", &coeff);
    v->Visit("base", &base);
  }

177 178 179 180
  bool SEqualReduce(const ModularSetNode* other, SEqualReducer equal) const {
    return equal(coeff, other->coeff) && equal(base, other->base);
  }

181
  static constexpr const char* _type_key = "arith.ModularSet";
182
  TVM_DECLARE_FINAL_OBJECT_INFO(ModularSetNode, Object);
183 184
};

185 186 187 188
/*!
 * \brief reference of ModularSetNode
 * \sa ModularSetNode
 */
189
class ModularSet : public ObjectRef {
190 191 192
 public:
  TVM_DLL ModularSet(int64_t coeff, int64_t base);

193
  TVM_DEFINE_OBJECT_REF_METHODS(ModularSet, ObjectRef, ModularSetNode);
194
};
195 196 197 198 199 200 201 202 203 204 205

/*!
 * \brief Analyzer to get modular information over expression.
 */
class ModularSetAnalyzer {
 public:
  /*!
   * \brief analyze the expr
   * \param expr The expression of interest.
   * \return the result of the analysis.
   */
206
  ModularSet operator()(const PrimExpr& expr);
207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
  /*!
   * \brief Update constant int bound information of var.
   *
   * \param var The variable of interest.
   * \param info The bound information.
   * \param override Whether do we allow override of existing information.
   */
  void Update(const Var& var,
              const ModularSet& info,
              bool override = false);

 private:
  friend class Analyzer;
  friend class ConstraintContext;
  explicit ModularSetAnalyzer(Analyzer* parent);
  ~ModularSetAnalyzer();
  /*!
   * \brief Update the internal state to enter constraint.
   * \param constraint A constraint expression.
   *
   * \return an exit function that must be called to cleanup the constraint can be nullptr.
   */
229
  std::function<void()> EnterConstraint(const PrimExpr& constraint);
230 231 232 233 234 235 236
  struct Entry;
  class Impl;
  /*! \brief Internal impl */
  Impl* impl_;
};

/*!
237 238 239 240 241 242 243 244 245
 * \brief Rewrite-rule based simplifier.
 */
class RewriteSimplifier {
 public:
  /*!
   * \brief analyze the expr
   * \param expr The expression of interest.
   * \return the result of the analysis.
   */
246
  PrimExpr operator()(const PrimExpr& expr);
247 248 249 250 251 252 253 254 255

  /*!
   * \brief Update binding of var to a new expression.
   *
   * \param var The variable of interest.
   * \param new_expr
   * \param override Whether do we allow override of existing information.
   */
  void Update(const Var& var,
256
              const PrimExpr& new_expr,
257 258
              bool override = false);

259
  std::function<void()> EnterConstraint(const PrimExpr& constraint);
260

261 262 263
 private:
  friend class Analyzer;
  friend class ConstraintContext;
264
  friend class CanonicalSimplifier;
265 266 267 268 269 270 271 272
  explicit RewriteSimplifier(Analyzer* parent);
  ~RewriteSimplifier();
  class Impl;
  /*! \brief Internal impl */
  Impl* impl_;
};

/*!
273 274 275 276 277 278 279 280 281
 * \brief Canonical-form based simplifier.
 */
class CanonicalSimplifier {
 public:
  /*!
   * \brief analyze the expr
   * \param expr The expression of interest.
   * \return the result of the analysis.
   */
282
  PrimExpr operator()(const PrimExpr& expr);
283 284 285 286 287 288 289 290 291

  /*!
   * \brief Update binding of var to a new expression.
   *
   * \param var The variable of interest.
   * \param new_expr
   * \param override Whether do we allow override of existing information.
   */
  void Update(const Var& var,
292
              const PrimExpr& new_expr,
293 294 295 296 297 298 299 300 301 302 303 304 305
              bool override = false);

 private:
  friend class Analyzer;
  friend class ConstraintContext;
  explicit CanonicalSimplifier(Analyzer* parent);
  ~CanonicalSimplifier();
  class Impl;
  /*! \brief Internal impl */
  Impl* impl_;
};

/*!
306
 * \brief Constraint context.
307 308 309 310 311 312
 *
 * \code
 *
 *  Var("x");
 *  arith::Analyzer analyzer;
 *  {
313
 *    With<arith::ConstraintContext> scope(&analyzer, x % 3 == 0);
314 315 316 317 318 319 320 321
 *    CHECK_EQ(analyzer.modular_set(x)->coeff, 3);
 *  }
 *  // constraint no longer in effect.
 *  CHECK_NE(analyzer.modular_set(x)->coeff, 3);
 *
 * \endcode
 */
class ConstraintContext {
322 323 324
 private:
  // declare friend to enable with.
  friend class With<ConstraintContext>;
325 326 327 328 329
  /*!
   * \brief Construct a constraint context.
   * \param analyzer The analyzer.
   * \param constraint The constraint to be applied.
   */
330
  ConstraintContext(Analyzer* analyzer, PrimExpr constraint)
331 332 333 334 335 336 337 338
      : analyzer_(analyzer), constraint_(constraint) {}
  // enter the scope.
  void EnterWithScope();
  // exit the scope.
  void ExitWithScope();
  /*! \brief The analyzer */
  Analyzer* analyzer_;
  /*! \brief The constraint */
339
  PrimExpr constraint_;
340 341 342 343
  /*! \brief function to be called in recovery */
  std::function<void()> exit_;
};

344
/*!
345
 * \brief Integer set analyzer.
346
 */
347 348 349 350 351 352 353 354 355 356
class IntSetAnalyzer {
 public:
  /*!
   * \brief Find a symbolic integer set that contains all possible values of
   *        expr given the domain of each variables.
   *
   * \param expr The expression of interest.
   * \param dom_map The domain map to indicate which variable to relax.
   * \return the result of the analysis.
   */
357
  IntSet operator()(const PrimExpr& expr, const Map<Var, IntSet>& dom_map);
358 359 360 361 362 363 364 365

 private:
  friend class Analyzer;
  explicit IntSetAnalyzer(Analyzer* parent);
  ~IntSetAnalyzer();
  class Impl;
  /*! \brief Internal impl */
  Impl* impl_;
366 367
};

368
/*!
369
 * \brief Analyzer that contains bunch of sub-analyzers.
370
 *
371 372
 * Each sub-analyzer can make use of another sub-analyzer
 * by weak reference of this.
373
 *
374 375 376
 * NOTE for sub-analyzer developers:
 * If the analyzer uses memoization, we need to clear the internal
 * cache when information about a Var has been overridden.
377
 */
378 379
class Analyzer {
 public:
380 381 382 383 384
  /*
   * Disable copy constructor.
   */
  Analyzer(const Analyzer&) = delete;
  Analyzer& operator=(const Analyzer&) = delete;
385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405
  /*! \brief sub-analyzer: const integer bound */
  ConstIntBoundAnalyzer const_int_bound;
  /*! \brief sub-analyzer: modular set */
  ModularSetAnalyzer modular_set;
  /*! \brief sub-analyzer rewrite simplify */
  RewriteSimplifier rewrite_simplify;
  /*! \brief sub-analyzer canonical simplify */
  CanonicalSimplifier canonical_simplify;
  /*! \brief sub-analyzer: int set */
  IntSetAnalyzer int_set;
  /*! \brief constructor */
  Analyzer();
  /*!
   * \brief Notify all the sub-analyzers that var
   *        is created and binded to expr.
   *
   *  Each var can only be binded once.
   *
   * \param var The variable.
   * \param expr The expression we bind to.
   */
406
  void Bind(const Var& var, const PrimExpr& expr);
407 408 409 410 411 412 413 414 415
  /*!
   * \brief Notify all the sub-analyzers that var
   *        is created and binded to a range.
   *
   *  Each var can only be binded once.
   *
   * \param var The variable.
   * \param range The range we bind to.
   */
416
  void Bind(const Var& var, const Range& range);
417 418 419 420 421 422 423 424 425 426 427 428
  /*!
   * \brief Whether can we prove expr >= val.

   *  Non-negative proof is very useful in integer analysis
   *  to lower divisions and mods given difference in trunc and ceil mode.
   *
   * \param expr The expression.
   * \param lower_bound The lower bound.
   * \return Whether we can prove it.
   *
   * \note Analyzer will call into sub-analyzers to get the result.
   */
429
  bool CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound);
430 431 432 433 434 435 436 437
  /*!
   * \brief Whether can we prove condition.
   *
   * \param cond The expression to be proved.
   * \return The result.
   *
   * \note Analyzer will call into sub-analyzers to get the result.
   */
438
  bool CanProve(const PrimExpr& cond);
439 440 441 442 443 444 445 446
  /*!
   * \brief Simplify expr.
   *
   * \param expr The expression to be simplified.
   * \return The result.
   *
   * \note Analyzer will call into sub-analyzers to get the result.
   */
447
  PrimExpr Simplify(const PrimExpr& expr);
448
};
449

450
}  // namespace arith
451
}  // namespace tvm
452
#endif  // TVM_ARITH_ANALYZER_H_