/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file tvm/relay/feature.h * \brief Detect features used in Expr/Module. */ #ifndef TVM_RELAY_FEATURE_H_ #define TVM_RELAY_FEATURE_H_ #include <tvm/node/container.h> #include <tvm/relay/expr.h> #include <tvm/ir/module.h> #include <bitset> namespace tvm { namespace relay { /*! \brief Different kinds of relay feature a program might use. */ enum Feature : int { fVar = 0, fGlobalVar = 1, fConstant = 2, fTuple = 3, fTupleGetItem = 4, fFunction = 5, fOp = 6, fCall = 7, fLet = 8, fIf = 9, fRefCreate = 10, fRefRead = 11, fRefWrite = 12, fConstructor = 13, fMatch = 14, /*! \brief Whether any non-atom fragment of the program is shared, making the program a graph. */ fGraph = 15, /*! \brief Whether there is local fixpoint in the program. */ fLetRec = 16 }; constexpr size_t feature_count = 17; /*! * \brief A finite set of Feature. */ class FeatureSet { public: FeatureSet(const FeatureSet&) = default; /*! \brief A singleton set containing a single Feature. */ explicit FeatureSet(Feature ft) { bs_.set(static_cast<size_t>(ft)); } explicit FeatureSet(const tvm::Array<tvm::Integer>& ft) { for (Integer i : ft) { (*this) += Feature(static_cast<int>(i)); } } explicit operator Array<Integer>() const { Array<Integer> ret; for (size_t i = 0; i < feature_count; ++i) { if (bs_[i]) { ret.push_back(Integer(i)); } } return ret; } /*! \brief A set that contain all the Feature. */ static FeatureSet All() { FeatureSet fs; fs.bs_.flip(); return fs; } /*! \brief The empty set. Contain no Feature. */ static FeatureSet No() { FeatureSet fs; return fs; } template<typename T> FeatureSet& operator+=(const T& rhs) { bs_ |= FeatureSet(rhs).bs_; return *this; } /*! \brief Set union. */ template<typename T> FeatureSet operator+(const T& rhs) const { FeatureSet fs(*this); fs += rhs; return fs; } template<typename T> FeatureSet& operator-=(const T& rhs) { bs_ &= ~(FeatureSet(rhs)).bs_; return *this; } /*! \brief Set difference. */ template<typename T> FeatureSet operator-(const T& rhs) const { FeatureSet fs(*this); fs -= rhs; return fs; } /*! * \brief Is this a subset of rhs? * * \param rhs another FeatureSet. * * \return true only if this is a subset of rhs. */ bool is_subset_of(const FeatureSet& rhs) const { return ((*this) - rhs).bs_.none(); } private: std::bitset<feature_count> bs_; FeatureSet() = default; explicit FeatureSet(const std::bitset<feature_count>& bs) : bs_(bs) { } }; /*! * \brief Calculate the feature of the program. * * \param expr The expression. * * \return The FeatureSet. */ FeatureSet DetectFeature(const RelayExpr& expr); /*! * \brief Calculate the feature of the program. * * \param mod The module. * * \return The FeatureSet. */ FeatureSet DetectFeature(const IRModule& mod); /*! * \brief Calculate the feature of the program. * * \param expr The expression. * \param mod The module. * * \return The FeatureSet. */ inline FeatureSet DetectFeature(const Expr& expr, const IRModule& mod) { return DetectFeature(expr) + DetectFeature(mod); } } // namespace relay } // namespace tvm #endif // TVM_RELAY_FEATURE_H_