Commit c6a1241e by ziheng Committed by Tianqi Chen

[TOPI] Add out_dtype argument for conv2d; Add x86 schedules (#646)

* [TOPI] Add out_dtype argument for conv2d; Add x86 schedules

* Fix

* Fix lint

* Fix
parent d7354628
...@@ -14,6 +14,7 @@ from .reduction import * ...@@ -14,6 +14,7 @@ from .reduction import *
from .transform import * from .transform import *
from .broadcast import * from .broadcast import *
from . import nn from . import nn
from . import x86
from . import cuda from . import cuda
from . import rasp from . import rasp
from . import testing from . import testing
......
...@@ -9,7 +9,7 @@ from ..util import simplify ...@@ -9,7 +9,7 @@ from ..util import simplify
# workload description of conv2d # workload description of conv2d
Workload = namedtuple('Workload', Workload = namedtuple('Workload',
['height', 'width', 'in_filter', 'out_filter', ['in_dtype', 'out_dtype', 'height', 'width', 'in_filter', 'out_filter',
'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride']) 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
# schedule description of spatial # schedule description of spatial
...@@ -22,36 +22,36 @@ Im2ColPack = namedtuple('Im2ColPack', ...@@ -22,36 +22,36 @@ Im2ColPack = namedtuple('Im2ColPack',
_WORKLOADS = [ _WORKLOADS = [
# workloads of resnet18 on imagenet # workloads of resnet18 on imagenet
Workload(224, 224, 3, 64, 7, 7, 3, 3, 2, 2), Workload('float32', 'float32', 224, 224, 3, 64, 7, 7, 3, 3, 2, 2),
Workload(56, 56, 64, 64, 3, 3, 1, 1, 1, 1), Workload('float32', 'float32', 56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
Workload(56, 56, 64, 64, 1, 1, 0, 0, 1, 1), Workload('float32', 'float32', 56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
Workload(56, 56, 64, 128, 3, 3, 1, 1, 2, 2), Workload('float32', 'float32', 56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
Workload(56, 56, 64, 128, 1, 1, 0, 0, 2, 2), Workload('float32', 'float32', 56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
Workload(28, 28, 128, 128, 3, 3, 1, 1, 1, 1), Workload('float32', 'float32', 28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
Workload(28, 28, 128, 256, 3, 3, 1, 1, 2, 2), Workload('float32', 'float32', 28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
Workload(28, 28, 128, 256, 1, 1, 0, 0, 2, 2), Workload('float32', 'float32', 28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
Workload(14, 14, 256, 256, 3, 3, 1, 1, 1, 1), Workload('float32', 'float32', 14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
Workload(14, 14, 256, 512, 3, 3, 1, 1, 2, 2), Workload('float32', 'float32', 14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
Workload(14, 14, 256, 512, 1, 1, 0, 0, 2, 2), Workload('float32', 'float32', 14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
Workload(7, 7, 512, 512, 3, 3, 1, 1, 1, 1), Workload('float32', 'float32', 7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
# workloads of mobile net on imagenet # workloads of mobile net on imagenet
Workload(224, 224, 3, 32, 3, 3, 1, 1, 2, 2), Workload('float32', 'float32', 224, 224, 3, 32, 3, 3, 1, 1, 2, 2),
Workload(112, 112, 32, 64, 1, 1, 0, 0, 1, 1), Workload('float32', 'float32', 112, 112, 32, 64, 1, 1, 0, 0, 1, 1),
Workload(56, 56, 64, 128, 1, 1, 0, 0, 1, 1), Workload('float32', 'float32', 56, 56, 64, 128, 1, 1, 0, 0, 1, 1),
Workload(56, 56, 128, 128, 1, 1, 0, 0, 1, 1), Workload('float32', 'float32', 56, 56, 128, 128, 1, 1, 0, 0, 1, 1),
Workload(28, 28, 128, 256, 1, 1, 0, 0, 1, 1), Workload('float32', 'float32', 28, 28, 128, 256, 1, 1, 0, 0, 1, 1),
Workload(28, 28, 256, 256, 1, 1, 0, 0, 1, 1), Workload('float32', 'float32', 28, 28, 256, 256, 1, 1, 0, 0, 1, 1),
Workload(14, 14, 256, 512, 1, 1, 0, 0, 1, 1), Workload('float32', 'float32', 14, 14, 256, 512, 1, 1, 0, 0, 1, 1),
Workload(14, 14, 512, 512, 1, 1, 0, 0, 1, 1), Workload('float32', 'float32', 14, 14, 512, 512, 1, 1, 0, 0, 1, 1),
Workload(7, 7, 512, 1024, 1, 1, 0, 0, 1, 1), Workload('float32', 'float32', 7, 7, 512, 1024, 1, 1, 0, 0, 1, 1),
Workload(7, 7, 1024, 1024, 1, 1, 0, 0, 1, 1), Workload('float32', 'float32', 7, 7, 1024, 1024, 1, 1, 0, 0, 1, 1),
] ]
# platform specific schedule # platform specific schedule
_CONV_SCHEDULE = {} _CONV_SCHEDULE = {}
@tvm.target.generic_func @tvm.target.generic_func
def conv2d(data, kernel, stride, padding, layout='NCHW'): def conv2d(data, kernel, stride, padding, layout='NCHW', out_dtype='float32'):
"""Conv2D operator. """Conv2D operator.
Parameters Parameters
...@@ -79,14 +79,14 @@ def conv2d(data, kernel, stride, padding, layout='NCHW'): ...@@ -79,14 +79,14 @@ def conv2d(data, kernel, stride, padding, layout='NCHW'):
# search platform specific declaration first # search platform specific declaration first
# default declaration # default declaration
if layout == 'NCHW': if layout == 'NCHW':
return conv2d_nchw(data, kernel, stride, padding) return conv2d_nchw(data, kernel, stride, padding, out_dtype)
elif layout == 'HWCN': elif layout == 'HWCN':
return conv2d_hwcn(data, kernel, stride, padding) return conv2d_hwcn(data, kernel, stride, padding, out_dtype)
else: else:
raise ValueError("not support this layout {} yet".format(layout)) raise ValueError("not support this layout {} yet".format(layout))
def _get_workload(data, kernel, stride, padding): def _get_workload(data, kernel, stride, padding, out_dtype):
""" Get the workload structure. """ """ Get the workload structure. """
_, CI, IH, IW = [x.value for x in data.shape] _, CI, IH, IW = [x.value for x in data.shape]
CO, _, KH, KW = [x.value for x in kernel.shape] CO, _, KH, KW = [x.value for x in kernel.shape]
...@@ -95,7 +95,8 @@ def _get_workload(data, kernel, stride, padding): ...@@ -95,7 +95,8 @@ def _get_workload(data, kernel, stride, padding):
HSTR, WSTR = stride HSTR, WSTR = stride
else: else:
HSTR, WSTR = stride, stride HSTR, WSTR = stride, stride
return Workload(IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR) assert data.dtype == kernel.dtype, "Do not support inputs with different data types now."
return Workload(data.dtype, out_dtype, IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR)
@tvm.target.generic_func @tvm.target.generic_func
...@@ -108,10 +109,10 @@ def _get_schedule(wkl): ...@@ -108,10 +109,10 @@ def _get_schedule(wkl):
# This return has no use, merely to supress pylint warning # This return has no use, merely to supress pylint warning
return wkl return wkl
def _spatial_pack(data, kernel, stride, padding): def _spatial_pack(data, kernel, stride, padding, out_dtype):
""" Compute convolution with pack on spatial axes. """ """ Compute convolution with pack on spatial axes. """
assert data.shape[0].value == 1, "spatial pack convolution only support batch size=1" assert data.shape[0].value == 1, "spatial pack convolution only support batch size=1"
wkl = _get_workload(data, kernel, stride, padding) wkl = _get_workload(data, kernel, stride, padding, out_dtype)
sch = _get_schedule(wkl) sch = _get_schedule(wkl)
H, W = wkl.height, wkl.width H, W = wkl.height, wkl.width
...@@ -158,8 +159,8 @@ def _spatial_pack(data, kernel, stride, padding): ...@@ -158,8 +159,8 @@ def _spatial_pack(data, kernel, stride, padding):
dw = tvm.reduce_axis((0, KW), name='dw') dw = tvm.reduce_axis((0, KW), name='dw')
conv = tvm.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \ conv = tvm.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \
tvm.sum(data_vec[n, h, w, ci, vh*HSTR+dh, vw*WSTR+dw] * tvm.sum(data_vec[n, h, w, ci, vh*HSTR+dh, vw*WSTR+dw].astype(out_dtype) *
kernel_vec[co, ci, dh, dw, vc], kernel_vec[co, ci, dh, dw, vc].astype(out_dtype),
axis=[ci, dh, dw]), name='conv') axis=[ci, dh, dw]), name='conv')
output = tvm.compute(oshape, lambda n, co, h, w: output = tvm.compute(oshape, lambda n, co, h, w:
...@@ -169,10 +170,10 @@ def _spatial_pack(data, kernel, stride, padding): ...@@ -169,10 +170,10 @@ def _spatial_pack(data, kernel, stride, padding):
return output return output
def _im2col_pack(data, kernel, stride, padding): def _im2col_pack(data, kernel, stride, padding, out_dtype):
""" Compute convolution with im2col pack layout. """ """ Compute convolution with im2col pack layout. """
assert data.shape[0].value == 1, "im2col pack convolution only support batch size=1" assert data.shape[0].value == 1, "im2col pack convolution only support batch size=1"
wkl = _get_workload(data, kernel, stride, padding) wkl = _get_workload(data, kernel, stride, padding, out_dtype)
sch = _get_schedule(wkl) sch = _get_schedule(wkl)
N = 1 N = 1
...@@ -234,7 +235,7 @@ def _im2col_pack(data, kernel, stride, padding): ...@@ -234,7 +235,7 @@ def _im2col_pack(data, kernel, stride, padding):
return output return output
def conv2d_nchw(Input, Filter, stride, padding): def conv2d_nchw(Input, Filter, stride, padding, out_dtype='float32'):
"""Convolution operator in NCHW layout. """Convolution operator in NCHW layout.
Parameters Parameters
...@@ -280,11 +281,12 @@ def conv2d_nchw(Input, Filter, stride, padding): ...@@ -280,11 +281,12 @@ def conv2d_nchw(Input, Filter, stride, padding):
return tvm.compute( return tvm.compute(
(batch, out_channel, out_height, out_width), (batch, out_channel, out_height, out_width),
lambda nn, ff, yy, xx: tvm.sum( lambda nn, ff, yy, xx: tvm.sum(
temp[nn, rc, yy * stride_h + ry, xx * stride_w + rx] * Filter[ff, rc, ry, rx], temp[nn, rc, yy * stride_h + ry, xx * stride_w + rx].astype(out_dtype) *
Filter[ff, rc, ry, rx].astype(out_dtype),
axis=[rc, ry, rx]), tag="conv2d_nchw") axis=[rc, ry, rx]), tag="conv2d_nchw")
def conv2d_hwcn(Input, Filter, stride, padding): def conv2d_hwcn(Input, Filter, stride, padding, out_dtype='float32'):
"""Convolution operator in HWCN layout. """Convolution operator in HWCN layout.
Parameters Parameters
...@@ -329,8 +331,8 @@ def conv2d_hwcn(Input, Filter, stride, padding): ...@@ -329,8 +331,8 @@ def conv2d_hwcn(Input, Filter, stride, padding):
Output = tvm.compute( Output = tvm.compute(
(out_height, out_width, out_channel, batch), (out_height, out_width, out_channel, batch),
lambda yy, xx, ff, nn: tvm.sum( lambda yy, xx, ff, nn: tvm.sum(
PaddedInput[yy * stride_h + ry, xx * stride_w + rx, rc, nn] * Filter[ry, rx, rc, ff], PaddedInput[yy * stride_h + ry, xx * stride_w + rx, rc, nn].astype(out_dtype) *
axis=[ry, rx, rc]), Filter[ry, rx, rc, ff].astype(out_dtype), axis=[ry, rx, rc]),
name="Conv2dOutput", tag="conv2d_hwcn") name="Conv2dOutput", tag="conv2d_hwcn")
return Output return Output
......
...@@ -9,7 +9,7 @@ from .util import get_pad_tuple ...@@ -9,7 +9,7 @@ from .util import get_pad_tuple
from ..util import simplify from ..util import simplify
def depthwise_conv2d_nchw(Input, Filter, stride, padding): def depthwise_conv2d_nchw(Input, Filter, stride, padding, out_dtype='float32'):
"""Depthwise convolution nchw forward operator. """Depthwise convolution nchw forward operator.
Parameters Parameters
...@@ -51,8 +51,8 @@ def depthwise_conv2d_nchw(Input, Filter, stride, padding): ...@@ -51,8 +51,8 @@ def depthwise_conv2d_nchw(Input, Filter, stride, padding):
Output = tvm.compute( Output = tvm.compute(
(batch, out_channel, out_height, out_width), (batch, out_channel, out_height, out_width),
lambda b, c, i, j: tvm.sum( lambda b, c, i, j: tvm.sum(
(PaddedInput[b, c/channel_multiplier, i*stride_h + di, j*stride_w + dj] * (PaddedInput[b, c/channel_multiplier, i*stride_h+di, j*stride_w+dj].astype(out_dtype) *
Filter[c/channel_multiplier, c%channel_multiplier, di, dj]), Filter[c/channel_multiplier, c%channel_multiplier, di, dj].astype(out_dtype)),
axis=[di, dj]), axis=[di, dj]),
name='DepthwiseConv2d', tag="depthwise_conv2d_nchw") name='DepthwiseConv2d', tag="depthwise_conv2d_nchw")
return Output return Output
......
...@@ -12,6 +12,7 @@ from ..nn.util import infer_pad, infer_stride ...@@ -12,6 +12,7 @@ from ..nn.util import infer_pad, infer_stride
from .. import generic from .. import generic
_SCHEDULES = [ _SCHEDULES = [
# float32 imagenet
SpatialPack(1, 8, 4, 1, 4, True), SpatialPack(1, 8, 4, 1, 4, True),
SpatialPack(1, 7, 4, 2, 4, True), SpatialPack(1, 7, 4, 2, 4, True),
SpatialPack(1, 4, 8, 4, 1, True), SpatialPack(1, 4, 8, 4, 1, True),
...@@ -25,6 +26,7 @@ _SCHEDULES = [ ...@@ -25,6 +26,7 @@ _SCHEDULES = [
Im2ColPack(7, 4, 1, 8, False), Im2ColPack(7, 4, 1, 8, False),
Im2ColPack(7, 4, 1, 16, False), Im2ColPack(7, 4, 1, 16, False),
# float32 mobilenet
SpatialPack(2, 2, 4, 28, 1, True), SpatialPack(2, 2, 4, 28, 1, True),
SpatialPack(1, 4, 8, 14, 1, False), SpatialPack(1, 4, 8, 14, 1, False),
SpatialPack(1, 2, 16, 8, 1, True), SpatialPack(1, 2, 16, 8, 1, True),
...@@ -47,12 +49,12 @@ def _schedule_conv2d(wkl): ...@@ -47,12 +49,12 @@ def _schedule_conv2d(wkl):
@conv2d.register("rasp") @conv2d.register("rasp")
def _declaration_conv2d(data, kernel, stride, padding, layout): def _declaration_conv2d(data, kernel, stride, padding, layout, out_dtype):
assert layout == 'NCHW', "only support NCHW convolution on rasp" assert layout == 'NCHW', "only support NCHW convolution on rasp"
assert data.shape[0].value == 1, "only support batch size=1 convolution on rasp" assert data.shape[0].value == 1, "only support batch size=1 convolution on rasp"
wkl = _get_workload(data, kernel, stride, padding) wkl = _get_workload(data, kernel, stride, padding, out_dtype)
sch = _get_schedule(wkl) sch = _get_schedule(wkl)
return _SCH_TO_DECL_FUNC[type(sch)](data, kernel, stride, padding) return _SCH_TO_DECL_FUNC[type(sch)](data, kernel, stride, padding, out_dtype)
def _schedule_spatial_conv2d(s, data, data_pad, data_vec, def _schedule_spatial_conv2d(s, data, data_pad, data_vec,
...@@ -64,10 +66,8 @@ def _schedule_spatial_conv2d(s, data, data_pad, data_vec, ...@@ -64,10 +66,8 @@ def _schedule_spatial_conv2d(s, data, data_pad, data_vec,
stride = infer_stride(data, kernel, output) stride = infer_stride(data, kernel, output)
else: else:
stride = infer_stride(data_pad, kernel, output) stride = infer_stride(data_pad, kernel, output)
wkl = _get_workload(data, kernel, stride, padding) wkl = _get_workload(data, kernel, stride, padding, output.dtype)
sch = _get_schedule(wkl)
with tvm.target.rasp():
sch = _get_schedule(wkl)
H, W = wkl.height, wkl.width H, W = wkl.height, wkl.width
CI, CO = wkl.in_filter, wkl.out_filter CI, CO = wkl.in_filter, wkl.out_filter
...@@ -172,7 +172,7 @@ def _schedule_im2col_conv2d(s, data, data_pad, data_col, data_vec, ...@@ -172,7 +172,7 @@ def _schedule_im2col_conv2d(s, data, data_pad, data_col, data_vec,
stride = infer_stride(data, kernel, output) stride = infer_stride(data, kernel, output)
else: else:
stride = infer_stride(data_pad, kernel, output) stride = infer_stride(data_pad, kernel, output)
wkl = _get_workload(data, kernel, stride, padding) wkl = _get_workload(data, kernel, stride, padding, output.dtype)
with _target.rasp(): with _target.rasp():
sch = _get_schedule(wkl) sch = _get_schedule(wkl)
...@@ -280,7 +280,7 @@ def _schedule_im2col_conv2d(s, data, data_pad, data_col, data_vec, ...@@ -280,7 +280,7 @@ def _schedule_im2col_conv2d(s, data, data_pad, data_col, data_vec,
return s return s
@generic.schedule_conv2d_nchw.register(["cpu", "rasp"]) @generic.schedule_conv2d_nchw.register(["rasp"])
def schedule_conv2d(outs): def schedule_conv2d(outs):
"""Create schedule for tensors""" """Create schedule for tensors"""
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
...@@ -294,6 +294,7 @@ def schedule_conv2d(outs): ...@@ -294,6 +294,7 @@ def schedule_conv2d(outs):
for tensor in op.input_tensors: for tensor in op.input_tensors:
if tensor.op.input_tensors: if tensor.op.input_tensors:
traverse(tensor.op) traverse(tensor.op)
if 'spatial_conv_output' in op.tag: if 'spatial_conv_output' in op.tag:
output = op.output(0) output = op.output(0)
conv_out = op.input_tensors[0] conv_out = op.input_tensors[0]
......
...@@ -8,22 +8,22 @@ from ..nn.util import infer_pad, infer_stride, get_pad_tuple ...@@ -8,22 +8,22 @@ from ..nn.util import infer_pad, infer_stride, get_pad_tuple
from .. import generic from .. import generic
_Workload = namedtuple('Workload', _Workload = namedtuple('Workload',
['height', 'width', 'channel', 'multiplier', ['in_dtype', 'out_dtype', 'height', 'width', 'channel', 'multiplier',
'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride']) 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
_Schedule = namedtuple('Schedule', ['vh', 'vw', 'vc', 'bc', 'unroll']) _Schedule = namedtuple('Schedule', ['vh', 'vw', 'vc', 'bc', 'unroll'])
# workloads of depthwise conv mobile net on imagenet # workloads of depthwise conv mobile net on imagenet
_WORKLOADS = [ _WORKLOADS = [
_Workload(112, 112, 32, 1, 3, 3, 1, 1, 1, 1), _Workload('float32', 'float32', 112, 112, 32, 1, 3, 3, 1, 1, 1, 1),
_Workload(112, 112, 64, 1, 3, 3, 1, 1, 2, 2), _Workload('float32', 'float32', 112, 112, 64, 1, 3, 3, 1, 1, 2, 2),
_Workload(56, 56, 128, 1, 3, 3, 1, 1, 1, 1), _Workload('float32', 'float32', 56, 56, 128, 1, 3, 3, 1, 1, 1, 1),
_Workload(56, 56, 128, 1, 3, 3, 1, 1, 2, 2), _Workload('float32', 'float32', 56, 56, 128, 1, 3, 3, 1, 1, 2, 2),
_Workload(28, 28, 256, 1, 3, 3, 1, 1, 1, 1), _Workload('float32', 'float32', 28, 28, 256, 1, 3, 3, 1, 1, 1, 1),
_Workload(28, 28, 256, 1, 3, 3, 1, 1, 2, 2), _Workload('float32', 'float32', 28, 28, 256, 1, 3, 3, 1, 1, 2, 2),
_Workload(14, 14, 512, 1, 3, 3, 1, 1, 1, 1), _Workload('float32', 'float32', 14, 14, 512, 1, 3, 3, 1, 1, 1, 1),
_Workload(14, 14, 512, 1, 3, 3, 1, 1, 2, 2), _Workload('float32', 'float32', 14, 14, 512, 1, 3, 3, 1, 1, 2, 2),
_Workload(14, 14, 1024, 1, 3, 3, 1, 1, 1, 1), _Workload('float32', 'float32', 7, 7, 1024, 1, 3, 3, 1, 1, 1, 1),
] ]
_SCHEDULES = [ _SCHEDULES = [
...@@ -35,10 +35,10 @@ _SCHEDULES = [ ...@@ -35,10 +35,10 @@ _SCHEDULES = [
_Schedule(1, 1, 4, 2, True), _Schedule(1, 1, 4, 2, True),
_Schedule(1, 1, 8, 8, True), _Schedule(1, 1, 8, 8, True),
_Schedule(1, 1, 4, 1, False), _Schedule(1, 1, 4, 1, False),
_Schedule(2, 1, 4, 16, False), _Schedule(1, 1, 4, 4, False),
] ]
def _get_workload(data, kernel, stride, padding): def _get_workload(data, kernel, stride, padding, out_dtype):
_, C, IH, IW = [x.value for x in data.shape] _, C, IH, IW = [x.value for x in data.shape]
_, MT, KH, KW = [x.value for x in kernel.shape] _, MT, KH, KW = [x.value for x in kernel.shape]
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
...@@ -46,7 +46,7 @@ def _get_workload(data, kernel, stride, padding): ...@@ -46,7 +46,7 @@ def _get_workload(data, kernel, stride, padding):
HSTR, WSTR = stride HSTR, WSTR = stride
else: else:
HSTR, WSTR = stride, stride HSTR, WSTR = stride, stride
return _Workload(IH, IW, C, MT, KH, KW, HPAD, WPAD, HSTR, WSTR) return _Workload(data.dtype, out_dtype, IH, IW, C, MT, KH, KW, HPAD, WPAD, HSTR, WSTR)
def _schedule(s, data, data_pad, kernel, output, last): def _schedule(s, data, data_pad, kernel, output, last):
...@@ -55,7 +55,7 @@ def _schedule(s, data, data_pad, kernel, output, last): ...@@ -55,7 +55,7 @@ def _schedule(s, data, data_pad, kernel, output, last):
stride = infer_stride(data, kernel, output) stride = infer_stride(data, kernel, output)
else: else:
stride = infer_stride(data_pad, kernel, output) stride = infer_stride(data_pad, kernel, output)
wkl = _get_workload(data, kernel, stride, padding) wkl = _get_workload(data, kernel, stride, padding, output.dtype)
if wkl not in _WORKLOADS: if wkl not in _WORKLOADS:
return s return s
......
# pylint: disable=redefined-builtin, wildcard-import
"""x86 specific declaration and schedules."""
from __future__ import absolute_import as _abs
from .conv2d import schedule_conv2d
# pylint: disable=invalid-name,unused-variable,invalid-name
"""Conv2D schedule on x86"""
import tvm
from .. import generic
from .. import tag
@generic.schedule_conv2d_nchw.register(["cpu"])
def schedule_conv2d(outs):
"""Create schedule for tensors"""
s = tvm.create_schedule([x.op for x in outs])
def traverse(op):
"""Traverse operators from computation graph"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(op.tag):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
if 'conv2d_nchw' in op.tag:
conv = op.output(0)
kernel = op.input_tensors[1]
data = op.input_tensors[0]
data_pad = None
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
data_pad = data
data = data_pad.op.input_tensors[0]
C = conv
n, c, h, w = C.op.axis
s[C].parallel(c)
s[C].pragma(n, "parallel_launch_point")
traverse(outs[0].op)
return s
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment