/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file tvm/data_layout.h * \brief Layout expression to describe the data organization of a tensor. * And BijectiveLayout to mapping two data layouts between each other. */ #ifndef TVM_DATA_LAYOUT_H_ #define TVM_DATA_LAYOUT_H_ #include <tvm/base.h> #include <tvm/expr.h> #include <string> #include <sstream> #include <vector> #include <utility> #include <algorithm> #include "expr_operator.h" namespace tvm { class LayoutAxis { public: static const LayoutAxis& Get(const char name); // Get the singleton LayoutAxis using itvar->var->name_hint static const LayoutAxis& Get(const IterVar& itvar); // Get the singleton LayoutAxis using name[0] (size of name must be 1). static const LayoutAxis& make(const std::string& name); inline bool IsPrimal() const { return name_ >= 'A' && name_ <= 'Z'; } inline std::string name() const { return std::string(1, name_); } // if current axis is primal, switch the axis to its subordinate one, // else switch to the primal. inline const LayoutAxis& ToDual() const { if (name_ >= 'A' && name_ <= 'Z') { return LayoutAxis::Get(name_ - 'A' + 'a'); } else { return LayoutAxis::Get(name_ - 'a' + 'A'); } } // return the primal axis. If it is already primal, return itself. const LayoutAxis& ToPrimal() const { return IsPrimal() ? *this : ToDual(); } // return the subordinate axis. If it is already subordinate, return itself. const LayoutAxis& ToSubordinate() const { return IsPrimal() ? ToDual() : *this; } inline bool operator==(const LayoutAxis& rhs) const { return name_ == rhs.name_; } friend std::ostream& operator<<(std::ostream& os, const LayoutAxis& l) { os << l.name(); return os; } private: static const LayoutAxis UPPER_CASE[]; static const LayoutAxis LOWER_CASE[]; LayoutAxis(const LayoutAxis&); LayoutAxis& operator=(const LayoutAxis&); explicit LayoutAxis(const char name) : name_(name) {} const char name_; }; class Layout; // Internal node container Buffer class LayoutNode : public Node { public: /*! \brief string representation of layout, "" for scalar. */ std::string name; /*! \brief specify each axis of the layout, * in which the variable name is the name of the axis. * The IterVar's extent indicates the size of the axis, * it is a variable for a primal axis, but a constant for a subordinate axis. * Empty for scalar's layout. */ Array<IterVar> axes; void VisitAttrs(AttrVisitor* v) final { v->Visit("name", &name); v->Visit("axes", &axes); } TVM_DLL static Layout make(const std::string& layout); static constexpr const char* _type_key = "Layout"; TVM_DECLARE_NODE_TYPE_INFO(LayoutNode, Node); }; /*! * \brief Layout is to describe how data is organized within an N-dimention tensor. * It is composed of upper cases, lower cases and numbers, * where upper case indicates a primal axis and * the corresponding lower case with factor size indicates the subordinate axis. * For example, NCHW16c can describe a 5-D tensor of * [batch_size, channel, height, width, channel_block]. * Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel). * Layout for scalar is defined, while both its name and axes have size 0. */ class Layout : public NodeRef { public: explicit Layout(NodePtr<Node> n) : NodeRef(n) {} /*! \brief default constructor */ Layout() = default; explicit Layout(const Array<IterVar>& axes); /*! \brief construct from a string */ Layout(const char* name) : Layout(std::string(name)) {} // NOLINT(*) /*! * \brief construct from a string. * \param name input in layout convention: * upper case indicates a dimension and * the corresponding lower case with factor size * indicates the split dimension. * return undefined layout if "__undef__" is passed. */ Layout(const std::string& name); // NOLINT(*) /*! * \brief access the internal node container * \return the pointer to the internal node container */ const LayoutNode* operator->() const { return static_cast<const LayoutNode*>(node_.get()); } /*! * \brief access the internal node container * \return the pointer to the internal node container */ LayoutNode* operator->() { return static_cast<LayoutNode*>(node_.get()); } /*! * \brief Return an undefined layout. * \return a (global) undefined layout. */ static const Layout& Undef() { static Layout undef; return undef; } /*! * \brief Returns a sub-layout which is the portion of the object * that starts at dimension \p pos and spans \p len dimensions * (or until the end of the layout, whichever comes first). * \param pos The start position. * \param len The length of the sub-layout. if 0, return layout of scalar * \return A newly constructed Layout object. */ Layout SubLayout(size_t pos, size_t len) const; /*! * \brief Split \p axis by \p size and put the sub-axis to position \p target_pos. * \param axis The source axis to be split. It must be a primal-axis; * \param target_pos The target position of the newly split subordinate-axis. * \param factor size of the sub-dimension. * \return A newly constructed Layout object. */ Layout Split(const LayoutAxis &axis, size_t target_pos, int32_t factor) const; /*! \return number of dimensions */ inline size_t ndim() const { if (!defined()) return 0; return operator->()->axes.size(); } /*! \return number of super dimensions */ inline size_t ndim_primal() const { if (!defined()) return 0; size_t ct = 0; for (auto x : operator->()->axes) { if (LayoutAxis::Get(x).IsPrimal()) { ct++; } } return ct; } /*! * \brief return the index of the input axis. * If it is not found in the layout or the layout is undefined, * return -1. * \param axis the input axis. * \return the index or -1 if not found. */ inline int32_t IndexOf(const LayoutAxis& axis) const { if (!this->defined()) return -1; const auto axes = operator->()->axes; for (size_t i = 0; i < axes.size(); ++i) { if (axes[i]->var->name_hint == axis.name()) return static_cast<int32_t>(i); } return -1; } /*! * \brief Get the factor size of the subordinate axis. * \param axis the input primal-axis or subordinate-axis. * \return the size of the subordinate-axis of \p axis (if \p axis is a primal-axis), * or the size of \p axis itself (if \p axis is a subordinate-axis). * Return -1 if \p axis is not in the layout the layout is undefined. */ int32_t FactorOf(const LayoutAxis& axis) const; /*! * \brief Whether the layout contains an axis. * \param axis axis to be checked. * \return Whether the layout contains the axis. */ bool Contains(const LayoutAxis& axis) const { if (!defined()) return false; for (const IterVar var : operator->()->axes) { if (var->var->name_hint == axis.name()) { return true; } } return false; } const LayoutAxis& operator[](int32_t i) const { CHECK(defined()) << "Try to access axis from an undefined layout."; int32_t index = i < 0 ? static_cast<int32_t>(ndim() + i) : i; CHECK(index >= 0 && static_cast<size_t>(index) < ndim()) << "Invalid index " << i; const IterVar axis = operator->()->axes[index]; return LayoutAxis::Get(axis); } /*! \return the string description of the layout */ inline std::string name() const { if (!defined()) return "__undef__"; return operator->()->name; } /*! * \brief Whether the two layouts are equal. * \param rhs Another layout. * \return whether the two layouts are equal. */ inline bool Equals(const Layout &rhs) const { return name() == rhs.name(); } /*! * \brief allow output string of layout to ostream * \param os the output stream * \param l the layout * \return the ostream */ friend std::ostream& operator<<(std::ostream& os, const Layout& l) { os << l.name(); return os; } using ContainerType = LayoutNode; }; class BijectiveLayout; // Internal node container BijectiveLayout class BijectiveLayoutNode : public Node { public: /*! \brief Describes how source axes can be mapped to the destination axes, * e.g., [i0 / 16, i1, i0 % 16] can describe NC -> NC16n */ Array<Expr> forward_rule; /*! \brief Describes how destination axes can be mapped to the source axes */ Array<Expr> backward_rule; /*! \brief The source layout */ Layout src_layout; /*! \brief The destination layout */ Layout dst_layout; void VisitAttrs(AttrVisitor* v) final { v->Visit("src_layout", &src_layout); v->Visit("dst_layout", &dst_layout); v->Visit("forward_rule", &forward_rule); v->Visit("backward_rule", &backward_rule); } static constexpr const char* _type_key = "BijectiveLayout"; TVM_DECLARE_NODE_TYPE_INFO(BijectiveLayoutNode, Node); TVM_DLL static BijectiveLayout make(const Layout& src_layout, const Layout& dst_layout); }; /*! \brief Bijective function mapping for data layout transformation. * Given two Layout, BijectiveLayout build and store the mapping rules, * provides API to transform N-dimention tensor from the source indices (i0, i1, …, im) * to the destination indices (j0, j1, … jm). */ class BijectiveLayout : public NodeRef { public: BijectiveLayout() = default; explicit BijectiveLayout(NodePtr<Node> n) : NodeRef(n) {} // Given the source shape, infer the destination shape. TVM_DLL Array<Expr> ForwardShape(const Array<Expr>& shape) const; // Given the destination shape, recover the source shape. TVM_DLL Array<Expr> BackwardShape(const Array<Expr>& dst_shape) const; // Given the destination indices, infer the destination indices. TVM_DLL Array<Expr> ForwardIndex(const Array<Expr>& index) const; // Given the destination indices, recover the source indices. TVM_DLL Array<Expr> BackwardIndex(const Array<Expr>& dst_index) const; /*! * \brief access the internal node container * \return the pointer to the internal node container */ inline const BijectiveLayoutNode* operator->() const; /*! \brief specify container node */ using ContainerType = BijectiveLayoutNode; }; inline const BijectiveLayoutNode* BijectiveLayout::operator->() const { return static_cast<const BijectiveLayoutNode*>(node_.get()); } } // namespace tvm #endif // TVM_DATA_LAYOUT_H_