Commit 489ec872 by Yuwei HU Committed by Tianqi Chen

fix fuse reduce_axis error in pooling schedule (#482)

parent acd9db84
...@@ -32,9 +32,6 @@ def schedule_global_pool(outs): ...@@ -32,9 +32,6 @@ def schedule_global_pool(outs):
Out = outs[0].op.output(0) Out = outs[0].op.output(0)
s[Pool].set_scope("local") s[Pool].set_scope("local")
i, c, h, w = s[Out].op.axis i, c, h, w = s[Out].op.axis
dh, dw = s[Pool].op.reduce_axis
fuse_index = s[Pool].fuse(dw, dh)
s[Pool].unroll(fuse_index)
by, ty = s[Out].split(i, factor=num_thread) by, ty = s[Out].split(i, factor=num_thread)
bx, tx = s[Out].split(c, factor=num_thread) bx, tx = s[Out].split(c, factor=num_thread)
s[Out].reorder(by, bx, ty, tx) s[Out].reorder(by, bx, ty, tx)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment