analyzer.cc 3.24 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23
/*!
 *  Copyright (c) 2019 by Contributors
 * \file tvm/arithmetic/analyzer.cc
 */
24
#include <tvm/ir.h>
25
#include <tvm/arithmetic.h>
26
#include <tvm/expr_operator.h>
27 28 29 30 31 32

namespace tvm {
namespace arith {

Analyzer::Analyzer()
    : const_int_bound(this),
33
      modular_set(this),
34
      rewrite_simplify(this),
35 36
      canonical_simplify(this),
      int_set(this) {
37 38
}

39
void Analyzer::Bind(const VarExpr& var, const Expr& expr) {
40 41 42 43 44 45 46 47
  Expr new_expr = expr;
  new_expr = this->canonical_simplify(new_expr);
  new_expr = this->rewrite_simplify(new_expr);

  this->const_int_bound.Update(var, this->const_int_bound(new_expr));
  this->modular_set.Update(var, this->modular_set(new_expr));
  this->rewrite_simplify.Update(var, new_expr);
  this->canonical_simplify.Update(var, new_expr);
48 49
}

50
void Analyzer::Bind(const VarExpr& var, const Range& range) {
51 52
  CHECK(range.defined());
  if (is_one(range->extent)) {
53 54 55
    this->Bind(var, range->min);
  } else {
    this->const_int_bound.Bind(var, range);
56
  }
57
  // skip modular_set
58
  // skip rewrite simplify
59 60
}

61 62 63

void ConstraintContext::EnterWithScope() {
  CHECK(exit_ == nullptr);
64
  // entering the scope.
65 66
  auto f0 = analyzer_->const_int_bound.EnterConstraint(constraint_);
  auto f1 = analyzer_->modular_set.EnterConstraint(constraint_);
67
  auto f2 = analyzer_->rewrite_simplify.EnterConstraint(constraint_);
68
  // recovery function.
69 70
  exit_ = [f0, f1, f2]() {
    if (f2 != nullptr) f2();
71 72 73 74 75
    if (f1 != nullptr) f1();
    if (f0 != nullptr) f0();
  };
}

76 77 78 79 80
void ConstraintContext::ExitWithScope() {
  CHECK(exit_ != nullptr);
  exit_();
}

81
bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) {
82
  if (const auto* ptr = expr.as<ir::IntImm>()) {
83
    return ptr->value >= lower_bound;
84 85
  }
  auto bd = this->const_int_bound(this->rewrite_simplify(expr));
86 87 88
  if (bd->min_value >= lower_bound) return true;
  return false;
}
89

90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
bool Analyzer::CanProve(const Expr& expr) {
  if (const auto* ptr = expr.as<ir::UIntImm>()) {
    return ptr->value != 0;
  }
  auto res = this->rewrite_simplify(expr);
  if (const auto* ptr = res.as<ir::UIntImm>()) {
    return ptr->value != 0;
  }
  res = this->canonical_simplify(expr);
  if (const auto* ptr = res.as<ir::UIntImm>()) {
    return ptr->value != 0;
  }
  return false;
}

Expr Analyzer::Simplify(const Expr& expr) {
  if (is_const(expr)) return expr;
  auto res = this->rewrite_simplify(expr);
108
  if (is_const(res)) return res;
109 110 111 112
  res = this->canonical_simplify(res);
  return res;
}

113 114
}  // namespace arith
}  // namespace tvm