Commit 5a7a056c by Tianqi Chen Committed by GitHub

[LANG/BUFFER] Change buffer arguments to match DLPack order, add scope (#203)

parent 7b821851
Subproject commit 41fe60a76fe6e5669540acf1ef3595bc38025157
Subproject commit e42653d7c3a604eb9f6ee1b5f989ddadd1cea69c
......@@ -26,12 +26,6 @@ class Buffer : public NodeRef {
Buffer() {}
explicit Buffer(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief construct a new buffer based on shape and strides.
*/
explicit Buffer(Array<Expr> shape,
Type dtype = Float(32),
std::string name = "buffer");
/*!
* \brief Generate a load expression loading the index location of buffer.
* \param index The index to the buffer.
* \return The load expression.
......@@ -57,10 +51,11 @@ class Buffer : public NodeRef {
/*! \brief Node to represent a buffer */
class BufferNode : public Node {
public:
/*! \brief optional name of the buffer */
std::string name;
// Data fields.
/*! \brief The pointer to the head of the data */
Var data;
/*! \brief data type in the content of the tensor */
Type dtype;
/*! \brief The shape of the buffer */
Array<Expr> shape;
/*!
......@@ -68,34 +63,39 @@ class BufferNode : public Node {
* This can be an empty array, indicating array is contiguous
*/
Array<Expr> strides;
/*! \brief data type in the content of the tensor */
Type dtype;
/*!
* \brief The offset in bytes to the beginning pointer to data
* Can be undefined, indicating this must be zero.
*/
Expr byte_offset;
// Meta data
/*! \brief optional name of the buffer */
std::string name;
/*! \brief storage scope of the buffer, if other than global */
std::string scope;
/*! \brief Alignment bytes size of byte_offset */
int offset_alignment;
/*! \brief constructor */
BufferNode() {}
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("data", &data);
v->Visit("dtype", &dtype);
v->Visit("shape", &shape);
v->Visit("strides", &strides);
v->Visit("dtype", &dtype);
v->Visit("byte_offset", &byte_offset);
v->Visit("name", &name);
v->Visit("scope", &scope);
v->Visit("offset_alignment", &offset_alignment);
}
static Buffer make(std::string name,
Var ptr,
static Buffer make(Var ptr,
Type dtype,
Array<Expr> shape,
Array<Expr> strides,
Type dtype,
Expr byte_offset,
std::string name,
std::string scope,
int offset_alignment);
static constexpr const char* _type_key = "Buffer";
......@@ -106,5 +106,16 @@ inline const BufferNode* Buffer::operator->() const {
return static_cast<const BufferNode*>(node_.get());
}
/*!
* \brief Construct a new buffer given shape, and dtype.
* \param shape The shape of the buffer,
* \param dtype The content data type.
* \param name The name of the buffer
* \return The created buffer.
* \sa BufferNode::make for complete constructor.
*/
Buffer decl_buffer(Array<Expr> shape,
Type dtype = Float(32),
std::string name = "buffer");
} // namespace tvm
#endif // TVM_BUFFER_H_
......@@ -369,11 +369,13 @@ def extern(shape, inputs, fcompute,
return res[0] if len(res) == 1 else res
def decl_buffer(shape, dtype=None,
def decl_buffer(shape,
dtype=None,
name="buffer",
data=None,
strides=None,
byte_offset=None,
scope="",
offset_alignment=0):
"""Decleare a new symbolic buffer.
......@@ -402,6 +404,10 @@ def decl_buffer(shape, dtype=None,
byte_offset: Expr, optional
The offset in bytes to data pointer.
scope: str, optional
The storage scope of the buffer, if not global.
If scope equals empty string, it means it is global memory.
offset_alignment: int, optional
The alignment of offset
......@@ -430,7 +436,7 @@ def decl_buffer(shape, dtype=None,
data = var(name, "handle")
return _api_internal._Buffer(
name, data, shape, strides, dtype, byte_offset, offset_alignment)
data, dtype, shape, strides, byte_offset, name, scope, offset_alignment)
def _IterVar(dom, name, iter_type, thread_tag=''):
......
......@@ -151,7 +151,8 @@ TVM_REGISTER_API("_Buffer")
args[3],
args[4],
args[5],
args[6]);
args[6],
args[7]);
});
TVM_REGISTER_API("_Tensor")
......
......@@ -16,14 +16,16 @@ Array<Expr> GetStrides(Array<Expr> shape) {
return Array<Expr>(vec.rbegin(), vec.rend());
}
Buffer::Buffer(Array<Expr> shape,
Buffer decl_buffer(Array<Expr> shape,
Type dtype,
std::string name)
: Buffer(BufferNode::make(
name,
Var(name, Type(Type::Handle, 0, 0)),
shape, Array<Expr>(), dtype,
Expr(), 0)) {
std::string name) {
return BufferNode::make(
Var(name, Handle()),
dtype,
shape,
Array<Expr>(),
Expr(),
name, "", 0);
}
inline Expr BufferOffset(const BufferNode* n, Array<Expr> index) {
......@@ -61,22 +63,23 @@ Stmt Buffer::MakeStore(Array<Expr> index, Expr value) const {
const_true(n->dtype.lanes()));
}
Buffer BufferNode::make(std::string name,
Var data,
Buffer BufferNode::make(Var data,
Type dtype,
Array<Expr> shape,
Array<Expr> strides,
Type dtype,
Expr byte_offset,
std::string name,
std::string scope,
int offset_alignment) {
auto n = std::make_shared<BufferNode>();
n->name = name;
n->data = data;
n->shape = shape;
n->strides = strides;
n->data = std::move(data);
n->dtype = dtype;
n->shape = std::move(shape);
n->strides = std::move(strides);
n->name = std::move(name);
n->scope = std::move(scope);
if (!byte_offset.defined()) {
byte_offset = make_const(shape[0].type(), 0);
byte_offset = make_const(n->shape[0].type(), 0);
}
if (offset_alignment != 0) {
CHECK_EQ(offset_alignment % dtype.bytes(), 0)
......
......@@ -83,7 +83,7 @@ class StorageFlattener : public IRMutator {
for (auto r : e.bounds) {
shape.push_back(r->extent);
}
e.buffer = Buffer(shape, op->type, key.GetName());
e.buffer = decl_buffer(shape, op->type, key.GetName());
buf_map_[key] = e;
Stmt body = this->Mutate(op->body);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment