Commit 5a7a056c by Tianqi Chen Committed by GitHub

[LANG/BUFFER] Change buffer arguments to match DLPack order, add scope (#203)

parent 7b821851
Subproject commit 41fe60a76fe6e5669540acf1ef3595bc38025157 Subproject commit e42653d7c3a604eb9f6ee1b5f989ddadd1cea69c
...@@ -26,12 +26,6 @@ class Buffer : public NodeRef { ...@@ -26,12 +26,6 @@ class Buffer : public NodeRef {
Buffer() {} Buffer() {}
explicit Buffer(std::shared_ptr<Node> n) : NodeRef(n) {} explicit Buffer(std::shared_ptr<Node> n) : NodeRef(n) {}
/*! /*!
* \brief construct a new buffer based on shape and strides.
*/
explicit Buffer(Array<Expr> shape,
Type dtype = Float(32),
std::string name = "buffer");
/*!
* \brief Generate a load expression loading the index location of buffer. * \brief Generate a load expression loading the index location of buffer.
* \param index The index to the buffer. * \param index The index to the buffer.
* \return The load expression. * \return The load expression.
...@@ -57,10 +51,11 @@ class Buffer : public NodeRef { ...@@ -57,10 +51,11 @@ class Buffer : public NodeRef {
/*! \brief Node to represent a buffer */ /*! \brief Node to represent a buffer */
class BufferNode : public Node { class BufferNode : public Node {
public: public:
/*! \brief optional name of the buffer */ // Data fields.
std::string name;
/*! \brief The pointer to the head of the data */ /*! \brief The pointer to the head of the data */
Var data; Var data;
/*! \brief data type in the content of the tensor */
Type dtype;
/*! \brief The shape of the buffer */ /*! \brief The shape of the buffer */
Array<Expr> shape; Array<Expr> shape;
/*! /*!
...@@ -68,34 +63,39 @@ class BufferNode : public Node { ...@@ -68,34 +63,39 @@ class BufferNode : public Node {
* This can be an empty array, indicating array is contiguous * This can be an empty array, indicating array is contiguous
*/ */
Array<Expr> strides; Array<Expr> strides;
/*! \brief data type in the content of the tensor */
Type dtype;
/*! /*!
* \brief The offset in bytes to the beginning pointer to data * \brief The offset in bytes to the beginning pointer to data
* Can be undefined, indicating this must be zero. * Can be undefined, indicating this must be zero.
*/ */
Expr byte_offset; Expr byte_offset;
// Meta data
/*! \brief optional name of the buffer */
std::string name;
/*! \brief storage scope of the buffer, if other than global */
std::string scope;
/*! \brief Alignment bytes size of byte_offset */ /*! \brief Alignment bytes size of byte_offset */
int offset_alignment; int offset_alignment;
/*! \brief constructor */ /*! \brief constructor */
BufferNode() {} BufferNode() {}
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("data", &data); v->Visit("data", &data);
v->Visit("dtype", &dtype);
v->Visit("shape", &shape); v->Visit("shape", &shape);
v->Visit("strides", &strides); v->Visit("strides", &strides);
v->Visit("dtype", &dtype);
v->Visit("byte_offset", &byte_offset); v->Visit("byte_offset", &byte_offset);
v->Visit("name", &name);
v->Visit("scope", &scope);
v->Visit("offset_alignment", &offset_alignment); v->Visit("offset_alignment", &offset_alignment);
} }
static Buffer make(std::string name, static Buffer make(Var ptr,
Var ptr, Type dtype,
Array<Expr> shape, Array<Expr> shape,
Array<Expr> strides, Array<Expr> strides,
Type dtype,
Expr byte_offset, Expr byte_offset,
std::string name,
std::string scope,
int offset_alignment); int offset_alignment);
static constexpr const char* _type_key = "Buffer"; static constexpr const char* _type_key = "Buffer";
...@@ -106,5 +106,16 @@ inline const BufferNode* Buffer::operator->() const { ...@@ -106,5 +106,16 @@ inline const BufferNode* Buffer::operator->() const {
return static_cast<const BufferNode*>(node_.get()); return static_cast<const BufferNode*>(node_.get());
} }
/*!
* \brief Construct a new buffer given shape, and dtype.
* \param shape The shape of the buffer,
* \param dtype The content data type.
* \param name The name of the buffer
* \return The created buffer.
* \sa BufferNode::make for complete constructor.
*/
Buffer decl_buffer(Array<Expr> shape,
Type dtype = Float(32),
std::string name = "buffer");
} // namespace tvm } // namespace tvm
#endif // TVM_BUFFER_H_ #endif // TVM_BUFFER_H_
...@@ -369,11 +369,13 @@ def extern(shape, inputs, fcompute, ...@@ -369,11 +369,13 @@ def extern(shape, inputs, fcompute,
return res[0] if len(res) == 1 else res return res[0] if len(res) == 1 else res
def decl_buffer(shape, dtype=None, def decl_buffer(shape,
dtype=None,
name="buffer", name="buffer",
data=None, data=None,
strides=None, strides=None,
byte_offset=None, byte_offset=None,
scope="",
offset_alignment=0): offset_alignment=0):
"""Decleare a new symbolic buffer. """Decleare a new symbolic buffer.
...@@ -402,6 +404,10 @@ def decl_buffer(shape, dtype=None, ...@@ -402,6 +404,10 @@ def decl_buffer(shape, dtype=None,
byte_offset: Expr, optional byte_offset: Expr, optional
The offset in bytes to data pointer. The offset in bytes to data pointer.
scope: str, optional
The storage scope of the buffer, if not global.
If scope equals empty string, it means it is global memory.
offset_alignment: int, optional offset_alignment: int, optional
The alignment of offset The alignment of offset
...@@ -430,7 +436,7 @@ def decl_buffer(shape, dtype=None, ...@@ -430,7 +436,7 @@ def decl_buffer(shape, dtype=None,
data = var(name, "handle") data = var(name, "handle")
return _api_internal._Buffer( return _api_internal._Buffer(
name, data, shape, strides, dtype, byte_offset, offset_alignment) data, dtype, shape, strides, byte_offset, name, scope, offset_alignment)
def _IterVar(dom, name, iter_type, thread_tag=''): def _IterVar(dom, name, iter_type, thread_tag=''):
......
...@@ -151,7 +151,8 @@ TVM_REGISTER_API("_Buffer") ...@@ -151,7 +151,8 @@ TVM_REGISTER_API("_Buffer")
args[3], args[3],
args[4], args[4],
args[5], args[5],
args[6]); args[6],
args[7]);
}); });
TVM_REGISTER_API("_Tensor") TVM_REGISTER_API("_Tensor")
......
...@@ -16,14 +16,16 @@ Array<Expr> GetStrides(Array<Expr> shape) { ...@@ -16,14 +16,16 @@ Array<Expr> GetStrides(Array<Expr> shape) {
return Array<Expr>(vec.rbegin(), vec.rend()); return Array<Expr>(vec.rbegin(), vec.rend());
} }
Buffer::Buffer(Array<Expr> shape, Buffer decl_buffer(Array<Expr> shape,
Type dtype, Type dtype,
std::string name) std::string name) {
: Buffer(BufferNode::make( return BufferNode::make(
name, Var(name, Handle()),
Var(name, Type(Type::Handle, 0, 0)), dtype,
shape, Array<Expr>(), dtype, shape,
Expr(), 0)) { Array<Expr>(),
Expr(),
name, "", 0);
} }
inline Expr BufferOffset(const BufferNode* n, Array<Expr> index) { inline Expr BufferOffset(const BufferNode* n, Array<Expr> index) {
...@@ -61,22 +63,23 @@ Stmt Buffer::MakeStore(Array<Expr> index, Expr value) const { ...@@ -61,22 +63,23 @@ Stmt Buffer::MakeStore(Array<Expr> index, Expr value) const {
const_true(n->dtype.lanes())); const_true(n->dtype.lanes()));
} }
Buffer BufferNode::make(std::string name, Buffer BufferNode::make(Var data,
Var data, Type dtype,
Array<Expr> shape, Array<Expr> shape,
Array<Expr> strides, Array<Expr> strides,
Type dtype,
Expr byte_offset, Expr byte_offset,
std::string name,
std::string scope,
int offset_alignment) { int offset_alignment) {
auto n = std::make_shared<BufferNode>(); auto n = std::make_shared<BufferNode>();
n->name = name; n->data = std::move(data);
n->data = data;
n->shape = shape;
n->strides = strides;
n->dtype = dtype; n->dtype = dtype;
n->shape = std::move(shape);
n->strides = std::move(strides);
n->name = std::move(name);
n->scope = std::move(scope);
if (!byte_offset.defined()) { if (!byte_offset.defined()) {
byte_offset = make_const(shape[0].type(), 0); byte_offset = make_const(n->shape[0].type(), 0);
} }
if (offset_alignment != 0) { if (offset_alignment != 0) {
CHECK_EQ(offset_alignment % dtype.bytes(), 0) CHECK_EQ(offset_alignment % dtype.bytes(), 0)
......
...@@ -83,7 +83,7 @@ class StorageFlattener : public IRMutator { ...@@ -83,7 +83,7 @@ class StorageFlattener : public IRMutator {
for (auto r : e.bounds) { for (auto r : e.bounds) {
shape.push_back(r->extent); shape.push_back(r->extent);
} }
e.buffer = Buffer(shape, op->type, key.GetName()); e.buffer = decl_buffer(shape, op->type, key.GetName());
buf_map_[key] = e; buf_map_[key] = e;
Stmt body = this->Mutate(op->body); Stmt body = this->Mutate(op->body);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment