Commit 5d533ec9 by Yao Wang Committed by Yizhi Liu

Improve x86 Inception (#1506)

* Improve x86 pooling and concat

* Fix

* Fix test concatenate correct layout

* Add conditional vectorize

* Fix lint

* Modify schedule for global pooling

* Fix

* Fix warning

* Fix alter layout test

* Remove vectorization for pooling when using 4D layout

* Remove vectorization for 4D concat

* Fix concatenate layout

* Fix concatenate schedule

* Fix concat

* Fix lint

* Fix concat

* Simplify pooling logic

* Update docstring

* Fix test topi pooling

* Small changes
parent 4dc21bdb
......@@ -280,20 +280,22 @@ reg.register_pattern("conv2d_transpose", OpPattern.OUT_ELEMWISE_FUSABLE)
# max_pool2d
@reg.register_schedule("max_pool2d")
def schedule_max_pool2d(_, outs, target):
def schedule_max_pool2d(attrs, outs, target):
"""Schedule definition of max_pool2d"""
layout = attrs["layout"]
with tvm.target.create(target):
return topi.generic.schedule_pool(outs)
return topi.generic.schedule_pool(outs, layout)
reg.register_pattern("max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
# avg_pool2d
@reg.register_schedule("avg_pool2d")
def schedule_avg_pool2d(_, outs, target):
def schedule_avg_pool2d(attrs, outs, target):
"""Schedule definition of avg_pool2d"""
layout = attrs["layout"]
with tvm.target.create(target):
return topi.generic.schedule_pool(outs)
return topi.generic.schedule_pool(outs, layout)
reg.register_pattern("avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
......
......@@ -2,6 +2,7 @@
"""Tensor transformation ops"""
from __future__ import absolute_import
import tvm
import topi
from .tensor import _fschedule_broadcast, _fschedule_injective
from . import registry as reg
......@@ -58,8 +59,13 @@ reg.register_pattern("squeeze", OpPattern.INJECTIVE)
reg.register_schedule("squeeze", _fschedule_injective)
# concatenate
@reg.register_schedule("concatenate")
def schedule_concatenate(_, outs, target):
"""Schedule definition of concatenate"""
with tvm.target.create(target):
return topi.generic.schedule_concatenate(outs)
reg.register_pattern("concatenate", OpPattern.INJECTIVE)
reg.register_schedule("concatenate", _fschedule_injective)
# split
reg.register_pattern("split", OpPattern.INJECTIVE)
......
......@@ -129,15 +129,31 @@ inline bool ConcatenateCorrectLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
const ConcatenateParam& param = nnvm::get<ConcatenateParam>(attrs.parsed);
CHECK_EQ(ilayouts->size(), last_ilayouts->size());
CHECK_EQ(olayouts->size(), 1U);
for (size_t i = 0; i < ilayouts->size(); ++i) {
const Layout& input = last_ilayouts->at(i).defined() ?
last_ilayouts->at(i) : ilayouts->at(i);
NNVM_ASSIGN_LAYOUT(*ilayouts, i, input);
Layout layout;
if (!ilayouts->at(0).defined()) {
layout = last_ilayouts->at(0);
} else if (param.axis >= static_cast<int>(ilayouts->at(0).ndim())) {
CHECK(last_ilayouts->at(0).defined())
<< "Current input layout " << ilayouts->at(0)
<< " is invalid but last input layout is not "
"defined for the first input.";
layout = last_ilayouts->at(0);
} else if (last_ilayouts->at(0).defined()
&& ilayouts->at(0)[param.axis]
!= last_ilayouts->at(0)[param.axis]) {
layout = last_ilayouts->at(0);
} else {
layout = ilayouts->at(0);
}
for (size_t i = 0; i < ilayouts->size(); ++i) {
NNVM_ASSIGN_LAYOUT(*ilayouts, i, layout);
}
NNVM_ASSIGN_LAYOUT(*olayouts, 0, layout);
return true;
}
......
......@@ -77,14 +77,25 @@ def test_concatenate():
g, ldict = correct_layout(z, {"x": "HW", "y": "HW"})
assert(ldict["x"][0] == "HW")
assert(ldict["y"][0] == "HW")
assert(ldict["concat"][0] == "__undef__")
assert(ldict["concat"][0] == "HW")
# second pass will insert layout transform
_, ldict = correct_layout(g, {"x": "HW16w", "y": "HW16w"})
assert(ldict["x"][0] == "HW16w")
assert(ldict["y"][0] == "HW16w")
assert(ldict["x_HW"][0] == "HW")
assert(ldict["y_HW"][0] == "HW")
assert(ldict["concat"][0] == "__undef__")
assert(ldict["concat"][0] == "HW16w")
x1 = sym.Variable("x", shape=(10, 20, 60))
x2 = sym.Variable("y", shape=(10, 20, 40))
z = sym.concatenate(x1, x2, axis=2, name="concat")
g, ldict = correct_layout(z, {"x": "H20wW", "y": "H20wW"})
assert(ldict["x"][0] == "H20wW")
assert(ldict["y"][0] == "H20wW")
assert(ldict["concat"][0] == "H20wW")
# second pass will insert layout transform
_, ldict = correct_layout(g, {"x": "HW", "y": "HW"})
assert(ldict["x_H20wW"][0] == "H20wW")
assert(ldict["x_H20wW"][0] == "H20wW")
assert(ldict["concat"][0] == "H20wW")
def test_expand_dims():
......@@ -349,4 +360,4 @@ if __name__ == "__main__":
test_transpose()
test_broadcast_to()
test_broadcast_binary()
test_reduce()
\ No newline at end of file
test_reduce()
......@@ -112,18 +112,18 @@ inline Tensor pool_impl(const Tensor& x,
}, "tensor", "pool_max");
} else if (pool_type == kAvgPool) {
auto temp = do_pad ? pad(x, pad_before, pad_after, 0, "pad_temp") : x;
auto tsum = tvm::compute(out_shape, [&](const Array<Var>& output) {
auto tavg = [&](const Array<Var>& output, Expr divide_factor) {
Array<Expr> indices;
for (const Var& var : output) indices.push_back(var);
indices.Set(height_axis, output[height_axis] * stride_height + dheight);
indices.Set(width_axis, output[width_axis] * stride_width + dwidth);
return tvm::sum(temp(indices), { dheight, dwidth });
}, "tensor", "pool_avg");
return tvm::sum(temp(indices) / divide_factor, { dheight, dwidth });
};
return tvm::compute(out_shape,
[&](const Array<Var>& output) {
if (count_include_pad) {
return tsum(output) / (kernel_height * kernel_width);
return tavg(output, kernel_height * kernel_width);
} else {
Expr h_start = output[height_axis] * stride_height - pad_top;
Expr w_start = output[width_axis] * stride_width - pad_left;
......@@ -133,9 +133,9 @@ inline Tensor pool_impl(const Tensor& x,
w_start = ir::Max::make(w_start, make_const(Int(32), 0));
Expr divide_factor = ir::Max::make((h_end - h_start) * (w_end - w_start),
make_const(Int(32), 1));
return tsum(output) / divide_factor;
return tavg(output, divide_factor);
}
}, "tensor", kElementWise);
}, "tensor", "pool_avg");
} else {
LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
return x;
......
# pylint: disable=invalid-name, unused-variable
# pylint: disable=invalid-name, unused-variable, unused-argument
"""Schedule for pooling operators"""
import tvm
from .. import tag
......@@ -70,7 +70,7 @@ def schedule_global_pool(outs):
@generic.schedule_pool.register(["cuda", "gpu"])
def schedule_pool(outs):
def schedule_pool(outs, layout):
"""Schedule for pool.
Parameters
......@@ -79,6 +79,9 @@ def schedule_pool(outs):
The computation graph description of pool
in the format of an array of tensors.
layout: str
Data layout.
Returns
-------
s: Schedule
......
......@@ -29,5 +29,22 @@ def schedule_injective(outs):
s[x].fuse(s[x].op.axis)
return s
@tvm.target.generic_func
def schedule_concatenate(outs):
"""Schedule for concatenate op.
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return schedule_injective(outs)
schedule_elemwise = schedule_injective
schedule_broadcast = schedule_injective
......@@ -282,7 +282,7 @@ def schedule_dense(outs):
@tvm.target.override_native_generic_func("schedule_pool")
def schedule_pool(outs):
def schedule_pool(outs, layout):
"""Schedule for pool
Parameters
......@@ -291,6 +291,9 @@ def schedule_pool(outs):
The computation graph description of pool
in the format of an array of tensors.
layout: str
Data layout.
Returns
-------
sch: Schedule
......
# pylint: disable=invalid-name, unused-variable
# pylint: disable=invalid-name, unused-variable, unused-argument
"""Schedule for pooling operators"""
import tvm
from .. import tag
......@@ -54,7 +54,7 @@ def schedule_global_pool(outs):
@generic.schedule_pool.register(["opengl"])
def schedule_pool(outs):
def schedule_pool(outs, layout):
"""Schedule for pool.
Parameters
......@@ -63,6 +63,9 @@ def schedule_pool(outs):
The computation graph description of pool
in the format of an array of tensors.
layout: str
Data layout.
Returns
-------
s: Schedule
......
......@@ -33,5 +33,51 @@ def schedule_injective(outs):
s[x].parallel(s[x].op.axis[0])
return s
@generic.schedule_concatenate.register(["cpu"])
def schedule_concatenate(outs):
"""X86 schedule for concatenate op.
Parameters
----------
outs: Array of Tensor
The computation graph description of injective in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
def vectorize(sch, tensor, vectorize_limit):
"""Internal vectorization function for concatenate."""
inner_axis = s[tensor].op.axis[len(s[tensor].op.axis) - 1]
inner_length = tensor.shape[len(tensor.shape) - 1].value
if inner_length <= vectorize_limit:
sch[tensor].vectorize(inner_axis)
else:
split_factor = 1
for i in range(vectorize_limit, 1, -1):
if inner_length % i == 0:
split_factor = i
break
if split_factor > 1:
_, inner_i = sch[tensor].split(inner_axis, split_factor)
sch[tensor].vectorize(inner_i)
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
x = outs[0]
s = tvm.create_schedule([x.op for x in outs])
tvm.schedule.AutoInlineInjective(s)
if len(s[x].op.axis) >= 5:
fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1], s[x].op.axis[2])
vectorize(s, x, 64)
s[x].parallel(fused)
elif len(s[x].op.axis) >= 3:
fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1])
s[x].parallel(fused)
else:
s[x].parallel(s[x].op.axis[0])
return s
schedule_elemwise = schedule_injective
schedule_broadcast = schedule_injective
......@@ -4,19 +4,47 @@ import tvm
from .. import generic
from .. import tag
def _parallel_sch(sch):
def _parallel_sch(sch, oshape, do_vectorize=False):
def vectorize(fused_axis, num_parallel_axis, vectorize_limit=64):
"""Internal vectorization utility function."""
reorder_axis = [fused_axis]
for i in range(num_parallel_axis, len(sch.op.axis) - 1):
reorder_axis.append(sch.op.axis[i])
kw, kh = sch.op.reduce_axis
fuse_k = sch.fuse(kw, kh)
c = sch.op.axis[len(sch.op.axis) - 1]
reorder_axis += [fuse_k, c]
sch.reorder(*reorder_axis)
inner_length = oshape[len(oshape) - 1].value
if inner_length <= vectorize_limit:
sch.vectorize(c)
else:
split_factor = 1
for i in range(vectorize_limit, 1, -1):
if inner_length % i == 0:
split_factor = i
break
if split_factor > 1:
_, c_i = sch.split(c, split_factor)
sch.vectorize(c_i)
if len(sch.op.axis) >= 5:
fused = sch.fuse(sch.op.axis[0], sch.op.axis[1], sch.op.axis[2])
sch.parallel(fused)
if do_vectorize:
vectorize(fused, 3)
elif len(sch.op.axis) >= 3:
fused = sch.fuse(sch.op.axis[0], sch.op.axis[1])
sch.parallel(fused)
if do_vectorize:
vectorize(fused, 2)
else:
sch.parallel(sch.op.axis[0])
return
sch.parallel(fused)
@generic.schedule_pool.register(["cpu"])
def schedule_pool(outs):
def schedule_pool(outs, layout):
"""Schedule for pool
Parameters
......@@ -25,6 +53,9 @@ def schedule_pool(outs):
The computation graph description of pool
in the format of an array of tensors.
layout: str
Data layout.
Returns
-------
sch: Schedule
......@@ -37,7 +68,8 @@ def schedule_pool(outs):
def _schedule(PaddedInput, Pool):
if isinstance(PaddedInput.op, tvm.tensor.ComputeOp):
s[PaddedInput].compute_inline()
_parallel_sch(s[Pool])
do_vectorize = layout[-1] not in "HWhw"
_parallel_sch(s[Pool], outs[0].shape, do_vectorize)
def traverse(OP):
"""Internal travserse function"""
......@@ -93,7 +125,7 @@ def schedule_global_pool(outs):
# schedule pool
elif OP.tag.startswith('global_pool'):
Pool = OP.output(0)
_parallel_sch(s[Pool])
_parallel_sch(s[Pool], outs[0].shape)
else:
raise RuntimeError("Unsupported operator: %s" % OP.tag)
......
......@@ -10,9 +10,11 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_
kw = kh
sw = sh
pt, pl, pb, pr = padding
layout = "NCHW"
A = tvm.placeholder((n, ic, ih, iw), name='A')
B = topi.nn.pool(A, kernel=[kh, kw], stride=[sh, sw], padding=padding,
pool_type=pool_type, ceil_mode=ceil_mode, count_include_pad=count_include_pad)
pool_type=pool_type, ceil_mode=ceil_mode,
layout="NCHW", count_include_pad=count_include_pad)
B = topi.nn.relu(B)
dtype = A.dtype
......@@ -54,7 +56,7 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_pool(B)
s = topi.generic.schedule_pool(B, layout)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment