/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file const_fold.h * \brief Centralized location for constant folding. */ #ifndef TVM_ARITHMETIC_CONST_FOLD_H_ #define TVM_ARITHMETIC_CONST_FOLD_H_ #include <tvm/ir.h> #include <tvm/ir_mutator.h> #include <tvm/expr_operator.h> #include <algorithm> #include <cmath> #include "int_operator.h" namespace tvm { namespace arith { /*! * \brief Try to run binary compute with constant folding. * * \param a The left operand. * \param b The right operand. * \tparam Op The operator type. * * \note a and b Must already matched data types with each other. * \return nullptr if constant fold fails, otherwise return folded result. */ template<typename Op> inline Expr TryConstFold(Expr a, Expr b) { return Expr(); } /*! * \brief Try to run unary compute with constant folding. * * \param a The left operand. * \tparam Op The operator type. * * \note a and b Must already matched data types with each other. * \return nullptr if constant fold fails, otherwise return folded result. */ template<typename Op> inline Expr TryConstFold(Expr a); /*! * \brief Check whether type is used to represent index. * * Index types are frequently used in shape computation * and need to be aggressively constant-folded. * * \param type The type to represent index. * \return the checked result. */ inline bool IsIndexType(const Type& type) { return type.is_int() && type.lanes() == 1 && (type.bits() == 32 || type.bits() == 64); } #define TVM_ARITH_CONST_PROPAGATION(BODY) \ using ir::IntImm; \ using ir::UIntImm; \ using ir::FloatImm; \ const IntImm* pa = a.as<IntImm>(); \ const IntImm* pb = b.as<IntImm>(); \ const FloatImm* fa = a.as<FloatImm>(); \ const FloatImm* fb = b.as<FloatImm>(); \ BODY; #define TVM_INDEX_CONST_PROPAGATION(BODY) \ using ir::IntImm; \ using ir::UIntImm; \ const IntImm* pa = a.as<IntImm>(); \ const IntImm* pb = b.as<IntImm>(); \ const Type& ta = a.type(); \ const Type& tb = b.type(); \ if (arith::IsIndexType(ta) && arith::IsIndexType(tb)) { \ BODY; \ } \ // specialization of constant folders. template<> inline Expr TryConstFold<ir::Add>(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ const Type& rtype = a.type(); if (pa && pb) return IntImm::make(rtype, pa->value + pb->value); if (pa && pa->value == 0) return b; if (pb && pb->value == 0) return a; if (fa && fb) return FloatImm::make(rtype, fa->value + fb->value); if (fa && fa->value == 0) return b; if (fb && fb->value == 0) return a; }); return Expr(); } template<> inline Expr TryConstFold<ir::Sub>(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ const Type& rtype = a.type(); if (pa && pb) return IntImm::make(rtype, pa->value - pb->value); if (pb && pb->value == 0) return a; if (fa && fb) return FloatImm::make(rtype, fa->value - fb->value); if (fb && fb->value == 0) return a; }); return Expr(); } template<> inline Expr TryConstFold<ir::Mul>(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ const Type& rtype = a.type(); if (pa && pb) return IntImm::make(rtype, pa->value * pb->value); if (pa) { if (pa->value == 1) return b; if (pa->value == 0) return a; } if (pb) { if (pb->value == 1) return a; if (pb->value == 0) return b; } if (fa && fb) return FloatImm::make(rtype, fa->value * fb->value); if (fa) { if (fa->value == 1) return b; if (fa->value == 0) return a; } if (fb) { if (fb->value == 1) return a; if (fb->value == 0) return b; } }); return Expr(); } template<> inline Expr TryConstFold<ir::Div>(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ const Type& rtype = a.type(); if (pa && pb) { // due to division and mod can have different modes // NOTE: this will assumes truc div. CHECK_NE(pb->value, 0) << "Divide by zero"; return IntImm::make(rtype, pa->value / pb->value); } if (pa) { if (pa->value == 0) return a; } if (pb) { if (pb->value == 1) return a; CHECK_NE(pb->value, 0) << "Divide by zero"; } if (fa && fb && fb->value != 0) { return FloatImm::make(rtype, fa->value / fb->value); } if (fa && fa->value == 0) return a; if (fb) { if (fb->value == 1) return a; CHECK_NE(fb->value, 0) << "Divide by zero"; } }); return Expr(); } template<> inline Expr TryConstFold<ir::Mod>(Expr a, Expr b) { TVM_INDEX_CONST_PROPAGATION({ const Type& rtype = a.type(); if (pa && pb) { return IntImm::make(rtype, pa->value % pb->value); } if (pa) { if (pa->value == 0) return a; } if (pb) { if (pb->value == 1) return make_zero(rtype); CHECK_NE(pb->value, 0) << "Divide by zero"; } }); return Expr(); } template<> inline Expr TryConstFold<ir::FloorDiv>(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ const Type& rtype = a.type(); if (pa && pb) { CHECK_NE(pb->value, 0) << "Divide by zero"; return IntImm::make(rtype, arith::floordiv(pa->value, pb->value)); } if (pa) { if (pa->value == 0) return a; } if (pb) { if (pb->value == 1) return a; CHECK_NE(pb->value, 0) << "Divide by zero"; } if (fa && fb && fb->value != 0) { return FloatImm::make(rtype, std::floor(fa->value / fb->value)); } if (fa && fa->value == 0) return a; if (fb) { if (fb->value == 1) return a; CHECK_NE(fb->value, 0) << "Divide by zero"; } }); return Expr(); } template<> inline Expr TryConstFold<ir::FloorMod>(Expr a, Expr b) { TVM_INDEX_CONST_PROPAGATION({ const Type& rtype = a.type(); if (pa && pb) { return IntImm::make(rtype, arith::floormod(pa->value, pb->value)); } if (pa) { if (pa->value == 0) return a; } if (pb) { if (pb->value == 1) return make_zero(rtype); CHECK_NE(pb->value, 0) << "Divide by zero"; } }); return Expr(); } template<> inline Expr TryConstFold<ir::Min>(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ const Type& rtype = a.type(); if (pa && pb) return IntImm::make(rtype, std::min(pa->value, pb->value)); if (fa && fb) return FloatImm::make(rtype, std::min(fa->value, fb->value)); }); if (a.same_as(b)) return a; return Expr(); } template<> inline Expr TryConstFold<ir::Max>(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ const Type& rtype = a.type(); if (pa && pb) return IntImm::make(rtype, std::max(pa->value, pb->value)); if (fa && fb) return FloatImm::make(rtype, std::max(fa->value, fb->value)); }); if (a.same_as(b)) return a; return Expr(); } template<> inline Expr TryConstFold<ir::GT>(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return UIntImm::make(UInt(1), pa->value > pb->value); if (fa && fb) return UIntImm::make(UInt(1), fa->value > fb->value); }); return Expr(); } template<> inline Expr TryConstFold<ir::GE>(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return UIntImm::make(UInt(1), pa->value >= pb->value); if (fa && fb) return UIntImm::make(UInt(1), fa->value >= fb->value); }); return Expr(); } template<> inline Expr TryConstFold<ir::LT>(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return UIntImm::make(UInt(1), pa->value < pb->value); if (fa && fb) return UIntImm::make(UInt(1), fa->value < fb->value); }); return Expr(); } template<> inline Expr TryConstFold<ir::LE>(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return UIntImm::make(UInt(1), pa->value <= pb->value); if (fa && fb) return UIntImm::make(UInt(1), fa->value <= fb->value); }); return Expr(); } template<> inline Expr TryConstFold<ir::EQ>(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return UIntImm::make(UInt(1), pa->value == pb->value); if (fa && fb) return UIntImm::make(UInt(1), fa->value == fb->value); }); return Expr(); } template<> inline Expr TryConstFold<ir::NE>(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return UIntImm::make(UInt(1), pa->value != pb->value); if (fa && fb) return UIntImm::make(UInt(1), fa->value != fb->value); }); return Expr(); } template<> inline Expr TryConstFold<ir::And>(Expr a, Expr b) { using ir::UIntImm; const UIntImm* pa = a.as<UIntImm>(); const UIntImm* pb = b.as<UIntImm>(); if (pa && pa->value) return b; if (pa && !pa->value) return a; if (pb && pb->value) return a; if (pb && !pb->value) return b; return Expr(); } template<> inline Expr TryConstFold<ir::Or>(Expr a, Expr b) { using ir::UIntImm; const UIntImm* pa = a.as<UIntImm>(); const UIntImm* pb = b.as<UIntImm>(); if (pa && pa->value) return a; if (pa && !pa->value) return b; if (pb && pb->value) return b; if (pb && !pb->value) return a; return Expr(); } template<> inline Expr TryConstFold<ir::Not>(Expr a) { using ir::UIntImm; const UIntImm* pa = a.as<UIntImm>(); if (pa) { return UIntImm::make(UInt(1), !(pa->value)); } return Expr(); } /*! \brief Helper namespace for symbolic value limits */ struct SymbolicLimits { /*! \brief positive infinity */ static Expr pos_inf_; /*! \brief negative infinity */ static Expr neg_inf_; }; /*! * \brief Opaque expression representing positive infinity. * * It can can only be used as parameter of by min/max * for integer analysis and cannot be used in normal expressions. * * \return positive infinity. */ inline Expr pos_inf() { return SymbolicLimits::pos_inf_; } /*! * \brief Check if value is positive infinity. * \param value The value to be checked. * * \return The check result. */ inline bool is_pos_inf(const Expr& value) { return value.same_as(SymbolicLimits::pos_inf_); } /*! * \brief Opaque expression representing negative infinity. * * It can can only be used as parameter of by min/max * for integer analysis and cannot be used in normal expressions. * * \return negative infinity. */ inline Expr neg_inf() { return SymbolicLimits::neg_inf_; } /*! * \brief Check if value is negative infinity. * \param value The value to be checked. * * \return The check result. */ inline bool is_neg_inf(const Expr& value) { return value.same_as(SymbolicLimits::neg_inf_); } } // namespace arith } // namespace tvm #endif // TVM_ARITHMETIC_CONST_FOLD_H_