/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file modular_set.cc * \brief Modular set analysis */ #include <tvm/runtime/registry.h> #include <tvm/arith/analyzer.h> #include <tvm/tir/op.h> #include <tvm/tir/expr_functor.h> #include <limits> #include <utility> #include <unordered_map> #include "pattern_match.h" namespace tvm { namespace arith { using namespace tir; TVM_REGISTER_NODE_TYPE(ModularSetNode); ModularSet::ModularSet(int64_t coeff, int64_t base) { auto node = make_object<ModularSetNode>(); node->coeff = coeff; node->base = base; // finish construction. data_ = std::move(node); } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch<ModularSetNode>([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast<const ModularSetNode*>(node.get()); p->stream << "ModularSet(" << "coeff=" << op->coeff << ", base=" << op->base << ')'; }); ModularSet MakeModularSet(int64_t coeff, int64_t base) { return ModularSet(coeff, base); } TVM_REGISTER_GLOBAL("arith.ModularSet") .set_body_typed(MakeModularSet); // internal entry for const int bound struct ModularSetAnalyzer::Entry { int64_t coeff{1}; int64_t base{0}; Entry() = default; Entry(int64_t coeff, int64_t base) { CHECK_GE(coeff, 0); this->coeff = coeff; if (coeff != 0) { base = base % coeff; if (base < 0) base += coeff; } this->base = base; } bool is_const() const { return coeff == 0; } bool operator==(const Entry& other) const { return coeff == other.coeff && base == other.base; } bool operator==(const ModularSet& other) const { return other.defined() && coeff == other->coeff && base == other->base; } }; class ModularSetAnalyzer::Impl : public ExprFunctor<ModularSetAnalyzer::Entry(const PrimExpr&)> { public: explicit Impl(Analyzer* parent) : parent_(parent) {} void Update(const Var& var, const ModularSet& info, bool override) { if (!override) { auto it = var_map_.find(var); if (it != var_map_.end()) { CHECK(it->second == info) << "Trying to update var \'" << var << "\'" << " with a different const bound: " << "original=" << ModularSet(it->second.coeff, it->second.base) << ", new=" << info; } } var_map_[var] = Entry(info->coeff, info->base); } // Detect useful constraints and use them in the analysis scope. std::function<void()> EnterConstraint(const PrimExpr& constraint) { PVar<Var> var; PVar<IntImm> coeff, base; // pattern match interesting constraints if ((truncmod(var, coeff) == base).Match(constraint) || (floormod(var, coeff) == base).Match(constraint)) { Entry entry(coeff.Eval()->value, base.Eval()->value); return UpdateByIntersect(var.Eval(), entry); } return nullptr; } // Override visitor behaviors Entry VisitExprDefault_(const Object* op) final { return Everything(); } Entry VisitExpr_(const CastNode* op) final { return VisitExpr(op->value); } Entry VisitExpr_(const IntImmNode* op) final { return Entry(0, op->value); } Entry VisitExpr_(const AddNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); int64_t coeff = ZeroAwareGCD(a.coeff, b.coeff); return Entry(coeff, a.base + b.base); } Entry VisitExpr_(const SubNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); int64_t coeff = ZeroAwareGCD(a.coeff, b.coeff); return Entry(coeff, a.base - b.base); } Entry VisitExpr_(const MulNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); // Simplification rule, x, y, z are in Z // (p x + n) (q y + m) // -> pq xy + pm x + qn y + mn // -> pq z + pm x + qn y + mn int64_t pq = a.coeff * b.coeff; int64_t pm = a.coeff * b.base; int64_t qn = a.base * b.coeff; int64_t coeff = ZeroAwareGCD(pq, ZeroAwareGCD(pm, qn)); return Entry(coeff, a.base * b.base); } Entry DivByConst(const PrimExpr& lhs, int64_t val, bool round_down) { Entry a = VisitExpr(lhs); CHECK_NE(val, 0); if (a.coeff % val == 0) { if (a.base == 0) { // a c x / c -> a x return Entry(std::abs(a.coeff / val), 0); } // positive division have a clear rounding mode. // Only handle case where we clearly know we need to round down. if (a.base > 0 && val > 0 && (round_down || parent_->CanProveGreaterEqual(lhs, 0))) { return Entry(a.coeff / val, a.base / val); } } return Everything(); } Entry VisitExpr_(const DivNode* op) final { Entry b = VisitExpr(op->b); if (b.is_const()) { return DivByConst(op->a, b.base, false); } return Everything(); } Entry VisitExpr_(const FloorDivNode* op) final { Entry b = VisitExpr(op->b); if (b.is_const()) { return DivByConst(op->a, b.base, true); } return Everything(); } Entry VisitExpr_(const MinNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); return Union(a, b); } Entry VisitExpr_(const MaxNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); return Union(a, b); } Entry VisitExpr_(const SelectNode* op) final { Entry a = VisitExpr(op->true_value); Entry b = VisitExpr(op->false_value); return Union(a, b); } Entry VisitExpr_(const CallNode* op) final { // only special handle >> which can be // used for index calculation. if (op->is_intrinsic(CallNode::shift_right)) { return VisitRightShift(op); } else { return Everything(); } } Entry VisitExpr_(const VarNode* op) final { Var v = GetRef<Var>(op); auto it = var_map_.find(v); if (it != var_map_.end()) { return it->second; } else { return Everything(); } } Entry VisitRightShift(const CallNode* op) { Entry b = VisitExpr(op->args[1]); // a c x / c -> a x if (b.is_const()) { return DivByConst(op->args[0], 1 << b.base, true); } return Everything(); } private: /*! \brief pointer to parent. */ Analyzer* parent_{nullptr}; // internal variable map std::unordered_map<Var, Entry, ObjectHash, ObjectEqual> var_map_; /*! * \brief Update var by intersecting entry with var's current set. * \param var The variable. * \param entry The entry to be updated. * \return The recovery function of the scope. */ std::function<void()> UpdateByIntersect(const Var& var, Entry entry) { Entry old = Everything(); auto it = var_map_.find(var); if (it != var_map_.end()) { old = it->second; } var_map_[var] = Intersect(old, entry); // reover function. return [this, old, var]() { var_map_[var] = old; }; } /*! * \brief Create union of two sets. * \param a The left operand. * \param b the right operand. */ static Entry Union(Entry a, Entry b) { // {ax + y} \cup {bz + h} => {gcd(a, b) x + {y or h}} int64_t coeff = ZeroAwareGCD(a.coeff, b.coeff); if (coeff == 0) { if (a.base == b.base) return a; return Everything(); } int64_t base0 = a.base % coeff; int64_t base1 = b.base % coeff; if (base0 == base1) { return Entry(coeff, base0); } else { return Entry(ZeroAwareGCD(ZeroAwareGCD(base0, base1), coeff), base0); } } /*! * \brief Use Extended Euclidean algorithm to solve ax + by = gcd(a, b) * \param a The first coefficient. * \param b The second coefficient. * \param x The solution of x. * \param y The solution of y. * \return The GCD of a and b. */ static int64_t ExtendedEuclidean(int64_t a, int64_t b, int64_t* x, int64_t* y) { // Extended Euclidean algorithm // if a < 0, the problem can be convert into // |a|* (-x) + b * y = gcd(|a|, b) // // initial condition: // a * 0 + b * 1 = b // a * 1 + b * 0 = a int64_t s = 0, old_s = 1; int64_t r = b, old_r = a >= 0 ? a : -a; // Iteration (r2 < r1): // a * x1 + b * y1 = r1 // a * x2 + b * y2 = r2 // The above two eqs can derive the following eq (q = r1 / r2) // a * (x1 - x2 * q) + b * (y1 - y2 * q) = r1 - r2 * q = r3 // Because r3 < r2, the iteration can eventually terminate while (r != 0) { int64_t q = old_r / r; int64_t tmp = old_r; old_r = r; r = tmp - q * r; tmp = old_s; old_s = s; s = tmp - q * s; } *x = a >= 0 ? old_s : -old_s; if (b != 0) { *y = (old_r - (*x) * a) / b; } else { *y = 1; } return old_r; } /*! * \brief Create interect of two sets. * \param a The left operand. * \param b the right operand. */ static Entry Intersect(Entry a, Entry b) { int64_t x, y; int64_t c1 = a.coeff, b1 = a.base, c2 = b.coeff, b2 = b.base; // z = c1 * p + b1 // z = c2 * q + b2 // c1 * x + c2 * y = gcd(c1, c2) // -> c1 * p - c2 * q = b2 - b1 // -> p = (b2 - b1) / gcd * x // -> q = (b2 - b1) / gcd * (-y) // -> z = LCM(x, y) * k + (c1 * p + b1) int64_t gcd = ExtendedEuclidean(c1, c2, &x, &y); int64_t v = b2 - b1; if (v % gcd == 0) { x = v / gcd * x; y = v / gcd * (-y); int64_t coeff = c1 / gcd * c2; return Entry(coeff, x * c1 + b1); } else { return Nothing(); } } /*! * \brief Take GCD of a and b. * \param a The first operand. * \param b The second operand. * \return The result. */ static int64_t ZeroAwareGCD(int64_t a, int64_t b) { if (a < 0) a = -a; if (b < 0) b = -b; if (a < b) std::swap(a, b); if (b == 0) return a; // perform GCD (greatest common divisor) // ax + by = gcd(a, b) z if a != 0, b != 0 while (a % b != 0) { a = a % b; std::swap(a, b); } return b; } /*! * \brief return everything dtype can represent. * \return Bound that represent everything dtype can represent. */ static Entry Everything() { return Entry(1, 0); } /*! * \brief return an empty set * \return Bound that represent everything dtype can represent. */ static Entry Nothing() { return Entry(0, 1); } }; ModularSet ModularSetAnalyzer::operator()(const PrimExpr& expr) { Entry ret = impl_->VisitExpr(expr); return ModularSet(ret.coeff, ret.base); } void ModularSetAnalyzer::Update(const Var& var, const ModularSet& info, bool override) { impl_->Update(var, info, override); } std::function<void()> ModularSetAnalyzer::EnterConstraint(const PrimExpr& constraint) { return impl_->EnterConstraint(constraint); } ModularSetAnalyzer::ModularSetAnalyzer(Analyzer* parent) : impl_(new Impl(parent)) { } ModularSetAnalyzer::~ModularSetAnalyzer() { delete impl_; } } // namespace arith } // namespace tvm