Commit 97e333ca by Balint Cristian Committed by Lianmin Zheng

Add Winograd matrices computation. (#3553)

parent ef909df1
...@@ -20,20 +20,19 @@ from __future__ import absolute_import as _abs ...@@ -20,20 +20,19 @@ from __future__ import absolute_import as _abs
import warnings import warnings
import numpy as np
import tvm import tvm
from tvm import autotvm from tvm import autotvm
import tvm.contrib.nnpack import tvm.contrib.nnpack
from ..generic import schedule_conv2d_nchw, schedule_conv2d_winograd_without_weight_transform, \ from ..generic import schedule_conv2d_nchw, schedule_conv2d_winograd_without_weight_transform, \
schedule_conv2d_winograd_nnpack_without_weight_transform schedule_conv2d_winograd_nnpack_without_weight_transform
from ..util import traverse_inline, get_const_tuple, const_matrix from ..util import traverse_inline, get_const_tuple
from ..nn import dilate, pad, conv2d, conv2d_alter_layout, \ from ..nn import dilate, pad, conv2d, conv2d_alter_layout, \
conv2d_winograd_without_weight_transform, \ conv2d_winograd_without_weight_transform, \
conv2d_winograd_nnpack_without_weight_transform, \ conv2d_winograd_nnpack_without_weight_transform, \
depthwise_conv2d_nchw depthwise_conv2d_nchw
from ..nn.util import get_const_int, get_pad_tuple from ..nn.util import get_const_int, get_pad_tuple
from ..nn.winograd_util import winograd_transform_matrices
@autotvm.register_topi_compute(conv2d, 'arm_cpu', ['direct']) @autotvm.register_topi_compute(conv2d, 'arm_cpu', ['direct'])
def conv2d_arm_cpu(cfg, data, kernel, strides, padding, dilation, layout, out_dtype): def conv2d_arm_cpu(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
...@@ -330,57 +329,14 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt ...@@ -330,57 +329,14 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
assert layout == 'NCHW' assert layout == 'NCHW'
assert KH == 3 and KW == 3 and HPAD == 1 and WPAD == 1 and HSTR == 1 and WSTR == 1 assert KH == 3 and KW == 3 and HSTR == 1 and WSTR == 1
data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad") data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad")
if tile_size == 4: r = KW
G_data = np.array([ m = tile_size
[1 / 4.0, 0, 0],
[-1 / 6.0, -1 / 6.0, -1 / 6.0],
[-1 / 6.0, 1 / 6.0, -1 / 6.0],
[1 / 24.0, 1 / 12.0, 1 / 6.0],
[1 / 24.0, -1 / 12.0, 1 / 6.0],
[0, 0, 1]], dtype=np.float32)
B_data = np.array([
[4, 0, 0, 0, 0, 0],
[0, -4, 4, -2, 2, 4],
[-5, -4, -4, -1, -1, 0],
[0, 1, -1, 2, -2, -5],
[1, 1, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 1]], out_dtype)
A_data = np.array([
[1, 0, 0, 0],
[1, 1, 1, 1],
[1, -1, 1, -1],
[1, 2, 4, 8],
[1, -2, 4, -8],
[0, 0, 0, 1]], out_dtype)
elif tile_size == 2:
G_data = np.array([
[1, 0, 0],
[1.0/2, 1.0/2, 1.0/2],
[1.0/2, -1.0/2, 1.0/2],
[0, 0, 1]], np.float32)
B_data = np.array([
[1, 0, 0, 0],
[0, 1, -1, 1],
[-1, 1, 1, 0],
[0, 0, 0, -1]], out_dtype)
A_data = np.array([
[1, 0],
[1, 1],
[1, -1],
[0, -1]], out_dtype)
else:
raise ValueError("Unsupported tile size for winograd: " + str(tile_size))
m = A_data.shape[1]
r = 3
alpha = m + r - 1 alpha = m + r - 1
A, B, G = winograd_transform_matrices(m, r, out_dtype)
K = CO K = CO
C = CI C = CI
...@@ -405,7 +361,6 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt ...@@ -405,7 +361,6 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
if pre_computed: if pre_computed:
U = kernel U = kernel
else: else:
G = const_matrix(G_data, 'G')
r_kh = tvm.reduce_axis((0, KH), 'r_kh') r_kh = tvm.reduce_axis((0, KH), 'r_kh')
r_kw = tvm.reduce_axis((0, KW), 'r_kw') r_kw = tvm.reduce_axis((0, KW), 'r_kw')
U = tvm.compute((alpha, alpha, K // VK, C, VK), lambda eps, nu, k, c, kk: U = tvm.compute((alpha, alpha, K // VK, C, VK), lambda eps, nu, k, c, kk:
...@@ -413,7 +368,6 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt ...@@ -413,7 +368,6 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]), name='U') G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]), name='U')
# transform image # transform image
B = const_matrix(B_data, 'B')
r_eps = tvm.reduce_axis((0, alpha), 'r_eps') r_eps = tvm.reduce_axis((0, alpha), 'r_eps')
r_nu = tvm.reduce_axis((0, alpha), 'r_nu') r_nu = tvm.reduce_axis((0, alpha), 'r_nu')
V = tvm.compute((alpha, alpha, P // VP, C, VP), lambda eps, nu, b, c, bb: V = tvm.compute((alpha, alpha, P // VP, C, VP), lambda eps, nu, b, c, bb:
...@@ -427,7 +381,6 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt ...@@ -427,7 +381,6 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
V[eps][nu][b // VP][c][b % VP], axis=c), name='M') V[eps][nu][b // VP][c][b % VP], axis=c), name='M')
# inverse transform # inverse transform
A = const_matrix(A_data, 'A')
r_eps = tvm.reduce_axis((0, alpha), 'r_eps') r_eps = tvm.reduce_axis((0, alpha), 'r_eps')
r_nu = tvm.reduce_axis((0, alpha), 'r_nu') r_nu = tvm.reduce_axis((0, alpha), 'r_nu')
Y = tvm.compute((K, P, m, m), lambda k, b, vh, vw: Y = tvm.compute((K, P, m, m), lambda k, b, vh, vw:
......
...@@ -17,15 +17,14 @@ ...@@ -17,15 +17,14 @@
# pylint: disable=invalid-name,unused-variable,unused-argument # pylint: disable=invalid-name,unused-variable,unused-argument
"""Winograd template for cuda backend""" """Winograd template for cuda backend"""
import numpy as np
import tvm import tvm
from tvm import autotvm from tvm import autotvm
from .. import nn from .. import nn
from ..nn import conv2d, group_conv2d_nchw, conv2d_winograd_without_weight_transform from ..nn import conv2d, group_conv2d_nchw, conv2d_winograd_without_weight_transform
from ..util import get_const_int, get_const_tuple, const_matrix, traverse_inline from ..util import get_const_int, get_const_tuple, traverse_inline
from ..generic import schedule_conv2d_winograd_without_weight_transform from ..generic import schedule_conv2d_winograd_without_weight_transform
from ..nn.winograd_util import winograd_transform_matrices
def _infer_tile_size(data, kernel): def _infer_tile_size(data, kernel):
...@@ -54,7 +53,7 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dty ...@@ -54,7 +53,7 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dty
CO, CI, KH, KW = get_const_tuple(kernel.shape) CO, CI, KH, KW = get_const_tuple(kernel.shape)
HPAD, WPAD, _, _ = nn.get_pad_tuple(padding, kernel) HPAD, WPAD, _, _ = nn.get_pad_tuple(padding, kernel)
HSTR, WSTR = (strides, strides) if isinstance(strides, int) else strides HSTR, WSTR = (strides, strides) if isinstance(strides, int) else strides
assert HSTR == 1 and WSTR == 1 and HPAD == 1 and WPAD == 1 and KH == 3 and KW == 3 assert HSTR == 1 and WSTR == 1 and KH == KW
else: # kernel tensor is pre-transfomred. this op is created by else: # kernel tensor is pre-transfomred. this op is created by
# alter op layout, do not check # alter op layout, do not check
# dilation is not supported # dilation is not supported
...@@ -65,54 +64,11 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dty ...@@ -65,54 +64,11 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dty
data_pad = nn.pad(data, (0, 0, HPAD, WPAD), (0, 0, HPAD, WPAD), name="data_pad") data_pad = nn.pad(data, (0, 0, HPAD, WPAD), (0, 0, HPAD, WPAD), name="data_pad")
if tile_size == 4: r = KW
G_data = np.array([ m = tile_size
[1 / 4.0, 0, 0],
[-1 / 6.0, -1 / 6.0, -1 / 6.0],
[-1 / 6.0, 1 / 6.0, -1 / 6.0],
[1 / 24.0, 1 / 12.0, 1 / 6.0],
[1 / 24.0, -1 / 12.0, 1 / 6.0],
[0, 0, 1]], dtype=np.float32)
B_data = np.array([
[4, 0, 0, 0, 0, 0],
[0, -4, 4, -2, 2, 4],
[-5, -4, -4, -1, -1, 0],
[0, 1, -1, 2, -2, -5],
[1, 1, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 1]], out_dtype)
A_data = np.array([
[1, 0, 0, 0],
[1, 1, 1, 1],
[1, -1, 1, -1],
[1, 2, 4, 8],
[1, -2, 4, -8],
[0, 0, 0, 1]], out_dtype)
elif tile_size == 2:
G_data = np.array([
[1, 0, 0],
[1.0/2, 1.0/2, 1.0/2],
[1.0/2, -1.0/2, 1.0/2],
[0, 0, 1]], np.float32)
B_data = np.array([
[1, 0, 0, 0],
[0, 1, -1, 1],
[-1, 1, 1, 0],
[0, 0, 0, -1]], out_dtype)
A_data = np.array([
[1, 0],
[1, 1],
[1, -1],
[0, -1]], out_dtype)
else:
raise ValueError("Unsupported tile size for winograd: " + str(tile_size))
m = A_data.shape[1]
r = 3
alpha = m + r - 1 alpha = m + r - 1
A, B, G = winograd_transform_matrices(m, r, out_dtype)
H = (H + 2 * HPAD - KH) // HSTR + 1 H = (H + 2 * HPAD - KH) // HSTR + 1
W = (W + 2 * WPAD - KW) // WSTR + 1 W = (W + 2 * WPAD - KW) // WSTR + 1
nH, nW = (H + m-1) // m, (W + m-1) // m nH, nW = (H + m-1) // m, (W + m-1) // m
...@@ -120,7 +76,6 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dty ...@@ -120,7 +76,6 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dty
# transform kernel # transform kernel
if not pre_computed: if not pre_computed:
G = const_matrix(G_data, 'G')
r_kh = tvm.reduce_axis((0, KH), name='r_kh') r_kh = tvm.reduce_axis((0, KH), name='r_kh')
r_kw = tvm.reduce_axis((0, KW), name='r_kw') r_kw = tvm.reduce_axis((0, KW), name='r_kw')
kernel_pack = tvm.compute((alpha, alpha, CI, CO), lambda eps, nu, ci, co: kernel_pack = tvm.compute((alpha, alpha, CI, CO), lambda eps, nu, ci, co:
...@@ -136,7 +91,6 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dty ...@@ -136,7 +91,6 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dty
[p % nW * m + nu], name='d') [p % nW * m + nu], name='d')
# transform data # transform data
B = const_matrix(B_data)
r_a = tvm.reduce_axis((0, alpha), 'r_a') r_a = tvm.reduce_axis((0, alpha), 'r_a')
r_b = tvm.reduce_axis((0, alpha), 'r_a') r_b = tvm.reduce_axis((0, alpha), 'r_a')
data_pack = tvm.compute((alpha, alpha, CI, P), lambda eps, nu, ci, p: data_pack = tvm.compute((alpha, alpha, CI, P), lambda eps, nu, ci, p:
...@@ -151,7 +105,6 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dty ...@@ -151,7 +105,6 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dty
axis=[ci]), name='bgemm') axis=[ci]), name='bgemm')
# inverse transform # inverse transform
A = const_matrix(A_data)
r_a = tvm.reduce_axis((0, alpha), 'r_a') r_a = tvm.reduce_axis((0, alpha), 'r_a')
r_b = tvm.reduce_axis((0, alpha), 'r_a') r_b = tvm.reduce_axis((0, alpha), 'r_a')
inverse = tvm.compute((CO, P, m, m), lambda co, p, vh, vw: inverse = tvm.compute((CO, P, m, m), lambda co, p, vh, vw:
......
...@@ -16,16 +16,15 @@ ...@@ -16,16 +16,15 @@
# under the License. # under the License.
# pylint: disable=invalid-name,unused-variable,unused-argument,no-else-return # pylint: disable=invalid-name,unused-variable,unused-argument,no-else-return
"""conv2d schedule on ARM Mali GPU""" """conv2d schedule on ARM Mali GPU"""
import numpy as np
import tvm import tvm
from tvm import autotvm from tvm import autotvm
from tvm.autotvm.task.space import get_factors from tvm.autotvm.task.space import get_factors
from ..generic import schedule_conv2d_nchw, schedule_conv2d_winograd_without_weight_transform from ..generic import schedule_conv2d_nchw, schedule_conv2d_winograd_without_weight_transform
from ..util import traverse_inline, get_const_int, get_const_tuple, const_matrix from ..util import traverse_inline, get_const_int, get_const_tuple
from ..nn import conv2d, conv2d_winograd_without_weight_transform, \ from ..nn import conv2d, conv2d_winograd_without_weight_transform, \
get_pad_tuple, pad, conv2d_alter_layout get_pad_tuple, pad, conv2d_alter_layout
from ..nn.winograd_util import winograd_transform_matrices
# reuse some compute declarations from ARM CPU # reuse some compute declarations from ARM CPU
from ..arm_cpu.conv2d import _decl_spatial_pack, _alter_conv2d_layout_arm from ..arm_cpu.conv2d import _decl_spatial_pack, _alter_conv2d_layout_arm
...@@ -226,57 +225,13 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt ...@@ -226,57 +225,13 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
assert layout == 'NCHW' assert layout == 'NCHW'
assert KH == 3 and KW == 3 and HPAD == 1 and WPAD == 1 and HSTR == 1 and WSTR == 1 assert KH == 3 and KW == 3 and HSTR == 1 and WSTR == 1
data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad") data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad")
if tile_size == 4: r = KW
G_data = np.array([ m = tile_size
[1 / 4.0, 0, 0],
[-1 / 6.0, -1 / 6.0, -1 / 6.0],
[-1 / 6.0, 1 / 6.0, -1 / 6.0],
[1 / 24.0, 1 / 12.0, 1 / 6.0],
[1 / 24.0, -1 / 12.0, 1 / 6.0],
[0, 0, 1]], out_dtype)
B_data = np.array([
[4, 0, 0, 0, 0, 0],
[0, -4, 4, -2, 2, 4],
[-5, -4, -4, -1, -1, 0],
[0, 1, -1, 2, -2, -5],
[1, 1, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 1]], out_dtype)
A_data = np.array([
[1, 0, 0, 0],
[1, 1, 1, 1],
[1, -1, 1, -1],
[1, 2, 4, 8],
[1, -2, 4, -8],
[0, 0, 0, 1]], out_dtype)
elif tile_size == 2:
G_data = np.array([
[1, 0, 0],
[1.0/2, 1.0/2, 1.0/2],
[1.0/2, -1.0/2, 1.0/2],
[0, 0, 1]], out_dtype)
B_data = np.array([
[1, 0, 0, 0],
[0, 1, -1, 1],
[-1, 1, 1, 0],
[0, 0, 0, -1]], out_dtype)
A_data = np.array([
[1, 0],
[1, 1],
[1, -1],
[0, -1]], out_dtype)
else:
raise ValueError("Unsupported tile size for winograd: " + str(tile_size))
m = A_data.shape[1]
r = 3
alpha = m + r - 1 alpha = m + r - 1
A, B, G = winograd_transform_matrices(m, r, out_dtype)
H = (IH + 2 * HPAD - 3) // HSTR + 1 H = (IH + 2 * HPAD - 3) // HSTR + 1
W = (IW + 2 * WPAD - 3) // WSTR + 1 W = (IW + 2 * WPAD - 3) // WSTR + 1
...@@ -321,7 +276,6 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt ...@@ -321,7 +276,6 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
if pre_computed: if pre_computed:
U = kernel U = kernel
else: else:
G = const_matrix(G_data, 'G')
r_kh = tvm.reduce_axis((0, KH), 'r_kh') r_kh = tvm.reduce_axis((0, KH), 'r_kh')
r_kw = tvm.reduce_axis((0, KW), 'r_kw') r_kw = tvm.reduce_axis((0, KW), 'r_kw')
U = tvm.compute((alpha, alpha, CO // bna, CI, bna), lambda eps, nu, co, ci, vco: U = tvm.compute((alpha, alpha, CO // bna, CI, bna), lambda eps, nu, co, ci, vco:
...@@ -329,7 +283,6 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt ...@@ -329,7 +283,6 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
axis=[r_kh, r_kw]), name='U') axis=[r_kh, r_kw]), name='U')
# transform image # transform image
B = const_matrix(B_data, 'B')
r_a = tvm.reduce_axis((0, alpha), 'r_a') r_a = tvm.reduce_axis((0, alpha), 'r_a')
r_b = tvm.reduce_axis((0, alpha), 'r_b') r_b = tvm.reduce_axis((0, alpha), 'r_b')
V = tvm.compute((alpha, alpha, P_round // bnb, CI, bnb), lambda eps, nu, p, ci, vp: V = tvm.compute((alpha, alpha, P_round // bnb, CI, bnb), lambda eps, nu, p, ci, vp:
...@@ -342,7 +295,6 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt ...@@ -342,7 +295,6 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
tvm.sum(U[eps][nu][co // bna][ci][co % bna] * tvm.sum(U[eps][nu][co // bna][ci][co % bna] *
V[eps][nu][p // bnb][ci][p % bnb], axis=ci), name='M') V[eps][nu][p // bnb][ci][p % bnb], axis=ci), name='M')
A = const_matrix(A_data, 'A')
r_a = tvm.reduce_axis((0, alpha), 'r_a') r_a = tvm.reduce_axis((0, alpha), 'r_a')
r_b = tvm.reduce_axis((0, alpha), 'r_b') r_b = tvm.reduce_axis((0, alpha), 'r_b')
Y = tvm.compute((CO, P, m, m), lambda co, p, vh, vw: Y = tvm.compute((CO, P, m, m), lambda co, p, vh, vw:
......
...@@ -19,12 +19,12 @@ ...@@ -19,12 +19,12 @@
"""Conv2D operators""" """Conv2D operators"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from collections import namedtuple from collections import namedtuple
import numpy as np
import tvm import tvm
from .pad import pad from .pad import pad
from .util import get_pad_tuple from .util import get_pad_tuple
from ..util import simplify, const_matrix, get_const_tuple from ..util import simplify, get_const_tuple
from .winograd_util import winograd_transform_matrices
# workload description of conv2d # workload description of conv2d
Workload = namedtuple('Workload', Workload = namedtuple('Workload',
...@@ -425,7 +425,7 @@ def conv2d_winograd_weight_transform(kernel, tile_size): ...@@ -425,7 +425,7 @@ def conv2d_winograd_weight_transform(kernel, tile_size):
Parameters Parameters
---------- ----------
kernel: Tensor kernel: Tensor
The raw kernel tensor with layout "NCHW". Only 3x3 kernel is supported for now The raw kernel tensor with layout "NCHW".
tile_size: int tile_size: int
Tile size of winograd transform. e.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3) Tile size of winograd transform. e.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)
...@@ -434,34 +434,15 @@ def conv2d_winograd_weight_transform(kernel, tile_size): ...@@ -434,34 +434,15 @@ def conv2d_winograd_weight_transform(kernel, tile_size):
output : tvm.Tensor output : tvm.Tensor
4-D with shape [alpha, alpha, CO, CI] 4-D with shape [alpha, alpha, CO, CI]
""" """
K = 3
shape = get_const_tuple(kernel.shape) shape = get_const_tuple(kernel.shape)
assert shape[2:] == (K, K), "Only support 3x3 kernel" assert shape[2] == shape[3], "Only support NxN kernel"
K = shape[3]
r = tile_size + K - 1 r = tile_size + K - 1
shape = (r, r) + shape[:2] shape = (r, r) + shape[:2]
if tile_size == 2: _, _, G = winograd_transform_matrices(tile_size, K, kernel.dtype)
G_data = np.array([
[1, 0, 0],
[1.0/2, 1.0/2, 1.0/2],
[1.0/2, -1.0/2, 1.0/2],
[0, 0, 1],
], dtype=kernel.dtype)
elif tile_size == 4:
G_data = np.array([
[1 / 4.0, 0, 0],
[-1 / 6.0, -1 / 6.0, -1 / 6.0],
[-1 / 6.0, 1 / 6.0, -1 / 6.0],
[1 / 24.0, 1 / 12.0, 1 / 6.0],
[1 / 24.0, -1 / 12.0, 1 / 6.0],
[0, 0, 1]
], dtype=kernel.dtype)
else:
raise ValueError("Unsupoorted tile size:" + tile_size)
G = const_matrix(G_data, 'G')
r_kh = tvm.reduce_axis((0, K), name='r_kh') r_kh = tvm.reduce_axis((0, K), name='r_kh')
r_kw = tvm.reduce_axis((0, K), name='r_kw') r_kw = tvm.reduce_axis((0, K), name='r_kw')
return tvm.compute(shape, lambda eps, nu, co, ci: return tvm.compute(shape, lambda eps, nu, co, ci:
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
""" Utility functions for implementing Winograd convolutions
[*] Fast Algorithms for Convolutional Neural Networks
Andrew Lavin, Scott Gray
https://arxiv.org/abs/1509.09308
https://github.com/andravin/wincnn
"""
from operator import mul
from functools import reduce
import numpy as np
from ..util import const_matrix
# pylint: disable=invalid-name
def _cook_toom_convolution(a, n, r):
"""Compute Cook-Toom convolution A,B,G matrices"""
def _F_m(a, n):
f = lambda j, i: reduce(mul, ((a[i]-a[k] if k != i else 1) for k in range(0, n-1)), 1)
F = np.fromfunction(np.vectorize(f), (1, n-1), dtype=int)
F = np.diagflat(F)
F = np.append(F, np.zeros((n-1, 1), dtype=int), axis=1)
f = lambda i, j: (1 if j == (n-1) else 0)
z = np.fromfunction(np.vectorize(f), (1, n), dtype=int)
return np.append(F, z, axis=0)
def _A_m(a, m, n):
f = lambda i, j: a[i]**j
A = np.fromfunction(np.vectorize(f), (m-1, n), dtype=int)
f = lambda i, j: (1 if j == (n-1) else 0)
z = np.fromfunction(np.vectorize(f), (1, n), dtype=int)
return np.append(A, z, axis=0)
def _B_m(a, n):
f = lambda j, i: reduce(mul, ((a[i]-a[k] if k != i else 1) for k in range(0, n-1)), 1)
Ff = np.fromfunction(np.vectorize(f), (1, n-1), dtype=int)
f = lambda i, nth: (reduce(mul, [(np.poly1d([1, -a[k]]) if k != i else 1) \
for k in range(0, n-1)], 1)).coef[n-1-nth-1]/Ff[0, i]
F = np.fromfunction(np.vectorize(f), (n-1, n-1), dtype=int)
f = lambda i, j: -a[i]**(n-1)
t = np.fromfunction(np.vectorize(f), (n-1, 1), dtype=int)
T = np.append(np.eye(n-1), t, axis=1)
return np.append(F.T.dot(T), np.array([np.eye(n)[n-1]]), axis=0)
alpha = n + r - 1
f = _F_m(a, alpha)
if f[0, 0] < 0:
f[0, :] *= -1
A = _A_m(a, alpha, n)
G = _A_m(a, alpha, r).T
G = G.dot(np.linalg.inv(f)).T
B = _B_m(a, alpha)
B = B.dot(f.T)
return (A, B, G)
def _interpolation_points(degree):
"""Propose filter points"""
assert 2 < degree < 18
# Default interpolation lookup table
#
# [1] Error Analysis and Improving the Accuracy of Winograd Convolution for Deep Neural Networks
# Barbara Barabasz, Andrew Anderson, Kirk M. Soodhalter, David Gregg
# https://arxiv.org/abs/1803.10986
#
# pylint: disable=bad-whitespace,line-too-long
in_pts = [
# {invalid}
[],
#01 {E=4.63E-08 on conv2d [1]}
[],
#02 {E=7.65E-08 on F( 2,3) [1]}
[0, -1, 1],
#03 {E=2.35E-07 on F( 3,3) [1]}
[0, -1, 1, 1/2],
#04 {E=3.29E-07 on F( 4,3) [1]}
[0, -1, 1, 1/2, -2],
#05 {E=6.81E-07 on F( 5,3) [1]}
[0, -1, 1, 1/2, -2, -1/2],
#06 {E=8.79E-07 on F( 6,3) [1]}
[0, -1, 1, 1/2, -1/2, 2, -2],
#07 {E=3.71E-06 on F( 7,3) [1]}
[0, -1, 1, 1/2, -1/2, 2, -2, -1/4],
#08 {E=7.35E-06 on F( 8,3) [1]}
[0, -1, 1, 1/2, -1/2, 2, -2, -1/4, 4],
#09 {E=2.20E-05 on F( 9,3) [1]}
[0, -1, 1, 1/2, -1/2, 2, -2, -1/4, 3/4, -4/3],
#10 {E=3.22E-05 on F(10,3) [1]}
[0, -1, 1, 1/2, -1/2, 2, -2, -1/4, 4, 3/4, -4/3],
#11 {E=1.09E-04 on F(11,3) [1]}
[0, -1, 1, 1/2, -1/2, 2, -2, -1/4, 4, 3/4, -4/3, 1/4],
#12 {E=1.99E-04 on F(12,3) [1]}
[0, -1, 1, 1/2, -1/2, 2, -2, -1/4, 4, 1/4, -3/4, 4/3, -4],
#13 {E=5.54E-04 on F(13,3) [1]}
[0, -1, 1, 1/2, -1/2, 2, -2, -1/4, 4, 1/4, -3/4, 4/3, 3/4, -4/3],
#14 {E=8.80E-04 on F(14,3) [1]}
[0, -1, 1, 1/2, -1/2, 2, -2, -1/4, 4, 1/4, -3/4, 4/3, -4, 3/4, -4/3],
#15 {E=1.07E-02 on F(15,3) [1]}
[0, -1, 1, 1/2, -1/2, 2, -2, -1/4, 4, 1/4, -3/4, 4/3, -4, 2/3, -3/2, 3/2],
#16 {E=1.93E-02 on F(16,3) [1]}
[0, -1, 1, 1/2, -1/2, 2, -2, -1/4, 4, 1/4, -3/4, 4/3, -4, 2/3, -3/2, -2/3, 3/2]
] # pylint: enable=bad-whitespace,line-too-long
return np.array(in_pts[degree-1], dtype=np.float64)
def winograd_transform_matrices(tile_size, kernel_size, out_dtype):
"""Compute the A, B, and G transform matrices for `tile_size` as a `tvm.Expr`.
"""
if not 1 < tile_size < 9:
raise ValueError("Unsupported tile size for Winograd: {}".format(tile_size))
if not 2 < kernel_size < 8:
raise ValueError("Unsupported kernel size for Winograd: {}".format(kernel_size))
degree = tile_size + kernel_size - 2
intp_pts = _interpolation_points(degree)
A_data, B_data, G_data = _cook_toom_convolution(intp_pts, tile_size, kernel_size)
return (
const_matrix(A_data.astype(out_dtype), "A"),
const_matrix(B_data.astype(out_dtype), "B"),
const_matrix(G_data.astype(out_dtype), "G"),
)
...@@ -103,11 +103,18 @@ def test_conv2d_nchw(): ...@@ -103,11 +103,18 @@ def test_conv2d_nchw():
autotvm.DispatchContext.current.silent = True autotvm.DispatchContext.current.silent = True
with WinogradFallback(): with WinogradFallback():
# inception v3 workloads
verify_conv2d_nchw(1, 128, 17, 192, 7, 1, 3, devices=['cuda'])
verify_conv2d_nchw(1, 128, 17, 128, 7, 1, 3, devices=['cuda'])
verify_conv2d_nchw(1, 160, 17, 160, 7, 1, 3, devices=['cuda'])
# resnet 18 workloads # resnet 18 workloads
verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1) verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1)
verify_conv2d_nchw(1, 128, 28, 128, 3, 1, 1) verify_conv2d_nchw(1, 128, 28, 128, 3, 1, 1)
verify_conv2d_nchw(1, 256, 14, 256, 3, 1, 1) verify_conv2d_nchw(1, 256, 14, 256, 3, 1, 1)
verify_conv2d_nchw(1, 512, 7, 512, 3, 1, 1) verify_conv2d_nchw(1, 512, 7, 512, 3, 1, 1)
verify_conv2d_nchw(1, 48, 35, 64, 5, 1, 2, devices=['cuda'])
# batch size = 2 # batch size = 2
verify_conv2d_nchw(2, 64, 56, 64, 3, 1, 1) verify_conv2d_nchw(2, 64, 56, 64, 3, 1, 1)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment