/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file op_common.h * \brief A set of utilities and common functionality * for relay ops. */ #ifndef TVM_RELAY_OP_OP_COMMON_H_ #define TVM_RELAY_OP_OP_COMMON_H_ #include <tvm/relay/expr.h> #include <tvm/relay/op.h> #include <tvm/relay/op_attr_types.h> #include <vector> #include <string> #include <unordered_map> #include "type_relations.h" #include "../transforms/infer_layout_util.h" namespace tvm { namespace relay { /*! Quick helper macro * - Expose a positional make function to construct the node. * - Register op to the registry. * * We make the decision to always only expose positional argument. * We will do rewrapping in the frontend to support language * sugars such as keyword arguments and default value. * \param OpName the name of registry. */ #define RELAY_REGISTER_UNARY_OP(OpName) \ TVM_REGISTER_GLOBAL("relay.op._make." OpName) \ .set_body_typed([](Expr data) { \ static const Op& op = Op::Get(OpName); \ return Call(op, {data}, Attrs(), {}); \ }); \ RELAY_REGISTER_OP(OpName) \ .set_num_inputs(1) \ .add_argument("data", "Tensor", "The input tensor.") \ .add_type_rel("Identity", IdentityRel) \ .set_attr<TOpPattern>("TOpPattern", kElemWise) \ .set_attr<TOpIsStateful>("TOpIsStateful", false) \ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", \ ElemwiseArbitraryLayout) \ /*! Quick helper macro * - Expose a positional make function to construct the node. * - Register op to the registry. * * We make the decision to always only expose positional argument. * We will do rewrapping in the frontend to support language * sugars such as keyword arguments and default value. * * \param OpName the name of registry. */ #define RELAY_REGISTER_BINARY_OP(OpName) \ TVM_REGISTER_GLOBAL("relay.op._make." OpName) \ .set_body_typed([](Expr lhs, Expr rhs) { \ static const Op& op = Op::Get(OpName); \ return Call(op, {lhs, rhs}, Attrs(), {}); \ }); \ RELAY_REGISTER_OP(OpName) \ .set_num_inputs(2) \ .add_argument("lhs", "Tensor", "The left hand side tensor.") \ .add_argument("rhs", "Tensor", "The right hand side tensor.") \ .add_type_rel("Broadcast", BroadcastRel) \ .set_attr<TOpPattern>("TOpPattern", kBroadcast) \ .set_attr<TOpIsStateful>("TOpIsStateful", false) \ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", \ BinaryBroadcastLayout) // Comparisons #define RELAY_REGISTER_CMP_OP(OpName) \ TVM_REGISTER_GLOBAL("relay.op._make." OpName) \ .set_body_typed([](Expr lhs, Expr rhs) { \ static const Op& op = Op::Get(OpName); \ return Call(op, {lhs, rhs}, Attrs(), {}); \ }); \ RELAY_REGISTER_OP(OpName) \ .set_num_inputs(2) \ .add_argument("lhs", "Tensor", "The left hand side tensor.") \ .add_argument("rhs", "Tensor", "The right hand side tensor.") \ .add_type_rel("BroadcastComp", BroadcastCompRel) \ .set_attr<TOpPattern>("TOpPattern", kBroadcast) \ .set_attr<TOpIsStateful>("TOpIsStateful", false) \ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", \ BinaryBroadcastLayout) /*! \brief A helper class for matching and rewriting operators. */ template<typename R> class OpMatch { public: using MatchFunc = std::function<R(const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_args)>; /*! \brief Match an operator with the given name. * \param op_name The name of the operator to match. * \param func The function to execute when it matches. * \return A self-reference for builder style API. */ inline OpMatch& Match(const std::string& op_name, MatchFunc func) { auto op = Op::Get(op_name); match_map_.insert({op, func}); return *this; } /*! \brief Rewrite a call operation based on the operator and the registered * match functions. * \param call The call to rewrite. * \return The result of rewriting. */ inline R operator()(const Call& call) { auto it = match_map_.find(Downcast<Op>(call->op)); if (it != match_map_.end()) { return it->second(call->args, call->attrs, call->type_args); } else { if (default_ != nullptr) { return default_(call->args, call->attrs, call->type_args); } else { LOG(FATAL) << "unexpected operation " << call->op; } } } private: /*! \brief The match function map. */ std::unordered_map<Op, MatchFunc, ObjectHash, ObjectEqual> match_map_; /*! \brief An optional default case. */ MatchFunc default_; }; /*! \brief A utility function to get padding width from a 1 or 2 ints tuple. */ inline void GetPaddingWidth(const Array<IndexExpr>& padding, IndexExpr* pad_w) { if (padding.size() == 1) { *pad_w = padding[0] * 2; } else if (padding.size() == 2) { *pad_w = padding[0] + padding[1]; } else { CHECK_EQ(padding.size(), 4) << " Expected padding size of 1 or 2, found " << padding.size(); } } /*! \brief A utility function to get padding height and width from a 1, 2, 4 ints tuple. */ inline void GetPaddingHeightWidth(const Array<IndexExpr>& padding, IndexExpr* pad_h, IndexExpr* pad_w) { if (padding.size() == 1) { *pad_h = padding[0] * 2; *pad_w = padding[0] * 2; } else if (padding.size() == 2) { *pad_h = padding[0] * 2; *pad_w = padding[1] * 2; } else if (padding.size() == 4) { *pad_h = padding[0] + padding[2]; *pad_w = padding[1] + padding[3]; } else { CHECK_EQ(padding.size(), 4) << " Padding size should be 1, 2 or 4, but got " << padding.size(); } } /*! \brief A utility function to get padding depth, height and width from a 1, 3, 6 ints tuple. */ inline void GetPaddingDepthHeightWidth(const Array<IndexExpr>& padding, IndexExpr* pad_d, IndexExpr* pad_h, IndexExpr* pad_w) { if (padding.size() == 1) { *pad_d = padding[0] * 2; *pad_h = padding[0] * 2; *pad_w = padding[0] * 2; } else if (padding.size() == 3) { *pad_d = padding[0] * 2; *pad_h = padding[1] * 2; *pad_w = padding[2] * 2; } else if (padding.size() == 6) { *pad_d = padding[0] + padding[3]; *pad_h = padding[1] + padding[4]; *pad_w = padding[2] + padding[5]; } else { CHECK_EQ(padding.size(), 6) << " Padding size should be 1, 3 or 6, but got " << padding.size(); } } } // namespace relay } // namespace tvm #endif // TVM_RELAY_OP_OP_COMMON_H_