data_layout.h 12.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20
/*!
21
 * \file tvm/tir/data_layout.h
22 23 24
 * \brief Layout expression to describe the data organization of a tensor.
 *  And BijectiveLayout to mapping two data layouts between each other.
 */
25 26
#ifndef TVM_TIR_DATA_LAYOUT_H_
#define TVM_TIR_DATA_LAYOUT_H_
27

28

29 30
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
31 32 33 34 35 36 37 38 39

#include <string>
#include <sstream>
#include <vector>
#include <utility>
#include <algorithm>


namespace tvm {
40
namespace tir {
41 42 43 44 45 46

class LayoutAxis {
 public:
  static const LayoutAxis& Get(const char name);

  // Get the singleton LayoutAxis using itvar->var->name_hint
47
  static const LayoutAxis& Get(const tir::IterVar& itvar);
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95

  // Get the singleton LayoutAxis using name[0] (size of name must be 1).
  static const LayoutAxis& make(const std::string& name);

  inline bool IsPrimal() const { return name_ >= 'A' && name_ <= 'Z'; }
  inline std::string name() const { return std::string(1, name_); }

  // if current axis is primal, switch the axis to its subordinate one,
  // else switch to the primal.
  inline const LayoutAxis& ToDual() const {
    if (name_ >= 'A' && name_ <= 'Z') {
      return LayoutAxis::Get(name_ - 'A' + 'a');
    } else {
      return LayoutAxis::Get(name_ - 'a' + 'A');
    }
  }

  // return the primal axis. If it is already primal, return itself.
  const LayoutAxis& ToPrimal() const {
    return IsPrimal() ? *this : ToDual();
  }

  // return the subordinate axis. If it is already subordinate, return itself.
  const LayoutAxis& ToSubordinate() const {
    return IsPrimal() ? ToDual() : *this;
  }

  inline bool operator==(const LayoutAxis& rhs) const {
    return name_ == rhs.name_;
  }

  friend std::ostream& operator<<(std::ostream& os, const LayoutAxis& l) {
    os << l.name();
    return os;
  }

 private:
  static const LayoutAxis UPPER_CASE[];
  static const LayoutAxis LOWER_CASE[];
  LayoutAxis(const LayoutAxis&);
  LayoutAxis& operator=(const LayoutAxis&);
  explicit LayoutAxis(const char name) : name_(name) {}

  const char name_;
};

class Layout;
// Internal node container Buffer
96
class LayoutNode : public Object {
97
 public:
98
  /*! \brief string representation of layout, "" for scalar. */
99 100 101 102 103
  std::string name;
  /*! \brief specify each axis of the layout,
   *   in which the variable name is the name of the axis.
   *   The IterVar's extent indicates the size of the axis,
   *   it is a variable for a primal axis, but a constant for a subordinate axis.
104
   *   Empty for scalar's layout.
105
   */
106
  Array<tir::IterVar> axes;
107

108
  void VisitAttrs(AttrVisitor* v) {
109 110 111 112 113 114 115
    v->Visit("name", &name);
    v->Visit("axes", &axes);
  }

  TVM_DLL static Layout make(const std::string& layout);

  static constexpr const char* _type_key = "Layout";
116
  TVM_DECLARE_FINAL_OBJECT_INFO(LayoutNode, Object);
117 118 119 120 121 122 123 124 125 126
};

/*!
 * \brief Layout is to describe how data is organized within an N-dimention tensor.
 *  It is composed of upper cases, lower cases and numbers,
 *  where upper case indicates a primal axis and
 *  the corresponding lower case with factor size indicates the subordinate axis.
 *  For example, NCHW16c can describe a 5-D tensor of
 *  [batch_size, channel, height, width, channel_block].
 *  Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel).
127
 *  Layout for scalar is defined, while both its name and axes have size 0.
128
 */
129
class Layout : public ObjectRef {
130
 public:
131
  explicit Layout(ObjectPtr<Object> n) : ObjectRef(n) {}
132 133 134 135

  /*! \brief default constructor */
  Layout() = default;

136
  explicit Layout(const Array<tir::IterVar>& axes);
137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155

  /*! \brief construct from a string */
  Layout(const char* name) : Layout(std::string(name)) {} // NOLINT(*)

