Commit 53428606 by Gaoxiong Committed by Tianqi Chen

support double buffer to use in ir builder DSL(#1897) (#1898)

parent 2d9bc751
...@@ -59,7 +59,8 @@ class StorageFlattener : public IRMutator { ...@@ -59,7 +59,8 @@ class StorageFlattener : public IRMutator {
if (op->attr_key == attr::realize_scope) { if (op->attr_key == attr::realize_scope) {
storage_scope_[op->node.get()] = op->value.as<StringImm>()->value; storage_scope_[op->node.get()] = op->value.as<StringImm>()->value;
return this->Mutate(op->body); return this->Mutate(op->body);
} else if (op->attr_key == attr::double_buffer_scope) { } else if (op->attr_key == attr::double_buffer_scope &&
op->node.node_->derived_from<OperationNode>()) {
Operation func(op->node.node_); Operation func(op->node.node_);
Stmt body = Mutate(op->body); Stmt body = Mutate(op->body);
for (int i = 0; i < func->num_outputs(); ++i) { for (int i = 0; i < func->num_outputs(); ++i) {
......
...@@ -51,8 +51,41 @@ def test_flatten_storage_align(): ...@@ -51,8 +51,41 @@ def test_flatten_storage_align():
stmt = tvm.ir_pass.Simplify(stmt) stmt = tvm.ir_pass.Simplify(stmt)
assert(stmt.body.extents[0].value == 17 * 8) assert(stmt.body.extents[0].value == 17 * 8)
def test_flatten_double_buffer():
dtype = 'int64'
n = 100
m = 4
tx = tvm.thread_axis("threadIdx.x")
ib = tvm.ir_builder.create()
A = ib.pointer("float32", name="A")
C = ib.pointer("float32", name="C")
ib.scope_attr(tx, "thread_extent", 1)
with ib.for_range(0, n) as i:
B = ib.allocate("float32", m, name="B", scope="shared")
with ib.new_scope():
ib.scope_attr(B.asnode(), "double_buffer_scope", 1)
with ib.for_range(0, m) as j:
B[j] = A[i * 4 + j]
with ib.for_range(0, m) as j:
C[j] = B[j] + 1
stmt = ib.get()
stmt = tvm.ir_pass.StorageFlatten(stmt, {}, 64)
stmt = tvm.ir_pass.InjectDoubleBuffer(stmt, 2)
stmt = tvm.ir_pass.Simplify(stmt)
assert isinstance(stmt.body.body, tvm.stmt.Allocate)
assert stmt.body.body.extents[0].value == 2
f = tvm.ir_pass.MakeAPI(stmt, "db", [A.asnode(), C.asnode()], 2, True)
f = tvm.ir_pass.ThreadSync(f, "shared")
count = [0]
def count_sync(op):
if isinstance(op, tvm.expr.Call) and op.name == "tvm_storage_sync":
count[0] += 1
tvm.ir_pass.PostOrderVisit(f.body, count_sync)
assert count[0] == 4
if __name__ == "__main__": if __name__ == "__main__":
test_flatten_storage_align() test_flatten_storage_align()
test_flatten2() test_flatten2()
test_flatten_prefetch() test_flatten_prefetch()
test_flatten_double_buffer()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment