buffer.h 6.12 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20
/*!
tqchen committed
21
 * \file tvm/buffer.h
22 23 24 25 26 27 28
 * \brief Symbolic n-dimensional array, to represent a memory buffer.
 */
#ifndef TVM_BUFFER_H_
#define TVM_BUFFER_H_

#include <string>

29 30
#include "base.h"
#include "expr.h"
31
#include "expr_operator.h"
32
#include "tvm/node/container.h"
33 34 35 36 37

namespace tvm {

// Internal node container Buffer
class BufferNode;
38

39 40 41 42 43
/*! \brief buffer type */
enum BufferType : int {
  kDefault = 1,
  // Maps buffer[i][j][k] -> buffer[i][0][k] if dimension i's shape equals 1.
  kAutoBroadcast = 2,
44 45
};

46 47 48
/*!
 * \brief Buffer is a symbolic n-darray structure.
 *  It is a composition of primitive symbolic types,
49
 *  used to specify the memory layout of the Tensor used in program input.
50 51 52 53
 */
class Buffer : public NodeRef {
 public:
  Buffer() {}
54
  explicit Buffer(NodePtr<Node> n) : NodeRef(n) {}
55
  /*!
56 57 58 59
   * \brief Return a new buffer that is equivalent with current one
   *  but always add stride field.
   * \return The strided version of the buffer.
   */
60
  TVM_DLL Buffer MakeStrideView() const;
61 62 63 64 65 66 67 68
  /*!
   * \brief Make a new symbolic buffer representing a slice of the buffer.
   * \param begins The beginning position of each dimension.
   * \param extents The extent of each dimension.
   * \note This function will make target buffer as compact as possible.
   *  If stride is not needed in the slice, it won't be presented
   * \return the result buffer.
   */
69
  TVM_DLL Buffer MakeSlice(Array<Expr> begins, Array<Expr> extents) const;
70
  /*!
71 72 73
   * \brief Get access ptr to the entire buffer.
   * \param access_mask The access mask
   * \param ptr_type The type of the pointer.
74
   * \param content_lanes The number of lanes for the (data) type.
75
   * \param offset The offset of ptr.
76
   */
77
  TVM_DLL Expr access_ptr(int access_mask, Type ptr_type = Handle(),
78
                          int content_lanes = 1, Expr offset = make_const(Int(32), 0)) const;
79
  /*!
80 81 82 83
   * \brief Create an Expr that does a vector load at begin index.
   * \param begin The beginning index
   * \param dtype The data type to be loaded.
   */
84
  TVM_DLL Expr vload(Array<Expr> begin, Type dtype) const;
85 86 87 88 89
  /*!
   * \brief Create a Stmt that does a vector store at begin index.
   * \param begin The beginning index
   * \param value The value to be stored.
   */
90
  TVM_DLL Stmt vstore(Array<Expr> begin, Expr value) const;
91
  /*!
92 93 94 95
   * \brief access the internal node container
   * \return the pointer to the internal node container
   */
  inline const BufferNode* operator->() const;
96 97 98

  /*! \brief specify container node */
  using ContainerType = BufferNode;
99 100 101 102 103
};

/*! \brief Node to represent a buffer */
class BufferNode : public Node {
 public:
104
  // Data fields.
105 106 107 108
  /*!
   * \brief The pointer to the head of the data
   * \sa data_alignment The alignment of data in bytes.
   */
109
  Var data;
110 111
  /*! \brief data type in the content of the tensor */
  Type dtype;
112 113 114 115 116 117 118
  /*! \brief The shape of the buffer */
  Array<Expr> shape;
  /*!
   * \brief The strides of each dimension
   *  This can be an empty array, indicating array is contiguous
   */
  Array<Expr> strides;
119 120
  /*! \brief The offset in terms of number of dtype elements (including lanes) */
  Expr elem_offset;
121 122 123 124 125
  // Meta data
  /*! \brief optional name of the buffer */
  std::string name;
  /*! \brief storage scope of the buffer, if other than global */
  std::string scope;
126 127 128 129 130 131 132
  /*! \brief Alignment requirement of data pointer in bytes. */
  int data_alignment;
  /*!
   * \brief Factor of elem_offset field,
   *  elem_offset is guaranteed to be multiple of offset_factor.
   */
  int offset_factor;
133 134
  /*! \brief buffer type */
  BufferType buffer_type;
135 136 137 138
  /*! \brief constructor */
  BufferNode() {}

  void VisitAttrs(AttrVisitor* v) final {
139
    v->Visit("data", &data);
140
    v->Visit("dtype", &dtype);
141 142
    v->Visit("shape", &shape);
    v->Visit("strides", &strides);
143
    v->Visit("elem_offset", &elem_offset);
144 145
    v->Visit("name", &name);
    v->Visit("scope", &scope);
146 147
    v->Visit("data_alignment", &data_alignment);
    v->Visit("offset_factor", &offset_factor);
148
    v->Visit("buffer_type", &buffer_type);
149 150
  }

151 152 153 154 155
  /*! \return preferred index type for this buffer node */
  Type DefaultIndexType() const {
    return shape.size() != 0 ? shape[0].type() : Int(32);
  }

156 157
  // User can specify data_alignment and offset_factor to be 0
  // A default value will be picked.
158 159 160 161
  TVM_DLL static Buffer make(Var ptr,
                             Type dtype,
                             Array<Expr> shape,
                             Array<Expr> strides,
162
                             Expr elem_offset,
163 164 165
                             std::string name,
                             std::string scope,
                             int data_alignment,
166 167
                             int offset_factor,
                             BufferType buffer_type);
168 169

  static constexpr const char* _type_key = "Buffer";
170
  TVM_DECLARE_NODE_TYPE_INFO(BufferNode, Node);
171 172 173 174 175 176
};

inline const BufferNode* Buffer::operator->() const {
  return static_cast<const BufferNode*>(node_.get());
}

177 178 179 180 181 182 183 184
/*!
 * \brief Construct a new buffer given shape, and dtype.
 * \param shape The shape of the buffer,
 * \param dtype The content data type.
 * \param name The name of the buffer
 * \return The created buffer.
 * \sa BufferNode::make for complete constructor.
 */
185 186 187
TVM_DLL Buffer decl_buffer(Array<Expr> shape,
                           Type dtype = Float(32),
                           std::string name = "buffer");
188 189
}  // namespace tvm
#endif  // TVM_BUFFER_H_