  /*!
   * \brief construct from a string.
   * \param name input in layout convention:
   *        upper case indicates a dimension and
   *        the corresponding lower case with factor size
   *        indicates the split dimension.
   *        return undefined layout if "__undef__" is passed.
   */
  Layout(const std::string& name); // NOLINT(*)

  /*!
   * \brief access the internal node container
   * \return the pointer to the internal node container
   */
  const LayoutNode* operator->() const {
156
    return static_cast<const LayoutNode*>(get());
157 158 159 160 161 162 163
  }

  /*!
   * \brief access the internal node container
   * \return the pointer to the internal node container
   */
  LayoutNode* operator->() {
164
    return static_cast<LayoutNode*>(get_mutable());
165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
  }

  /*!
   * \brief Return an undefined layout.
   * \return a (global) undefined layout.
   */
  static const Layout& Undef() {
    static Layout undef;
    return undef;
  }

  /*!
   * \brief Returns a sub-layout which is the portion of the object
   *        that starts at dimension \p pos and spans \p len dimensions
   *        (or until the end of the layout, whichever comes first).
   * \param pos The start position.
181
   * \param len The length of the sub-layout. if 0, return layout of scalar
182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214
   * \return A newly constructed Layout object.
   */
  Layout SubLayout(size_t pos, size_t len) const;

  /*!
   * \brief Split \p axis by \p size and put the sub-axis to position \p target_pos.
   * \param axis The source axis to be split. It must be a primal-axis;
   * \param target_pos The target position of the newly split subordinate-axis.
   * \param factor size of the sub-dimension.
   * \return A newly constructed Layout object.
   */
  Layout Split(const LayoutAxis &axis, size_t target_pos, int32_t factor) const;


  /*! \return number of dimensions */
  inline size_t ndim() const {
    if (!defined()) return 0;
    return operator->()->axes.size();
  }

  /*! \return number of super dimensions */
  inline size_t ndim_primal() const {
    if (!defined()) return 0;
    size_t ct = 0;
    for (auto x : operator->()->axes) {
      if (LayoutAxis::Get(x).IsPrimal()) {
        ct++;
      }
    }
    return ct;
  }

  /*!
215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236
   * \brief Returns a new layout where the dims have been expanded to match the primal dimensions.
   * \param dst_layout The dst layout to which current layout has to be expanded.
   * \return The expanded Layout.
   */
  inline Layout ExpandPrimal(const Layout& dst_layout) {
    Layout new_src_layout;
    // 1) Find the axis which are missing in the current layout. Make them the prefix.
    std::string new_src_layout_str = "";
    for (auto dst_axis : dst_layout->axes) {
      if (LayoutAxis::Get(dst_axis).IsPrimal()) {
        if (!this->Contains(LayoutAxis::Get(dst_axis))) {
          new_src_layout_str += dst_axis->var->name_hint;
        }
      }
    }
    // 2) Now, add the primal axis of the current layout.
    new_src_layout_str += this->name();
    new_src_layout = Layout(new_src_layout_str);
    return new_src_layout;
  }

  /*!
237 238 239 240 241 242 243 244 245 246
   * \brief return the index of the input axis.
   *        If it is not found in the layout or the layout is undefined,
   *        return -1.
   * \param axis the input axis.
   * \return the index or -1 if not found.
   */
  inline int32_t IndexOf(const LayoutAxis& axis) const {
    if (!this->defined()) return -1;
    const auto axes = operator->()->axes;
    for (size_t i = 0; i < axes.size(); ++i) {
247
      if (axes[i]->var->name_hint == axis.name()) return static_cast<int32_t>(i);
248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267
    }
    return -1;
  }

  /*!
   * \brief Get the factor size of the subordinate axis.
   * \param axis the input primal-axis or subordinate-axis.
   * \return the size of the subordinate-axis of \p axis (if \p axis is a primal-axis),
   *         or the size of \p axis itself (if \p axis is a subordinate-axis).
   *         Return -1 if \p axis is not in the layout the layout is undefined.
   */
  int32_t FactorOf(const LayoutAxis& axis) const;

  /*!
   * \brief Whether the layout contains an axis.
   * \param axis axis to be checked.
   * \return Whether the layout contains the axis.
   */
  bool Contains(const LayoutAxis& axis) const {
    if (!defined()) return false;
268
    for (const tir::IterVar var : operator->()->axes) {
269
      if (var->var->name_hint == axis.name()) {
270 271 272 273 274 275 276 277 278 279
        return true;
      }
    }
    return false;
  }

  const LayoutAxis& operator[](int32_t i) const {
    CHECK(defined()) << "Try to access axis from an undefined layout.";
    int32_t index = i < 0 ? static_cast<int32_t>(ndim() + i) : i;
    CHECK(index >= 0 && static_cast<size_t>(index) < ndim()) << "Invalid index " << i;
280
    const tir::IterVar axis = operator->()->axes[index];
281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314
    return LayoutAxis::Get(axis);
  }

  /*! \return the string description of the layout */
  inline std::string name() const {
    if (!defined()) return "__undef__";
    return operator->()->name;
  }

  /*!
   * \brief Whether the two layouts are equal.
   * \param rhs Another layout.
   * \return whether the two layouts are equal.
   */
  inline bool Equals(const Layout &rhs) const {
    return name() == rhs.name();
  }

  /*!
   * \brief allow output string of layout to ostream
   * \param os the output stream
   * \param l the layout
   * \return the ostream
   */
  friend std::ostream& operator<<(std::ostream& os, const Layout& l) {
    os << l.name();
    return os;
  }

  using ContainerType = LayoutNode;
};

class BijectiveLayout;
// Internal node container BijectiveLayout
315
class BijectiveLayoutNode : public Object {
316 317 318 319
 public:
  /*! \brief Describes how source axes can be mapped to the destination axes,
   *   e.g., [i0 / 16, i1, i0 % 16] can describe NC -> NC16n
   */
320
  Array<PrimExpr> forward_rule;
321
  /*! \brief Describes how destination axes can be mapped to the source axes */
322
  Array<PrimExpr> backward_rule;
323 324 325 326 327 328

  /*! \brief The source layout */
  Layout src_layout;
  /*! \brief The destination layout */
  Layout dst_layout;

329
  void VisitAttrs(AttrVisitor* v) {
330 331 332 333 334 335 336
    v->Visit("src_layout", &src_layout);
    v->Visit("dst_layout", &dst_layout);
    v->Visit("forward_rule", &forward_rule);
    v->Visit("backward_rule", &backward_rule);
  }

  static constexpr const char* _type_key = "BijectiveLayout";
337
  TVM_DECLARE_FINAL_OBJECT_INFO(BijectiveLayoutNode, Object);
338 339 340 341 342 343 344
};

/*! \brief Bijective function mapping for data layout transformation.
 *   Given two Layout, BijectiveLayout build and store the mapping rules,
 *   provides API to transform N-dimention tensor from the source indices (i0, i1, …, im)
 *   to the destination indices (j0, j1, … jm).
 */
345
class BijectiveLayout : public ObjectRef {
346 347
 public:
  BijectiveLayout() = default;
348
  explicit BijectiveLayout(ObjectPtr<Object> n) : ObjectRef(n) {}
349 350 351 352 353 354
  /*!
   * \brief The constructor
   * \param src_layout The source layout
   * \param dst_layout The destination layout
   */
  TVM_DLL BijectiveLayout(Layout src_layout, Layout dst_layout);
355 356

  // Given the source shape, infer the destination shape.
357
  TVM_DLL Array<PrimExpr> ForwardShape(const Array<PrimExpr>& shape) const;
358
  // Given the destination shape, recover the source shape.
359
  TVM_DLL Array<PrimExpr> BackwardShape(const Array<PrimExpr>& dst_shape) const;
360
  // Given the destination indices, infer the destination indices.
361
  TVM_DLL Array<PrimExpr> ForwardIndex(const Array<PrimExpr>& index) const;
362
  // Given the destination indices, recover the source indices.
363
  TVM_DLL Array<PrimExpr> BackwardIndex(const Array<PrimExpr>& dst_index) const;
364 365 366 367 368 369 370 371 372 373 374 375

  /*!
   * \brief access the internal node container
   * \return the pointer to the internal node container
   */
  inline const BijectiveLayoutNode* operator->() const;

  /*! \brief specify container node */
  using ContainerType = BijectiveLayoutNode;
};

inline const BijectiveLayoutNode* BijectiveLayout::operator->() const {
376
  return static_cast<const BijectiveLayoutNode*>(get());
377
}
378
}  // namespace tir
379 380
}  // namespace tvm

381
#endif  // TVM_TIR_DATA_LAYOUT_H_