Unverified Commit 0af5c216 by Yizhi Liu Committed by GitHub

[Codegen] Support broadcast op with symbolic shape (#3389)

* [Codegen] Support broadcast op with symbolic shape

* fix case where last dim = 1

* use enum; simplify stride calculation; improve doc

* fix lint

* improve py doc
parent 26466047
...@@ -36,10 +36,11 @@ namespace tvm { ...@@ -36,10 +36,11 @@ namespace tvm {
// Internal node container Buffer // Internal node container Buffer
class BufferNode; class BufferNode;
/*! \brief memory access kind */ /*! \brief buffer type */
enum class AccessMask : int { enum BufferType : int {
kRead = 1, kDefault = 1,
kWrite = 2 // Maps buffer[i][j][k] -> buffer[i][0][k] if dimension i's shape equals 1.
kAutoBroadcast = 2,
}; };
/*! /*!
...@@ -129,6 +130,8 @@ class BufferNode : public Node { ...@@ -129,6 +130,8 @@ class BufferNode : public Node {
* elem_offset is guaranteed to be multiple of offset_factor. * elem_offset is guaranteed to be multiple of offset_factor.
*/ */
int offset_factor; int offset_factor;
/*! \brief buffer type */
BufferType buffer_type;
/*! \brief constructor */ /*! \brief constructor */
BufferNode() {} BufferNode() {}
...@@ -142,6 +145,7 @@ class BufferNode : public Node { ...@@ -142,6 +145,7 @@ class BufferNode : public Node {
v->Visit("scope", &scope); v->Visit("scope", &scope);
v->Visit("data_alignment", &data_alignment); v->Visit("data_alignment", &data_alignment);
v->Visit("offset_factor", &offset_factor); v->Visit("offset_factor", &offset_factor);
v->Visit("buffer_type", &buffer_type);
} }
/*! \return preferred index type for this buffer node */ /*! \return preferred index type for this buffer node */
...@@ -159,7 +163,8 @@ class BufferNode : public Node { ...@@ -159,7 +163,8 @@ class BufferNode : public Node {
std::string name, std::string name,
std::string scope, std::string scope,
int data_alignment, int data_alignment,
int offset_factor); int offset_factor,
BufferType buffer_type);
static constexpr const char* _type_key = "Buffer"; static constexpr const char* _type_key = "Buffer";
TVM_DECLARE_NODE_TYPE_INFO(BufferNode, Node); TVM_DECLARE_NODE_TYPE_INFO(BufferNode, Node);
......
...@@ -531,7 +531,8 @@ def decl_buffer(shape, ...@@ -531,7 +531,8 @@ def decl_buffer(shape,
elem_offset=None, elem_offset=None,
scope="", scope="",
data_alignment=-1, data_alignment=-1,
offset_factor=0): offset_factor=0,
buffer_type=""):
"""Declare a new symbolic buffer. """Declare a new symbolic buffer.
Normally buffer is created automatically during lower and build. Normally buffer is created automatically during lower and build.
...@@ -574,11 +575,39 @@ def decl_buffer(shape, ...@@ -574,11 +575,39 @@ def decl_buffer(shape,
If 0 is pssed, the alignment will be set to 1. If 0 is pssed, the alignment will be set to 1.
if non-zero is passed, we will created a Var for elem_offset if elem_offset is not None. if non-zero is passed, we will created a Var for elem_offset if elem_offset is not None.
buffer_type: str, optional, {"", "auto_broadcast"}
auto_broadcast buffer allows one to implement broadcast computation
without considering whether dimension size equals to one.
TVM maps buffer[i][j][k] -> buffer[i][0][k] if dimension i's shape equals 1.
Returns Returns
------- -------
buffer : Buffer buffer : Buffer
The created buffer The created buffer
Example
-------
Here's an example of how broadcast buffer can be used to define a symbolic broadcast operation,
.. code-block:: python
m0, m1, m2 = tvm.var("m0"), tvm.var("m1"), tvm.var("m2")
n0, n1, n2 = tvm.var("n0"), tvm.var("n1"), tvm.var("n2")
o0, o1, o2 = tvm.var("o0"), tvm.var("o1"), tvm.var("o2")
A = tvm.placeholder((m0, m1, m2), name='A')
B = tvm.placeholder((n0, n1, n2), name='B')
C = tvm.compute((o0, o1, o2), lambda i, j, k: A[i, j, k] + B[i, j, k], name='C')
Ab = tvm.decl_buffer(A.shape, A.dtype, name="Ab", buffer_type="broadcast")
Bb = tvm.decl_buffer(B.shape, B.dtype, name="Bb", buffer_type="broadcast")
s = tvm.create_schedule(C.op)
fadd = tvm.build(s, [A, B, C], target='llvm', name='bcast_add', binds={A:Ab, B:Bb})
ctx = tvm.cpu(0)
a = tvm.nd.array(np.random.uniform(size=(2, 4, 3)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(2, 1, 3)).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros((2, 4, 3), dtype=C.dtype), ctx)
fadd(a, b, c)
tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
Note Note
---- ----
Buffer data structure reflects the DLTensor structure in dlpack. Buffer data structure reflects the DLTensor structure in dlpack.
...@@ -602,7 +631,7 @@ def decl_buffer(shape, ...@@ -602,7 +631,7 @@ def decl_buffer(shape,
data = var(name, "handle") data = var(name, "handle")
return _api_internal._Buffer( return _api_internal._Buffer(
data, dtype, shape, strides, elem_offset, name, scope, data, dtype, shape, strides, elem_offset, name, scope,
data_alignment, offset_factor) data_alignment, offset_factor, buffer_type)
def layout(layout_str): def layout(layout_str):
"""Create a layout node from a string. """Create a layout node from a string.
......
...@@ -207,7 +207,13 @@ TVM_REGISTER_API("Range") ...@@ -207,7 +207,13 @@ TVM_REGISTER_API("Range")
}); });
TVM_REGISTER_API("_Buffer") TVM_REGISTER_API("_Buffer")
.set_body_typed(BufferNode::make); .set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args.size(), 10);
auto buffer_type = args[9].operator std::string();
BufferType type = (buffer_type == "auto_broadcast") ? kAutoBroadcast : kDefault;
*ret = BufferNode::make(args[0], args[1], args[2], args[3], args[4],
args[5], args[6], args[7], args[8], type);
});
TVM_REGISTER_API("_BufferAccessPtr") TVM_REGISTER_API("_BufferAccessPtr")
.set_body_method(&Buffer::access_ptr); .set_body_method(&Buffer::access_ptr);
......
...@@ -342,7 +342,7 @@ Buffer BufferWithOffsetAlignment(Array<Expr> shape, ...@@ -342,7 +342,7 @@ Buffer BufferWithOffsetAlignment(Array<Expr> shape,
} }
return BufferNode::make(data, dtype, shape, Array<Expr>(), elem_offset, name, "", return BufferNode::make(data, dtype, shape, Array<Expr>(), elem_offset, name, "",
data_alignment, offset_factor); data_alignment, offset_factor, kDefault);
} }
void GetBinds(const Array<Tensor>& args, void GetBinds(const Array<Tensor>& args,
......
...@@ -49,7 +49,8 @@ Buffer decl_buffer(Array<Expr> shape, ...@@ -49,7 +49,8 @@ Buffer decl_buffer(Array<Expr> shape,
Expr(), Expr(),
name, name,
"", "",
0, 0); 0, 0,
kDefault);
} }
// Split the given expression w.r.t the add operator // Split the given expression w.r.t the add operator
...@@ -365,7 +366,8 @@ Buffer Buffer::MakeSlice(Array<Expr> begins, Array<Expr> extents) const { ...@@ -365,7 +366,8 @@ Buffer Buffer::MakeSlice(Array<Expr> begins, Array<Expr> extents) const {
n->name + "_slice", n->name + "_slice",
n->scope, n->scope,
n->data_alignment, n->data_alignment,
0); 0,
n->buffer_type);
} }
Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes, Expr offset) const { Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes, Expr offset) const {
...@@ -405,7 +407,8 @@ Buffer BufferNode::make(Var data, ...@@ -405,7 +407,8 @@ Buffer BufferNode::make(Var data,
std::string name, std::string name,
std::string scope, std::string scope,
int data_alignment, int data_alignment,
int offset_factor) { int offset_factor,
BufferType buffer_type) {
auto n = make_node<BufferNode>(); auto n = make_node<BufferNode>();
n->data = std::move(data); n->data = std::move(data);
n->dtype = dtype; n->dtype = dtype;
...@@ -428,6 +431,12 @@ Buffer BufferNode::make(Var data, ...@@ -428,6 +431,12 @@ Buffer BufferNode::make(Var data,
n->elem_offset = std::move(elem_offset); n->elem_offset = std::move(elem_offset);
n->data_alignment = data_alignment; n->data_alignment = data_alignment;
n->offset_factor = offset_factor; n->offset_factor = offset_factor;
n->buffer_type = buffer_type;
if (n->buffer_type == kAutoBroadcast && n->shape.size() > 0 && n->strides.empty()) {
for (size_t i = 0; i < n->shape.size(); ++i) {
n->strides.push_back(tvm::var("stride"));
}
}
return Buffer(n); return Buffer(n);
} }
......
...@@ -242,6 +242,21 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, ...@@ -242,6 +242,21 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
check = IfThenElse::make(Not::make(is_null), check, Stmt()); check = IfThenElse::make(Not::make(is_null), check, Stmt());
init_nest_.emplace_back(Block::make(check, Evaluate::make(0))); init_nest_.emplace_back(Block::make(check, Evaluate::make(0)));
} }
} else if (buffer->buffer_type == kAutoBroadcast) {
Type stype = buffer->DefaultIndexType();
Expr stride = make_const(stype, 1);
for (size_t i = buffer->shape.size(); i != 0; --i) {
size_t k = i - 1;
std::ostringstream field_name;
field_name << v_strides->name_hint << '[' << k << ']';
Expr value = cast(buffer->shape[k].type(),
Load::make(tvm_shape_type, v_strides,
IntImm::make(Int(32), k), const_true(1)));
value = tvm::if_then_else(is_null, stride, value);
value = tvm::if_then_else(buffer->shape[k] == 1, 0, value);
Bind_(buffer->strides[k], value, field_name.str(), true);
stride = Simplify(stride * buffer->shape[k]);
}
} else { } else {
std::ostringstream stride_null_err_msg; std::ostringstream stride_null_err_msg;
stride_null_err_msg << arg_name << ".strides: expected non-null strides."; stride_null_err_msg << arg_name << ".strides: expected non-null strides.";
......
...@@ -160,7 +160,7 @@ class CopyIntrinInjector : public IRMutator { ...@@ -160,7 +160,7 @@ class CopyIntrinInjector : public IRMutator {
store_strides[loop_var_size], store_strides[loop_var_size],
store->buffer_var->name_hint, store->buffer_var->name_hint,
GetStorageScope(store->buffer_var.get()), GetStorageScope(store->buffer_var.get()),
0, 0); 0, 0, kDefault);
Buffer src = BufferNode::make( Buffer src = BufferNode::make(
Var(load->buffer_var.node_), Var(load->buffer_var.node_),
load->type, load->type,
...@@ -169,7 +169,7 @@ class CopyIntrinInjector : public IRMutator { ...@@ -169,7 +169,7 @@ class CopyIntrinInjector : public IRMutator {
src_elem_offset, src_elem_offset,
load->buffer_var->name_hint, load->buffer_var->name_hint,
GetStorageScope(load->buffer_var.get()), GetStorageScope(load->buffer_var.get()),
0, 0); 0, 0, kDefault);
*out = flower_copy_fromto_(src, dst, pad_before, pad_after, pad_value); *out = flower_copy_fromto_(src, dst, pad_before, pad_after, pad_value);
CHECK(out->defined()) << "flower function did not return correct stmt"; CHECK(out->defined()) << "flower function did not return correct stmt";
return true; return true;
......
...@@ -220,7 +220,7 @@ class StorageFlattener : public IRMutator { ...@@ -220,7 +220,7 @@ class StorageFlattener : public IRMutator {
Var(key.GetName(), Handle()), Var(key.GetName(), Handle()),
op->type, shape, strides, Expr(), op->type, shape, strides, Expr(),
key.GetName(), skey.to_string(), key.GetName(), skey.to_string(),
align, 0); align, 0, kDefault);
buf_map_[key] = e; buf_map_[key] = e;
Stmt body = this->Mutate(op->body); Stmt body = this->Mutate(op->body);
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
# under the License. # under the License.
import tvm import tvm
from tvm.schedule import Buffer from tvm.schedule import Buffer
import numpy as np
def test_buffer(): def test_buffer():
m = tvm.var('m') m = tvm.var('m')
...@@ -108,6 +109,34 @@ def test_buffer_index_merge_mult_mod(): ...@@ -108,6 +109,34 @@ def test_buffer_index_merge_mult_mod():
index_direct = A.vload((0, ((k0 % (k1 / s)) / n) * n + ((k0 % (k1 / n)) % n + (k0 % k1)))) index_direct = A.vload((0, ((k0 % (k1 / s)) / n) * n + ((k0 % (k1 / n)) % n + (k0 % k1))))
assert_simplified_equal(index_simplified, index_direct) assert_simplified_equal(index_simplified, index_direct)
def test_buffer_broadcast():
m0, m1, m2 = tvm.var("m0"), tvm.var("m1"), tvm.var("m2")
n0, n1, n2 = tvm.var("n0"), tvm.var("n1"), tvm.var("n2")
o0, o1, o2 = tvm.var("o0"), tvm.var("o1"), tvm.var("o2")
A = tvm.placeholder((m0, m1, m2), name='A')
B = tvm.placeholder((n0, n1, n2), name='B')
C = tvm.compute((o0, o1, o2), lambda i, j, k: A[i, j, k] + B[i, j, k], name='C')
Ab = tvm.decl_buffer(A.shape, A.dtype, name="Ab", buffer_type="auto_broadcast")
Bb = tvm.decl_buffer(B.shape, B.dtype, name="Bb", buffer_type="auto_broadcast")
s = tvm.create_schedule(C.op)
def check():
if not tvm.module.enabled("llvm"):
return
fadd = tvm.build(s, [A, B, C], target='llvm', name='bcast_add', binds={A:Ab, B:Bb})
ctx = tvm.cpu(0)
a = tvm.nd.array(np.random.uniform(size=(2, 4, 3)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(2, 1, 1)).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros((2, 4, 3), dtype=C.dtype), ctx)
fadd(a, b, c)
tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
check()
if __name__ == "__main__": if __name__ == "__main__":
test_buffer() test_buffer()
test_buffer_access_ptr() test_buffer_access_ptr()
...@@ -115,3 +144,4 @@ if __name__ == "__main__": ...@@ -115,3 +144,4 @@ if __name__ == "__main__":
test_buffer_access_ptr_extent() test_buffer_access_ptr_extent()
test_buffer_vload() test_buffer_vload()
test_buffer_index_merge_mult_mod() test_buffer_index_merge_mult_mod()
test_buffer_broadcast()
...@@ -49,7 +49,7 @@ inline Buffer DeclExternBuffer(Array<Expr> shape, ...@@ -49,7 +49,7 @@ inline Buffer DeclExternBuffer(Array<Expr> shape,
auto data = var(name, Handle()); auto data = var(name, Handle());
auto elem_offset = Expr(); auto elem_offset = Expr();
return BufferNode::make(data, dtype, shape, Array<Expr>(), elem_offset, name, "", return BufferNode::make(data, dtype, shape, Array<Expr>(), elem_offset, name, "",
-1, 0); -1, 0, kDefault);
} }
/*! /*!
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment