Commit 2ff74317 by Tianqi Chen Committed by GitHub

[PASS] StorageRewrite Fold Inplace op storage when possible (#759)

* [PASS] StorageRewrite Fold Inplace op storage when possible

* update comment to fix typos
parent 9d6dbe34
......@@ -153,6 +153,12 @@ constexpr const char* coproc_uop_scope = "coproc_uop_scope";
/*! \brief Mark the scope as volatile access for certain handle. */
constexpr const char* volatile_scope = "volatile_scope";
/*!
* \brief Mark the scope as generated by extern primitive.
* such scope can contain arbitrary ir program and we need to be careful
* when make certain assumptions about the structure of the program.
*/
constexpr const char* extern_scope = "extern_scope";
/*!
* \brief Mark the scope as when computation start to happen
* This can hint some code generator to create a new function for compute.
*/
......
......@@ -130,7 +130,7 @@ Stmt ExternOpNode::BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) const {
CHECK_EQ(stage->op.operator->(), this);
Stmt ret = this->body;
Stmt ret = AttrStmt::make(make_zero(Int(32)), attr::extern_scope, 0, this->body);
auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) {
Array<NodeRef> bind_spec;
Array<Expr> tuple;
......
......@@ -15,6 +15,7 @@ def test_add_pipeline():
C = tvm.extern(A.shape, [A], extern_generator, name='C')
s = tvm.create_schedule(C.op)
print(tvm.lower(s, [A, C], simple_mode=True))
def check_llvm():
if not tvm.module.enabled("llvm"):
......
......@@ -19,14 +19,39 @@ def test_storage_share():
stmt = tvm.ir_pass.CanonicalSimplify(stmt)
stmt = tvm.ir_pass.Simplify(stmt)
stmt = tvm.ir_pass.StorageRewrite(stmt)
# verify only have two allocations.
# verify that the data is folded.
# verify only have one allocations.
# verify inplace folding works
num_alloc = [0]
def verify(n):
if isinstance(n, tvm.stmt.Allocate):
num_alloc[0] += 1
tvm.ir_pass.PostOrderVisit(stmt, verify)
assert num_alloc[0] == 1
def test_inplace_rule():
m = 10
A = tvm.placeholder((m,), name='A')
A0 = tvm.compute((m,), lambda i: A[i], name='A0')
A1 = tvm.compute((m,), lambda i: A[i] + 1, name='A1')
AA = tvm.compute((m,), lambda i: A0[i] + A1[i] + A1[0], name='AA')
B = tvm.compute((m,), lambda i: AA[i] + 1, name='B')
s = tvm.create_schedule(B.op)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
stmt = tvm.ir_pass.CanonicalSimplify(stmt)
stmt = tvm.ir_pass.Simplify(stmt)
stmt = tvm.ir_pass.StorageRewrite(stmt)
# verify only have one allocations.
# verify inplace folding works
num_alloc = [0]
def verify(n):
if isinstance(n, tvm.stmt.Allocate):
num_alloc[0] += 1
elif isinstance(n, tvm.stmt.Store):
assert n.buffer_var != n.value.a.buffer_var
tvm.ir_pass.PostOrderVisit(stmt, verify)
assert num_alloc[0] == 2
......@@ -38,7 +63,7 @@ def test_storage_combine():
B = A
stages = []
for t in range(num_stage):
B = tvm.compute((n, ), lambda i: B[i] + (t+1), name='A%d' % t)
B = tvm.compute((n, ), lambda i: B[i] + B[0] + (t+1), name='A%d' % t)
stages.append(B)
s = tvm.create_schedule(B.op)
......@@ -121,12 +146,14 @@ def test_parallel_alloc():
A[j] = A[j] + 2
body = ib.get()
body = tvm.ir_pass.StorageRewrite(body)
assert(isinstance(body.body.body.body.body, tvm.stmt.Allocate))
if __name__ == "__main__":
test_inplace_rule()
test_storage_share()
test_parallel_alloc()
test_storage_combine()
test_storage_share_gpu()
test_storage_share()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment