/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file tvm/expr_operator.h * \brief Common operators defined for Expr. * * \note Most of the operator defined here perform simple constant folding * when the type is int32 or int64 for simplifying the index expressions. */ #ifndef TVM_EXPR_OPERATOR_H_ #define TVM_EXPR_OPERATOR_H_ #include <algorithm> #include <type_traits> #include "expr.h" #include "ir.h" namespace tvm { /*! * \brief Make a const value with certain data type. * \param t The target type. * \param value The input value * \return the result expression. * \tparam ValueType The constant value type */ template<typename ValueType, typename = typename std::enable_if<std::is_pod<ValueType>::value>::type> inline Expr make_const(Type t, ValueType value); /*! * \brief Make a const zero expr. * \param t The target type. * \return the result expression. */ inline Expr make_zero(Type t); /*! * \brief Make a constant true expression. * \param lanes The number of lanes in the bool * \return The result expression. */ inline Expr const_true(int lanes = 1) { return make_const(UInt(1, lanes), 1); } /*! * \brief Make a constant false expression. * \param lanes The number of lanes in the bool * \return The result expression. */ inline Expr const_false(int lanes = 1) { return make_const(UInt(1, lanes), 0); } /*! * \brief Get x as constant int expression. * \param x The expression * \return the address to the int expression, * return nullptr, if x is not IntImm. */ inline const int64_t* as_const_int(const Expr& x) { if (!x.defined()) return nullptr; if (const ir::IntImm* op = x.as<ir::IntImm>()) { return &(op->value); } else { return nullptr; } } /*! * \brief Get x as constant uint expression. * \param x The expression * \return the address to the int expression, * return nullptr, if x is not UIntImm. */ inline const uint64_t* as_const_uint(const Expr& x) { if (!x.defined()) return nullptr; if (const ir::UIntImm* op = x.as<ir::UIntImm>()) { return &(op->value); } else { return nullptr; } } /*! * \brief Check whether x is a constant integer expression. * \param x The input argument * \param value the value to be compared against. * \return whether x is constant expression. */ inline bool is_const_int(const Expr& x, int64_t value); /*! * \brief Check whether stmt is nop. * \param stmt The input statement * \return whether stmt is nop */ inline bool is_no_op(const Stmt& stmt); /*! * \brief Check whether x is a constant integer 1 * \param x The input argument. * \note This only return true for integer types. * \return whether x is constant 1 */ inline bool is_one(const Expr& x) { return is_const_int(x, 1); } /*! * \brief Check whether x is a constant integer 0 * \param x The input argument * \return whether x is constant 0 * \note This only return true for integer types. */ inline bool is_zero(const Expr& x) { return is_const_int(x, 0); } /*! * \brief Check whether x is a constant. * \note This only return true for integer types. * \return whether x is constant */ inline bool is_const(const Expr& x); /*! * \brief Check whether x is a constant power of two * If x is power of two, write the power to the shift. * * \param x The input expression. * \param shift The output shift if x is power of two. * \return whether x is constant power of two */ TVM_DLL bool is_const_power_of_two_integer(const Expr& x, int* shift); /*! * \brief cast value to type. * * \param t the target type. * \param value The value * \return The result expression. * \note This function may return value if the type is the same. */ TVM_DLL Expr cast(const Type& t, Expr value); /*! * \brief perform reinterpret cast value to type. * * \param t the target type. * \param value The value * \return The result expression. * \note This function may return value if the type is the same. */ TVM_DLL Expr reinterpret(const Type& t, Expr value); /*! * \brief add operator * * \param a left operand * \param b right operand * \return The result expression. * \note this function does eager constant folding for * index types(int32, int64) when possible. */ TVM_DLL Expr operator+(Expr a, Expr b); /*! * \brief subtraction operator * * \param a left operand * \param b right operand * \return The result expression. * \note this function does eager constant folding for * index types(int32, int64) when possible. */ TVM_DLL Expr operator-(Expr a, Expr b); /*! * \brief negation. * * \param a input. * \return The result expression. * \note this function does eager constant folding for * index types(int32, int64) when possible. */ TVM_DLL Expr operator-(Expr a); /*! * \brief multiplication operator * * \param a left operand * \param b right operand * \return The result expression. * \note this function does eager constant folding for * index types(int32, int64) when possible. */ TVM_DLL Expr operator*(Expr a, Expr b); /*! * \brief division operator * * \param a left operand * \param b right operand * \return The result expression. * \note this function does eager constant folding for * index types(int32, int64) when possible. */ TVM_DLL Expr operator/(Expr a, Expr b); /*! * \brief mod operator * * \param a left operand * \param b right operand * \return The result expression. * \note this function does eager constant folding for * index types(int32, int64) when possible. */ TVM_DLL Expr operator%(Expr a, Expr b); /*! * \brief left shift operator * * \param a left operand * \param b right operand * \return The result expression. * \note this function does eager constant folding for * index types(int32, int64) when possible. */ TVM_DLL Expr operator<<(Expr a, Expr b); /*! * \brief right shift operator * * \param a left operand * \param b right operand * \return The result expression. * \note this function does eager constant folding for * index types(int32, int64) when possible. */ TVM_DLL Expr operator>>(Expr a, Expr b); /*! * \brief greater * * \param a left operand * \param b right operand * \return The result expression. * \note this function does eager constant folding for * index types(int32, int64) when possible. */ TVM_DLL Expr operator>(Expr a, Expr b); /*! * \brief greater_equal * * \param a left operand * \param b right operand * \return The result expression. * \note this function does eager constant folding for * index types(int32, int64) when possible. */ TVM_DLL Expr operator>=(Expr a, Expr b); /*! * \brief less * * \param a left operand * \param b right operand * \return The result expression. * \note this function does eager constant folding for * index types(int32, int64) when possible. */ TVM_DLL Expr operator<(Expr a, Expr b); /*! * \brief less_equal * * \param a left operand * \param b right operand * \return The result expression. * \note this function does eager constant folding for * index types(int32, int64) when possible. */ TVM_DLL Expr operator<=(Expr a, Expr b); /*! * \brief equal * * \param a left operand * \param b right operand * \return The result expression. * \note this function does eager constant folding for * index types(int32, int64) when possible. */ TVM_DLL Expr operator==(Expr a, Expr b); /*! * \brief not_equal * * \param a left operand * \param b right operand * \return The result expression. * \note this function does eager constant folding for * index types(int32, int64) when possible. */ TVM_DLL Expr operator!=(Expr a, Expr b); /*! * \brief and * * \param a left operand * \param b right operand * \return The result expression. * \note This operator does eager constant folding. */ TVM_DLL Expr operator&&(Expr a, Expr b); /*! * \brief or * * \param a left operand * \param b right operand * \return The result expression. * \note This operator does eager constant folding. */ TVM_DLL Expr operator||(Expr a, Expr b); /*! * \brief not * * \param a left operand * \return The result expression. * \note This operator does eager constant folding. */ TVM_DLL Expr operator!(Expr a); /*! * \brief take maximum of two values * * \param a left operand * \param b right operand * \return The result expression. * \note this function does eager constant folding for * index types(int32, int64) when possible. */ TVM_DLL Expr max(Expr a, Expr b); /*! * \brief take minimum of two values * * \param a left operand * \param b right operand * \return The result expression. * \note this function does eager constant folding for * index types(int32, int64) when possible. */ TVM_DLL Expr min(Expr a, Expr b); /*! * \brief take bitwise and of two values * * \param a left operand * \param b right operand * \return The result expression. * \note this function does eager constant folding for * index types(int32, int64) when possible. */ TVM_DLL Expr operator&(Expr a, Expr b); /*! * \brief take bitwise or of two values * * \param a left operand * \param b right operand * \return The result expression. * \note this function does eager constant folding for * index types(int32, int64) when possible. */ TVM_DLL Expr operator|(Expr a, Expr b); /*! * \brief take bitwise xor of two values * * \param a left operand * \param b right operand * \return The result expression. * \note this function does eager constant folding for * index types(int32, int64) when possible. */ TVM_DLL Expr operator^(Expr a, Expr b); /*! * \brief take bitwise negation of two values * * \param a the input expression. * \return The result expression. * \note this function does eager constant folding for * index types(int32, int64) when possible. */ TVM_DLL Expr operator~(Expr a); /*! * \brief Conditional expression. * * \param cond The condition * \param true_value The value when results are true. * \param false_value The value when results are false. * \return The result expression. * \note this function does eager constant folding for * index types(int32, int64) when possible. */ TVM_DLL Expr if_then_else(Expr cond, Expr true_value, Expr false_value); /*! * \brief Mark condition as likely. * \param cond The condition * \return The marked expression. */ TVM_DLL Expr likely(Expr cond); /*! * \brief Calculate power(x, y) * \param x The left operand. * \param y The right operand. */ TVM_DLL Expr pow(Expr x, Expr y); /*! * \brief Calculate absolute value of x. * \param x The input data * * \return The aboslute value of input data x */ TVM_DLL Expr abs(Expr x); /*! * \brief sum of of source expression over axis * \param source The source expression. * \param axis List of iteration variables that will be used for reduction. */ TVM_DLL Expr sum(Expr source, Array<IterVar> axis); /*! * \brief max of of source expression over axis * \param source The source expression. * \param axis List of iteration variables that will be used for reduction. */ TVM_DLL Expr max(Expr source, Array<IterVar> axis); /*! * \brief max of of source expression over axis * \param source The source expression. * \param axis List of iteration variables that will be used for reduction. */ TVM_DLL Expr min(Expr source, Array<IterVar> axis); /*! * \brief product of of source expression over axis * \param source The source expression. * \param axis List of iteration variables that will be used for reduction. */ TVM_DLL Expr prod(Expr source, Array<IterVar> axis); /*! * \brief Calculate floor(x) * \param x The input expression. * \return The result expression. */ TVM_DLL Expr floor(Expr x); /*! * \brief Calculate ceil(x) * \param x The input expression. * \return The result expression. */ TVM_DLL Expr ceil(Expr x); /*! * \brief Calculate round(x) * \param x The input expression. * \return The result expression. */ TVM_DLL Expr round(Expr x); /*! * \brief Calculate trunc(x) * \param x The input expression. * \return The result expression. */ TVM_DLL Expr trunc(Expr x); // Intrinsic operators #define TVM_DECLARE_INTRIN_UNARY(OpName) \ inline Expr OpName(Expr x) { \ return ir::Call::make(x.type(), #OpName, {x}, ir::Call::PureIntrinsic); \ } \ TVM_DECLARE_INTRIN_UNARY(exp); TVM_DECLARE_INTRIN_UNARY(tanh); TVM_DECLARE_INTRIN_UNARY(sigmoid); TVM_DECLARE_INTRIN_UNARY(sqrt); TVM_DECLARE_INTRIN_UNARY(rsqrt); TVM_DECLARE_INTRIN_UNARY(log); TVM_DECLARE_INTRIN_UNARY(popcount); // Implementation details after this inline bool is_const(const Expr& x) { if (x.as<ir::IntImm>() || x.as<ir::UIntImm>()) { return true; } else if (const auto* op = x.as<ir::Broadcast>()) { const Expr& val = op->value; if (val.as<ir::IntImm>() || val.as<ir::UIntImm>()) { return true; } } return false; } inline bool is_positive_const(const Expr& a) { if (const ir::IntImm* op = a.as<ir::IntImm>()) { return op->value > 0; } else if (const ir::UIntImm* op = a.as<ir::UIntImm>()) { return op->value > 0; } else { return false; } } inline bool is_negative_const(const Expr& a) { if (const ir::IntImm* op = a.as<ir::IntImm>()) { return op->value < 0; } else { return false; } } inline bool is_const_int(const Expr& x, int64_t value) { if (const auto* op = x.as<ir::IntImm>()) { return op->value == value; } else if (const auto* op = x.as<ir::UIntImm>()) { return op->value == static_cast<uint64_t>(value); } else if (const auto* op = x.as<ir::Broadcast>()) { const Expr& val = op->value; if (const auto* opv = val.as<ir::IntImm>()) { return opv->value == value; } else if (const auto* opv = val.as<ir::UIntImm>()) { return opv->value == static_cast<uint64_t>(value); } } return false; } inline bool is_no_op(const Stmt& stmt) { if (!stmt.defined()) return true; if (const auto* op = stmt.as<ir::Evaluate>()) { return is_const(op->value); } return false; } template<typename ValueType> inline Expr MakeConstScalar(Type t, ValueType value) { if (t.is_int()) return ir::IntImm::make(t, static_cast<int64_t>(value)); if (t.is_uint()) return ir::UIntImm::make(t, static_cast<uint64_t>(value)); if (t.is_float()) return ir::FloatImm::make(t, static_cast<double>(value)); LOG(FATAL) << "cannot make const for type " << t; return Expr(); } template<typename ValueType, typename> inline Expr make_const(Type t, ValueType value) { if (t.lanes() == 1) { return MakeConstScalar(t, value); } else { return ir::Broadcast::make( MakeConstScalar(t.element_of(), value), t.lanes()); } } inline Expr make_zero(Type t) { if (t.is_handle()) { return reinterpret(t, make_const(UInt(64), 0)); } return make_const(t, 0); } // additional const expression overloading #define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc) \ inline Expr Name(Expr& a, Expr b) { \ a = OpFunc(a, b); \ return a; \ } #define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(Name) \ inline Expr Name(const Expr& a, float b) { \ return Name(a, Expr(b)); \ } \ inline Expr Name(float a, const Expr& b) { \ return Name(Expr(a), b); \ } \ inline Expr Name(int a, const Expr& b) { \ return Name(make_const(b.type(), a), b); \ } \ inline Expr Name(const Expr& a, int b) { \ return Name(a, make_const(a.type(), b)); \ } #define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name) \ inline Expr Name(const Expr& a, bool b) { \ return Name(a, Expr(b)); \ } \ inline Expr Name(bool a, const Expr& b) { \ return Name(Expr(a), b); \ } #define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \ inline Expr Name(const Expr& a, int b) { \ return Name(a, make_const(a.type(), b)); \ } \ inline Expr Name(int a, const Expr& b) { \ return Name(make_const(b.type(), a), b); \ } TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator+=, operator+); TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator-=, operator-); TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator*=, operator*); TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator/=, operator/); TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator+); TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator-); TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator*); TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator/); TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(max); TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(min); TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator>); // NOLINT(*) TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator>=); TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<); // NOLINT(*) TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<=); // integer related ops TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator%); TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator>>); // NOLINT(*) TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator<<); // NOLINT(*) TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator&); TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator|); TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator^); // logical ops TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(operator&&); TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(operator||); } // namespace tvm #endif // TVM_EXPR_OPERATOR_H_