/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file tvm/relay/op_strategy.h * \brief The Relay operator Strategy and related data structure. */ #ifndef TVM_RELAY_OP_STRATEGY_H_ #define TVM_RELAY_OP_STRATEGY_H_ #include <tvm/te/tensor.h> #include <tvm/te/schedule.h> #include <tvm/relay/expr.h> #include <tvm/relay/op_attr_types.h> #include <tvm/target/target.h> #include <string> namespace tvm { namespace relay { /*! * \brief Operator implementation that includes compute and schedule function. */ class OpImplementationNode : public Object { public: /*! \brief Compute function */ FTVMCompute fcompute; /*! \brief Schedule function */ FTVMSchedule fschedule; /*! \brief Name of the implementation */ std::string name; /*! \brief Priority level */ int plevel; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name", &name); v->Visit("plevel", &plevel); } static constexpr const char* _type_key = "relay.OpImplementation"; TVM_DECLARE_FINAL_OBJECT_INFO(OpImplementationNode, Object); }; /*! * \brief Operator implementation class. */ class OpImplementation : public ObjectRef { public: /*! * \brief Invoke the operator compute function. * \param attrs The attribute of the primitive * \param inputs The input tensors. * \param out_type The output type information. * \return The output compute description of the operator. */ TVM_DLL Array<te::Tensor> Compute(const Attrs& attrs, const Array<te::Tensor>& inputs, const Type& out_type); /*! * \brief Build the computation schedule. * \param attrs The attribute of the node. * \param outs The output tensors. * \param target The build target. * \return The computation schedule. */ TVM_DLL te::Schedule Schedule(const Attrs& attrs, const Array<te::Tensor>& outs, const Target& target); TVM_DEFINE_OBJECT_REF_METHODS(OpImplementation, ObjectRef, OpImplementationNode); }; /*! * \brief Specialized implementations for operators under certain conditions. */ class OpSpecializationNode : public Object { public: /*! \brief List of implementations. */ Array<OpImplementation> implementations; /*! \brief Condition to enable the specialization. * Could be undefined to represent generic case. */ te::SpecializedCondition condition; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("condition", &condition); v->Visit("implementations", &implementations); } static constexpr const char* _type_key = "relay.OpSpecialization"; TVM_DECLARE_FINAL_OBJECT_INFO(OpSpecializationNode, ExprNode); }; /*! * \brief Operator specialization class. */ class OpSpecialization : public ObjectRef { public: /*! * \brief Add an implementation. * \param fcompute Compute function * \param fschedule Schedule function * \param name Name of the implementation * \param plevel Priority level of the implementation */ TVM_DLL void AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule, std::string name, int plevel); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(OpSpecialization, ObjectRef, OpSpecializationNode); }; /*! * \brief Operator strategy to choose implementation. */ class OpStrategyNode : public Object { public: /*! \brief List of operator specializations. */ Array<OpSpecialization> specializations; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("specializations", &specializations); } static constexpr const char* _type_key = "relay.OpStrategy"; TVM_DECLARE_FINAL_OBJECT_INFO(OpStrategyNode, ExprNode); }; /*! * \brief Operator strategy class. */ class OpStrategy : public ObjectRef { public: /*! * \brief Add an implementation. * \param fcompute Compute function * \param fschedule Schedule function * \param name Name of the implementation * \param plevel Priority level of the implementation */ TVM_DLL void AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule, std::string name, int plevel); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(OpStrategy, ObjectRef, OpStrategyNode); }; } // namespace relay } // namespace tvm #endif // TVM_RELAY_OP_STRATEGY_H_