modular.cc 4.58 KB
Newer Older
1 2 3 4 5 6 7 8
/*!
 *  Copyright (c) 2017 by Contributors
 * \file modular.cc
 * \brief Modular analysis
 */
#include <tvm/ir.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_visitor.h>
9
#include <tvm/arithmetic.h>
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
#include <limits>
#include "./int_set_internal.h"

namespace tvm {
namespace arith {

using namespace ir;

class ModularEvaluator
    : public ExprFunctor<ModularEntry(const Expr&)> {
 public:
  explicit ModularEvaluator(
      const std::unordered_map<
      const Variable*, ModularEntry>& mod_map)
      : mod_map_(mod_map) {
  }
  ModularEntry Eval(const Expr& e) {
    return VisitExpr(e);
  }
  // default
  ModularEntry VisitExprDefault_(const Node*) final {
    return ModularEntry::everything();
  }
  // override combination rules.
  ModularEntry VisitExpr_(const IntImm* op) final {
    if (op->value < std::numeric_limits<int>::max()) {
      ModularEntry ret;
      ret.base = static_cast<int>(op->value);
      ret.coeff = 0;
      return ret;
    } else {
      return ModularEntry::everything();
    }
  }
  ModularEntry VisitExpr_(const UIntImm* op) final {
    if (op->value < static_cast<uint64_t>(
            std::numeric_limits<int>::max())) {
      ModularEntry ret;
      ret.base = static_cast<int>(op->value);
      ret.coeff = 0;
      return ret;
    } else {
      return ModularEntry::everything();
    }
  }
  ModularEntry VisitExpr_(const Variable* op) final {
    auto it = mod_map_.find(op);
    if (it != mod_map_.end()) {
      return it->second;
    } else {
      return ModularEntry::everything();
    }
  }
  ModularEntry VisitExpr_(const Add* op) final {
    ModularEntry a = Eval(op->a);
    ModularEntry b = Eval(op->b);
    ModularEntry ret;
    ret.coeff = ZeroAwareGCD(a.coeff, b.coeff);
    ret.base = BaseSimplify(a.base + b.base, ret.coeff);
    return ret;
  }
  ModularEntry VisitExpr_(const Sub* op) final {
    ModularEntry a = Eval(op->a);
    ModularEntry b = Eval(op->b);
    ModularEntry ret;
    ret.coeff = ZeroAwareGCD(a.coeff, b.coeff);
    ret.base = BaseSimplify(a.base - b.base, ret.coeff);
    return ret;
  }
  ModularEntry VisitExpr_(const Mul* op) final {
    ModularEntry a = Eval(op->a);
    ModularEntry b = Eval(op->b);
    // Simplification rule, x, y, z are in Z
    // (p x + n) (q y + m)
    // -> pq xy + pm x + qn y + mn
    // -> pq z + pm x + qn y + mn
    int pq = a.coeff * b.coeff;
    int pm = a.coeff * b.base;
    int qn = a.base * b.coeff;
    ModularEntry ret;
    ret.coeff = ZeroAwareGCD(pq, ZeroAwareGCD(pm, qn));
    ret.base = BaseSimplify(a.base * b.base, ret.coeff);
    return ret;
  }
  ModularEntry VisitExpr_(const Div* op) final {
    // a c x  / c -> a x
    // We cannot do cases where offset is non-zero
    // because of different integer rounding in pos/neg
    ModularEntry a = Eval(op->a);
    ModularEntry b = Eval(op->b);
    if (b.coeff == 0 &&
        a.base == 0) {
      CHECK_NE(b.base, 0);
      if (a.coeff % b.base == 0) {
        ModularEntry ret;
        ret.coeff = a.coeff / b.base;
        ret.base = 0;
        return ret;
      }
    }
    return ModularEntry::everything();
  }

 private:
  const std::unordered_map<
    const Variable*, ModularEntry>& mod_map_;
116
  friend struct ModularEntry;
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
  // simplify the base by putting it in range.
  static int BaseSimplify(int base, int coeff) {
    if (coeff == 0) return base;
    base = base % coeff;
    if (base < 0) base += coeff;
    return base;
  }
  static int ZeroAwareGCD(int a, int b) {
    CHECK_GE(a, 0);
    CHECK_GE(b, 0);
    if (a < b) std::swap(a, b);
    if (b == 0) return a;
    // perform GCD (greatest common divisor)
    // ax + by = gcd(a, b) z if a != 0, b != 0
    while (a % b != 0) {
      a = a % b;
      std::swap(a, b);
    }
    return b;
  }
};

139 140 141 142 143 144 145 146 147
ModularEntry ModularEntry::Add(const ModularEntry& a,
                               const ModularEntry& b) {
  ModularEntry ret;
  ret.coeff = ModularEvaluator::ZeroAwareGCD(a.coeff, b.coeff);
  ret.base = ModularEvaluator::BaseSimplify(a.base + b.base, ret.coeff);
  return ret;
}


148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
ModularEntry EvalModular(
    const Expr& e,
    const std::unordered_map<const Variable*, ModularEntry>& mod_map) {
  return ModularEvaluator(mod_map)(e);
}

IntSet EvalModular(const Expr& e,
                   const Map<Var, IntSet>& mod_map) {
  std::unordered_map<const Variable*, ModularEntry> mmap;
  for (auto& kv : mod_map) {
    const ModularSet* m = kv.second.as<ModularSet>();
    CHECK(m) << "Need to pass ModularSet for Modular Analysis";
    mmap[kv.first.get()] = m->e;
  }
  std::shared_ptr<ModularSet> n = std::make_shared<ModularSet>();
  n->e = ModularEvaluator(mmap)(e);
  return IntSet(n);
}

}  // namespace arith
}  // namespace tvm