Commit c6b1020b by Yizhi Liu Committed by Tianqi Chen

Generalize pooling to support arbitrary layout (#1103)

* generalize pool2d to arbitrary layout

* explain more the layout support for pool

* allow missing factor size for pooling

* explain what factor size is used for

* fix typo

* name idx -> axis
parent 154104b3
......@@ -33,7 +33,9 @@ inline Schedule schedule_pool(const Target &target, const Array<Tensor>& outs) {
auto s = create_schedule(out_ops);
auto _schedule = [&](const Tensor& padded_input, const Tensor& pool) {
s[padded_input].compute_inline();
if (padded_input->op->is_type<ComputeOpNode>()) {
s[padded_input].compute_inline();
}
auto num_thread = target->max_num_threads;
Tensor out;
Tensor OL;
......
......@@ -84,7 +84,8 @@ def schedule_pool(outs):
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
def _schedule(PaddedInput, Pool):
s[PaddedInput].compute_inline()
if isinstance(PaddedInput.op, tvm.tensor.ComputeOp):
s[PaddedInput].compute_inline()
num_thread = tvm.target.current_target(allow_none=False).max_num_threads
if Pool.op in s.outputs:
Out = Pool
......
......@@ -67,7 +67,8 @@ def schedule_pool(outs):
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
def _schedule(PaddedInput, Pool):
s[PaddedInput].opengl()
if isinstance(PaddedInput.op, tvm.tensor.ComputeOp):
s[PaddedInput].opengl()
if Pool.op in s.outputs:
Out = Pool
else:
......
......@@ -27,6 +27,10 @@ def schedule_injective(outs):
n, c, _, _ = s[x].op.axis
fused = s[x].fuse(n, c) # for nhwc layout, fuse n and h
s[x].parallel(fused)
elif len(s[x].op.axis) == 5:
n, C, h, _, _ = s[x].op.axis
fused = s[x].fuse(n, C, h)
s[x].parallel(fused)
else:
s[x].parallel(s[x].op.axis[0])
return s
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment