Commit c8aa6f99 by masahi Committed by Yizhi Liu

[TOPI, x86] Adapt AVX schedules for SSE target (#1504)

* adapt avx schedule for sse

* fixed output type when mixed precision mode
parent 325c8274
......@@ -117,11 +117,9 @@ def _declaration_conv(data, kernel, stride, padding, layout, out_dtype):
out_dtype = data.dtype if out_dtype is None else out_dtype
target = tvm.target.current_target(allow_none=False)
wkl = _get_workload(data, kernel, stride, padding, out_dtype)
if 'avx' in str(target) and layout == 'NCHW':
if layout == 'NCHW':
sch = _get_schedule(wkl)
return _AVX_SCH_TO_DECL_FUNC[type(sch)](data, kernel, stride, padding, layout, out_dtype)
elif layout == 'NCHW':
return nn.conv2d_nchw(data, kernel, stride, padding, out_dtype)
elif layout == 'HWCN':
return nn.conv2d_hwcn(data, kernel, stride, padding, out_dtype)
elif layout == 'NHWC':
......@@ -191,77 +189,39 @@ def schedule_conv2d(outs):
s = tvm.create_schedule([x.op for x in outs])
target = tvm.target.current_target(allow_none=False)
def default_schedule(op):
"""NCHW conv2d schedule for non imagenet workloads"""
conv = op.output(0)
kernel = op.input_tensors[1]
if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
s[kernel].compute_inline()
data = op.input_tensors[0]
data_pad = None
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
data_pad = data
data = data_pad.op.input_tensors[0]
n_pad, c_pad, h_pad, w_pad = data_pad.op.axis
pad_fused = s[data_pad].fuse(n_pad, c_pad)
s[data_pad].parallel(pad_fused)
C = conv
n, c, h, w = C.op.axis
rc, ry, rx = C.op.reduce_axis
fused = s[C].fuse(n, c)
s[C].parallel(fused)
wo, wi = s[C].split(w, factor=16)
s[C].reorder(fused, rc, h, wo, ry, rx, wi) # move rc to outer loop
s[C].unroll(rx)
s[C].unroll(ry)
s[C].vectorize(wi)
def traverse(op):
"""Traverse operators from computation graph"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(op.tag):
if op not in s.outputs:
s[op].compute_inline()
else: # inject custom schedule
if len(op.axis) == 4 and 'avx' not in str(target): # schedule bias + bn + relu
n, c, h, w = op.axis
fused = s[op].fuse(n, c)
s[op].parallel(fused)
s[op].vectorize(w)
for tensor in op.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
if 'conv2d_nchw' in op.tag:
if 'avx' in str(target):
try:
output = op.output(0)
conv_out = op.input_tensors[0]
kernel_vec = conv_out.op.input_tensors[1]
kernel = kernel_vec.op.input_tensors[0]
if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
s[kernel].compute_inline()
data_vec = conv_out.op.input_tensors[0]
data = data_vec.op.input_tensors[0]
data_pad = None
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
data_pad = data
data = data_pad.op.input_tensors[0]
padding = infer_pad(data, data_pad)
if data_pad is None:
stride = infer_stride(data, kernel, output)
else:
stride = infer_stride(data_pad, kernel, output)
wkl = _get_workload(data, kernel, stride, padding, output.dtype)
sch = _get_schedule(wkl)
_AVX_SCH_TO_SCH_FUNC[type(sch)](s, data, data_pad, data_vec,
kernel, kernel_vec, conv_out, output, outs[0])
except IndexError:
default_schedule(op)
output = op.output(0)
conv_out = op.input_tensors[0]
kernel_vec = conv_out.op.input_tensors[1]
kernel = kernel_vec.op.input_tensors[0]
if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
s[kernel].compute_inline()
data_vec = conv_out.op.input_tensors[0]
data = data_vec.op.input_tensors[0]
data_pad = None
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
data_pad = data
data = data_pad.op.input_tensors[0]
padding = infer_pad(data, data_pad)
if data_pad is None:
stride = infer_stride(data, kernel, output)
else:
default_schedule(op)
stride = infer_stride(data_pad, kernel, output)
wkl = _get_workload(data, kernel, stride, padding, output.dtype)
sch = _get_schedule(wkl)
_AVX_SCH_TO_SCH_FUNC[type(sch)](s, data, data_pad, data_vec,
kernel, kernel_vec, conv_out, output, outs[0])
traverse(outs[0].op)
return s
......
......@@ -85,13 +85,16 @@ def _declaration_conv(data, kernel, stride, padding, layout, out_dtype):
kw = tvm.reduce_axis((0, kernel_width), name='kw')
conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
tvm.sum(data_vec[n, ic//sch.ic_bn, oh*HSTR+kh, ic%sch.ic_bn, ow*WSTR+kw] *
kernel_vec[oc_chunk, ic//sch.ic_bn, kh, kw, ic%sch.ic_bn, oc_block],
tvm.sum(data_vec[n, ic//sch.ic_bn, oh*HSTR+kh, ic%sch.ic_bn, ow*WSTR+kw]
.astype(out_dtype) *
kernel_vec[oc_chunk, ic//sch.ic_bn, kh, kw, ic%sch.ic_bn, oc_block]
.astype(out_dtype),
axis=[ic, kh, kw]),
name='conv')
unpack = tvm.compute(unpack_shape,
lambda n, c, h, w: conv[n, c // sch.oc_bn, h, w, c % sch.oc_bn],
lambda n, c, h, w: conv[n, c // sch.oc_bn, h, w, c % sch.oc_bn]
.astype(out_dtype),
name='output_unpack',
tag='conv2d_nchw')
return unpack
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment