buffer.h 4.65 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
/*!
 *  Copyright (c) 2016 by Contributors
 * \file buffer.h
 * \brief Symbolic n-dimensional array, to represent a memory buffer.
 */
#ifndef TVM_BUFFER_H_
#define TVM_BUFFER_H_

#include <tvm/container.h>
#include <string>

#include "./base.h"
#include "./expr.h"

namespace tvm {

// Internal node container Buffer
class BufferNode;
19

20 21 22 23 24 25
/*! \brief memory access kind */
enum class AccessMask : int {
  kRead = 1,
  kWrite = 2
};

26 27 28
/*!
 * \brief Buffer is a symbolic n-darray structure.
 *  It is a composition of primitive symbolic types,
29
 *  used to specify the memory layout of the Tensor used in program input.
30 31 32 33 34 35
 */
class Buffer : public NodeRef {
 public:
  Buffer() {}
  explicit Buffer(std::shared_ptr<Node> n) : NodeRef(n) {}
  /*!
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
   * \brief Return a new buffer that is equivalent with current one
   *  but always add stride field.
   * \return The strided version of the buffer.
   */
  Buffer MakeStrideView() const;
  /*!
   * \brief Make a new symbolic buffer representing a slice of the buffer.
   * \param begins The beginning position of each dimension.
   * \param extents The extent of each dimension.
   * \note This function will make target buffer as compact as possible.
   *  If stride is not needed in the slice, it won't be presented
   * \return the result buffer.
   */
  Buffer MakeSlice(Array<Expr> begins, Array<Expr> extents) const;
  /*!
51 52 53 54 55 56
   * \brief Get access ptr to the entire buffer.
   * \param access_mask The access mask
   * \param ptr_type The type of the pointer.
   */
  Expr access_ptr(int access_mask, Type ptr_type = Handle()) const;
  /*!
57 58 59 60 61 62 63 64 65 66 67 68
   * \brief Create an Expr that does a vector load at begin index.
   * \param begin The beginning index
   * \param dtype The data type to be loaded.
   */
  Expr vload(Array<Expr> begin, Type dtype) const;
  /*!
   * \brief Create a Stmt that does a vector store at begin index.
   * \param begin The beginning index
   * \param value The value to be stored.
   */
  Stmt vstore(Array<Expr> begin, Expr value) const;
  /*!
69 70 71 72
   * \brief access the internal node container
   * \return the pointer to the internal node container
   */
  inline const BufferNode* operator->() const;
73 74 75

  /*! \brief specify container node */
  using ContainerType = BufferNode;
76 77 78 79 80
};

/*! \brief Node to represent a buffer */
class BufferNode : public Node {
 public:
81
  // Data fields.
82 83 84 85
  /*!
   * \brief The pointer to the head of the data
   * \sa data_alignment The alignment of data in bytes.
   */
86
  Var data;
87 88
  /*! \brief data type in the content of the tensor */
  Type dtype;
89 90 91 92 93 94 95
  /*! \brief The shape of the buffer */
  Array<Expr> shape;
  /*!
   * \brief The strides of each dimension
   *  This can be an empty array, indicating array is contiguous
   */
  Array<Expr> strides;
96 97
  /*! \brief The offset in terms of number of dtype elements (including lanes) */
  Expr elem_offset;
98 99 100 101 102
  // Meta data
  /*! \brief optional name of the buffer */
  std::string name;
  /*! \brief storage scope of the buffer, if other than global */
  std::string scope;
103 104 105 106 107 108 109
  /*! \brief Alignment requirement of data pointer in bytes. */
  int data_alignment;
  /*!
   * \brief Factor of elem_offset field,
   *  elem_offset is guaranteed to be multiple of offset_factor.
   */
  int offset_factor;
110 111 112 113
  /*! \brief constructor */
  BufferNode() {}

  void VisitAttrs(AttrVisitor* v) final {
114
    v->Visit("data", &data);
115
    v->Visit("dtype", &dtype);
116 117
    v->Visit("shape", &shape);
    v->Visit("strides", &strides);
118
    v->Visit("elem_offset", &elem_offset);
119 120
    v->Visit("name", &name);
    v->Visit("scope", &scope);
121 122
    v->Visit("data_alignment", &data_alignment);
    v->Visit("offset_factor", &offset_factor);
123 124
  }

125 126
  // User can specify data_alignment and offset_factor to be 0
  // A default value will be picked.
127 128
  static Buffer make(Var ptr,
                     Type dtype,
129 130
                     Array<Expr> shape,
                     Array<Expr> strides,
131
                     Expr byte_offset,
132 133
                     std::string name,
                     std::string scope,
134 135
                     int data_alignment,
                     int offset_factor);
136 137

  static constexpr const char* _type_key = "Buffer";
138
  TVM_DECLARE_NODE_TYPE_INFO(BufferNode, Node);
139 140 141 142 143 144
};

inline const BufferNode* Buffer::operator->() const {
  return static_cast<const BufferNode*>(node_.get());
}

145 146 147 148 149 150 151 152 153 154 155
/*!
 * \brief Construct a new buffer given shape, and dtype.
 * \param shape The shape of the buffer,
 * \param dtype The content data type.
 * \param name The name of the buffer
 * \return The created buffer.
 * \sa BufferNode::make for complete constructor.
 */
Buffer decl_buffer(Array<Expr> shape,
                   Type dtype = Float(32),
                   std::string name = "buffer");
156 157
}  // namespace tvm
#endif  // TVM_BUFFER_H_