Commit c6a1241e by ziheng Committed by Tianqi Chen

[TOPI] Add out_dtype argument for conv2d; Add x86 schedules (#646)

* [TOPI] Add out_dtype argument for conv2d; Add x86 schedules

* Fix

* Fix lint

* Fix
parent d7354628
......@@ -14,6 +14,7 @@ from .reduction import *
from .transform import *
from .broadcast import *
from . import nn
from . import x86
from . import cuda
from . import rasp
from . import testing
......
......@@ -9,7 +9,7 @@ from ..util import simplify
# workload description of conv2d
Workload = namedtuple('Workload',
['height', 'width', 'in_filter', 'out_filter',
['in_dtype', 'out_dtype', 'height', 'width', 'in_filter', 'out_filter',
'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
# schedule description of spatial
......@@ -22,36 +22,36 @@ Im2ColPack = namedtuple('Im2ColPack',
_WORKLOADS = [
# workloads of resnet18 on imagenet
Workload(224, 224, 3, 64, 7, 7, 3, 3, 2, 2),
Workload(56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
Workload(56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
Workload(56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
Workload(56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
Workload(28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
Workload(28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
Workload(28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
Workload(14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
Workload(14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
Workload(14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
Workload(7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
Workload('float32', 'float32', 224, 224, 3, 64, 7, 7, 3, 3, 2, 2),
Workload('float32', 'float32', 56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
Workload('float32', 'float32', 56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
Workload('float32', 'float32', 56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
Workload('float32', 'float32', 28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
Workload('float32', 'float32', 28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
Workload('float32', 'float32', 28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
Workload('float32', 'float32', 14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
Workload('float32', 'float32', 14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
Workload('float32', 'float32', 14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
Workload('float32', 'float32', 7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
# workloads of mobile net on imagenet
Workload(224, 224, 3, 32, 3, 3, 1, 1, 2, 2),
Workload(112, 112, 32, 64, 1, 1, 0, 0, 1, 1),
Workload(56, 56, 64, 128, 1, 1, 0, 0, 1, 1),
Workload(56, 56, 128, 128, 1, 1, 0, 0, 1, 1),
Workload(28, 28, 128, 256, 1, 1, 0, 0, 1, 1),
Workload(28, 28, 256, 256, 1, 1, 0, 0, 1, 1),
Workload(14, 14, 256, 512, 1, 1, 0, 0, 1, 1),
Workload(14, 14, 512, 512, 1, 1, 0, 0, 1, 1),
Workload(7, 7, 512, 1024, 1, 1, 0, 0, 1, 1),
Workload(7, 7, 1024, 1024, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 224, 224, 3, 32, 3, 3, 1, 1, 2, 2),
Workload('float32', 'float32', 112, 112, 32, 64, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 56, 56, 64, 128, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 56, 56, 128, 128, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 28, 28, 128, 256, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 28, 28, 256, 256, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 14, 14, 256, 512, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 14, 14, 512, 512, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 7, 7, 512, 1024, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 7, 7, 1024, 1024, 1, 1, 0, 0, 1, 1),
]
# platform specific schedule
_CONV_SCHEDULE = {}
@tvm.target.generic_func
def conv2d(data, kernel, stride, padding, layout='NCHW'):
def conv2d(data, kernel, stride, padding, layout='NCHW', out_dtype='float32'):
"""Conv2D operator.
Parameters
......@@ -79,14 +79,14 @@ def conv2d(data, kernel, stride, padding, layout='NCHW'):
# search platform specific declaration first
# default declaration
if layout == 'NCHW':
return conv2d_nchw(data, kernel, stride, padding)
return conv2d_nchw(data, kernel, stride, padding, out_dtype)
elif layout == 'HWCN':
return conv2d_hwcn(data, kernel, stride, padding)
return conv2d_hwcn(data, kernel, stride, padding, out_dtype)
else:
raise ValueError("not support this layout {} yet".format(layout))
def _get_workload(data, kernel, stride, padding):
def _get_workload(data, kernel, stride, padding, out_dtype):
""" Get the workload structure. """
_, CI, IH, IW = [x.value for x in data.shape]
CO, _, KH, KW = [x.value for x in kernel.shape]
......@@ -95,7 +95,8 @@ def _get_workload(data, kernel, stride, padding):
HSTR, WSTR = stride
else:
HSTR, WSTR = stride, stride
return Workload(IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR)
assert data.dtype == kernel.dtype, "Do not support inputs with different data types now."
return Workload(data.dtype, out_dtype, IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR)
@tvm.target.generic_func
......@@ -108,10 +109,10 @@ def _get_schedule(wkl):
# This return has no use, merely to supress pylint warning
return wkl
def _spatial_pack(data, kernel, stride, padding):
def _spatial_pack(data, kernel, stride, padding, out_dtype):
""" Compute convolution with pack on spatial axes. """
assert data.shape[0].value == 1, "spatial pack convolution only support batch size=1"
wkl = _get_workload(data, kernel, stride, padding)
wkl = _get_workload(data, kernel, stride, padding, out_dtype)
sch = _get_schedule(wkl)
H, W = wkl.height, wkl.width
......@@ -158,8 +159,8 @@ def _spatial_pack(data, kernel, stride, padding):
dw = tvm.reduce_axis((0, KW), name='dw')
conv = tvm.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \
tvm.sum(data_vec[n, h, w, ci, vh*HSTR+dh, vw*WSTR+dw] *
kernel_vec[co, ci, dh, dw, vc],
tvm.sum(data_vec[n, h, w, ci, vh*HSTR+dh, vw*WSTR+dw].astype(out_dtype) *
kernel_vec[co, ci, dh, dw, vc].astype(out_dtype),
axis=[ci, dh, dw]), name='conv')
output = tvm.compute(oshape, lambda n, co, h, w:
......@@ -169,10 +170,10 @@ def _spatial_pack(data, kernel, stride, padding):
return output
def _im2col_pack(data, kernel, stride, padding):
def _im2col_pack(data, kernel, stride, padding, out_dtype):
""" Compute convolution with im2col pack layout. """
assert data.shape[0].value == 1, "im2col pack convolution only support batch size=1"
wkl = _get_workload(data, kernel, stride, padding)
wkl = _get_workload(data, kernel, stride, padding, out_dtype)
sch = _get_schedule(wkl)
N = 1
......@@ -234,7 +235,7 @@ def _im2col_pack(data, kernel, stride, padding):
return output
def conv2d_nchw(Input, Filter, stride, padding):
def conv2d_nchw(Input, Filter, stride, padding, out_dtype='float32'):
"""Convolution operator in NCHW layout.
Parameters
......@@ -280,11 +281,12 @@ def conv2d_nchw(Input, Filter, stride, padding):
return tvm.compute(
(batch, out_channel, out_height, out_width),
lambda nn, ff, yy, xx: tvm.sum(
temp[nn, rc, yy * stride_h + ry, xx * stride_w + rx] * Filter[ff, rc, ry, rx],
temp[nn, rc, yy * stride_h + ry, xx * stride_w + rx].astype(out_dtype) *
Filter[ff, rc, ry, rx].astype(out_dtype),
axis=[rc, ry, rx]), tag="conv2d_nchw")
def conv2d_hwcn(Input, Filter, stride, padding):
def conv2d_hwcn(Input, Filter, stride, padding, out_dtype='float32'):
"""Convolution operator in HWCN layout.
Parameters
......@@ -329,8 +331,8 @@ def conv2d_hwcn(Input, Filter, stride, padding):
Output = tvm.compute(
(out_height, out_width, out_channel, batch),
lambda yy, xx, ff, nn: tvm.sum(
PaddedInput[yy * stride_h + ry, xx * stride_w + rx, rc, nn] * Filter[ry, rx, rc, ff],
axis=[ry, rx, rc]),
PaddedInput[yy * stride_h + ry, xx * stride_w + rx, rc, nn].astype(out_dtype) *
Filter[ry, rx, rc, ff].astype(out_dtype), axis=[ry, rx, rc]),
name="Conv2dOutput", tag="conv2d_hwcn")
return Output
......
......@@ -9,7 +9,7 @@ from .util import get_pad_tuple
from ..util import simplify
def depthwise_conv2d_nchw(Input, Filter, stride, padding):
def depthwise_conv2d_nchw(Input, Filter, stride, padding, out_dtype='float32'):
"""Depthwise convolution nchw forward operator.
Parameters
......@@ -51,8 +51,8 @@ def depthwise_conv2d_nchw(Input, Filter, stride, padding):
Output = tvm.compute(
(batch, out_channel, out_height, out_width),
lambda b, c, i, j: tvm.sum(
(PaddedInput[b, c/channel_multiplier, i*stride_h + di, j*stride_w + dj] *
Filter[c/channel_multiplier, c%channel_multiplier, di, dj]),
(PaddedInput[b, c/channel_multiplier, i*stride_h+di, j*stride_w+dj].astype(out_dtype) *
Filter[c/channel_multiplier, c%channel_multiplier, di, dj].astype(out_dtype)),
axis=[di, dj]),
name='DepthwiseConv2d', tag="depthwise_conv2d_nchw")
return Output
......
......@@ -12,6 +12,7 @@ from ..nn.util import infer_pad, infer_stride
from .. import generic
_SCHEDULES = [
# float32 imagenet
SpatialPack(1, 8, 4, 1, 4, True),
SpatialPack(1, 7, 4, 2, 4, True),
SpatialPack(1, 4, 8, 4, 1, True),
......@@ -25,6 +26,7 @@ _SCHEDULES = [
Im2ColPack(7, 4, 1, 8, False),
Im2ColPack(7, 4, 1, 16, False),
# float32 mobilenet
SpatialPack(2, 2, 4, 28, 1, True),
SpatialPack(1, 4, 8, 14, 1, False),
SpatialPack(1, 2, 16, 8, 1, True),
......@@ -47,12 +49,12 @@ def _schedule_conv2d(wkl):
@conv2d.register("rasp")
def _declaration_conv2d(data, kernel, stride, padding, layout):
def _declaration_conv2d(data, kernel, stride, padding, layout, out_dtype):
assert layout == 'NCHW', "only support NCHW convolution on rasp"
assert data.shape[0].value == 1, "only support batch size=1 convolution on rasp"
wkl = _get_workload(data, kernel, stride, padding)
wkl = _get_workload(data, kernel, stride, padding, out_dtype)
sch = _get_schedule(wkl)
return _SCH_TO_DECL_FUNC[type(sch)](data, kernel, stride, padding)
return _SCH_TO_DECL_FUNC[type(sch)](data, kernel, stride, padding, out_dtype)
def _schedule_spatial_conv2d(s, data, data_pad, data_vec,
......@@ -64,10 +66,8 @@ def _schedule_spatial_conv2d(s, data, data_pad, data_vec,
stride = infer_stride(data, kernel, output)
else:
stride = infer_stride(data_pad, kernel, output)
wkl = _get_workload(data, kernel, stride, padding)
with tvm.target.rasp():
sch = _get_schedule(wkl)
wkl = _get_workload(data, kernel, stride, padding, output.dtype)
sch = _get_schedule(wkl)
H, W = wkl.height, wkl.width
CI, CO = wkl.in_filter, wkl.out_filter
......@@ -172,7 +172,7 @@ def _schedule_im2col_conv2d(s, data, data_pad, data_col, data_vec,
stride = infer_stride(data, kernel, output)
else:
stride = infer_stride(data_pad, kernel, output)
wkl = _get_workload(data, kernel, stride, padding)
wkl = _get_workload(data, kernel, stride, padding, output.dtype)
with _target.rasp():
sch = _get_schedule(wkl)
......@@ -280,7 +280,7 @@ def _schedule_im2col_conv2d(s, data, data_pad, data_col, data_vec,
return s
@generic.schedule_conv2d_nchw.register(["cpu", "rasp"])
@generic.schedule_conv2d_nchw.register(["rasp"])
def schedule_conv2d(outs):
"""Create schedule for tensors"""
s = tvm.create_schedule([x.op for x in outs])
......@@ -294,6 +294,7 @@ def schedule_conv2d(outs):
for tensor in op.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
if 'spatial_conv_output' in op.tag:
output = op.output(0)
conv_out = op.input_tensors[0]
......
......@@ -8,22 +8,22 @@ from ..nn.util import infer_pad, infer_stride, get_pad_tuple
from .. import generic
_Workload = namedtuple('Workload',
['height', 'width', 'channel', 'multiplier',
['in_dtype', 'out_dtype', 'height', 'width', 'channel', 'multiplier',
'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
_Schedule = namedtuple('Schedule', ['vh', 'vw', 'vc', 'bc', 'unroll'])
# workloads of depthwise conv mobile net on imagenet
_WORKLOADS = [
_Workload(112, 112, 32, 1, 3, 3, 1, 1, 1, 1),
_Workload(112, 112, 64, 1, 3, 3, 1, 1, 2, 2),
_Workload(56, 56, 128, 1, 3, 3, 1, 1, 1, 1),
_Workload(56, 56, 128, 1, 3, 3, 1, 1, 2, 2),
_Workload(28, 28, 256, 1, 3, 3, 1, 1, 1, 1),
_Workload(28, 28, 256, 1, 3, 3, 1, 1, 2, 2),
_Workload(14, 14, 512, 1, 3, 3, 1, 1, 1, 1),
_Workload(14, 14, 512, 1, 3, 3, 1, 1, 2, 2),
_Workload(14, 14, 1024, 1, 3, 3, 1, 1, 1, 1),
_Workload('float32', 'float32', 112, 112, 32, 1, 3, 3, 1, 1, 1, 1),
_Workload('float32', 'float32', 112, 112, 64, 1, 3, 3, 1, 1, 2, 2),
_Workload('float32', 'float32', 56, 56, 128, 1, 3, 3, 1, 1, 1, 1),
_Workload('float32', 'float32', 56, 56, 128, 1, 3, 3, 1, 1, 2, 2),
_Workload('float32', 'float32', 28, 28, 256, 1, 3, 3, 1, 1, 1, 1),
_Workload('float32', 'float32', 28, 28, 256, 1, 3, 3, 1, 1, 2, 2),
_Workload('float32', 'float32', 14, 14, 512, 1, 3, 3, 1, 1, 1, 1),
_Workload('float32', 'float32', 14, 14, 512, 1, 3, 3, 1, 1, 2, 2),
_Workload('float32', 'float32', 7, 7, 1024, 1, 3, 3, 1, 1, 1, 1),
]
_SCHEDULES = [
......@@ -35,10 +35,10 @@ _SCHEDULES = [
_Schedule(1, 1, 4, 2, True),
_Schedule(1, 1, 8, 8, True),
_Schedule(1, 1, 4, 1, False),
_Schedule(2, 1, 4, 16, False),
_Schedule(1, 1, 4, 4, False),
]
def _get_workload(data, kernel, stride, padding):
def _get_workload(data, kernel, stride, padding, out_dtype):
_, C, IH, IW = [x.value for x in data.shape]
_, MT, KH, KW = [x.value for x in kernel.shape]
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
......@@ -46,7 +46,7 @@ def _get_workload(data, kernel, stride, padding):
HSTR, WSTR = stride
else:
HSTR, WSTR = stride, stride
return _Workload(IH, IW, C, MT, KH, KW, HPAD, WPAD, HSTR, WSTR)
return _Workload(data.dtype, out_dtype, IH, IW, C, MT, KH, KW, HPAD, WPAD, HSTR, WSTR)
def _schedule(s, data, data_pad, kernel, output, last):
......@@ -55,7 +55,7 @@ def _schedule(s, data, data_pad, kernel, output, last):
stride = infer_stride(data, kernel, output)
else:
stride = infer_stride(data_pad, kernel, output)
wkl = _get_workload(data, kernel, stride, padding)
wkl = _get_workload(data, kernel, stride, padding, output.dtype)
if wkl not in _WORKLOADS:
return s
......
# pylint: disable=redefined-builtin, wildcard-import
"""x86 specific declaration and schedules."""
from __future__ import absolute_import as _abs
from .conv2d import schedule_conv2d
# pylint: disable=invalid-name,unused-variable,invalid-name
"""Conv2D schedule on x86"""
import tvm
from .. import generic
from .. import tag
@generic.schedule_conv2d_nchw.register(["cpu"])
def schedule_conv2d(outs):
"""Create schedule for tensors"""
s = tvm.create_schedule([x.op for x in outs])
def traverse(op):
"""Traverse operators from computation graph"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(op.tag):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
if 'conv2d_nchw' in op.tag:
conv = op.output(0)
kernel = op.input_tensors[1]
data = op.input_tensors[0]
data_pad = None
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
data_pad = data
data = data_pad.op.input_tensors[0]
C = conv
n, c, h, w = C.op.axis
s[C].parallel(c)
s[C].pragma(n, "parallel_launch_point")
traverse(outs[0].op)
return s
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment