Commit 27205e36 by Tianqi Chen Committed by GitHub

[BUGFIX] Thread related bound (#86)

parent 9ec40edd
......@@ -148,6 +148,9 @@ void InferRootBound(const Stage& stage,
<< "call schedule.normalize to achieve this.";
if (NeedRelax(iv, found_attach, ctx.bind_map, scope)) {
relax_set[iv->var.get()] = IntSet::range(vrange);
if (ctx.bind_map.count(iv)) {
relax_set[ctx.bind_map.at(iv)->var.get()] = IntSet::range(vrange);
}
}
}
CHECK(found_attach || stage_attach.size() == 0)
......
......@@ -175,6 +175,63 @@ def test_bound_nest_thread():
assert(bounds[A2.op.axis[0]].extent.value==32)
assert(bounds[A3.op.axis[0]].extent == m)
def test_gemm_bound():
nn = 1024
n = tvm.convert(nn)
A = tvm.placeholder((n, n), name='A')
B = tvm.placeholder((n, n), name='B')
k = tvm.reduce_axis((0, n), name='k')
C = tvm.compute(
(n, n),
lambda ii, jj: tvm.sum(A[ii, k] * B[jj, k], axis=k),
name='CC')
# schedule
s = tvm.Schedule(C.op)
xtile, ytile = 32, 32
scale = 8
num_thread = 8
block_factor = scale * num_thread
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis("threadIdx.x")
block_y = tvm.thread_axis("blockIdx.y")
thread_y = tvm.thread_axis("threadIdx.y")
CC = s.cache_write(C, "local")
AA = s.cache_read(A, "shared", [CC])
BB = s.cache_read(B, "shared", [CC])
by, yi = s[C].split(C.op.axis[0], factor=block_factor)
bx, xi = s[C].split(C.op.axis[1], factor=block_factor)
s[C].reorder(by, bx, yi, xi)
s[C].bind(by, block_y)
s[C].bind(bx, block_x)
ty, yi = s[C].split(yi, nparts=num_thread)
tx, xi = s[C].split(xi, nparts=num_thread)
s[C].reorder(ty, tx, yi, xi)
s[C].bind(ty, thread_y)
s[C].bind(tx, thread_x)
yo, xo = CC.op.axis
s[CC].reorder(k, yo, xo)
s[CC].compute_at(s[C], tx)
s[AA].compute_at(s[CC], k)
s[BB].compute_at(s[CC], k)
ty, xi = s[AA].split(s[AA].op.axis[0], nparts=num_thread)
tx, xi = s[AA].split(xi, nparts=num_thread)
s[AA].bind(ty, thread_y)
s[AA].bind(tx, thread_x)
ty, xi = s[BB].split(s[BB].op.axis[0], nparts=num_thread)
tx, xi = s[BB].split(xi, nparts=num_thread)
s[BB].bind(ty, thread_y)
s[BB].bind(tx, thread_x)
s.normalize()
bounds = tvm.schedule.InferBound(s)
assert(bounds[BB.op.axis[0]].extent.value==64)
assert(bounds[AA.op.axis[0]].extent.value==64)
assert(bounds[CC.op.axis[0]].extent.value == 8)
assert(bounds[CC.op.axis[1]].extent.value == 8)
if __name__ == "__main__":
test_bound_nest_thread()
......@@ -187,3 +244,4 @@ if __name__ == "__main__":
test_bound_blur()
test_bound_conv1d()
test_bound2()
test_gemm_bound()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment