Commit de8d4a4d by Lianmin Zheng Committed by Tianqi Chen

[TOPI] Add winograd for mali (#898)

* add winograd for mali

* fix lint

* add padding

* fix comment
parent 01f52b1d
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
"""conv2d schedule on ARM Mali GPU""" """conv2d schedule on ARM Mali GPU"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import numpy as np
import tvm import tvm
from .. import generic from .. import generic
...@@ -63,7 +65,23 @@ def transpose(s, tensor, readers): ...@@ -63,7 +65,23 @@ def transpose(s, tensor, readers):
s[tmp].compute_inline() s[tmp].compute_inline()
return s.cache_write(tmp, "global"), tmp return s.cache_write(tmp, "global"), tmp
@conv2d.register("mali") def const_array(data, name):
""" convert an const array to tvm tensor"""
row, col = data.shape
dtype = str(data.dtype)
def select_array(i, j):
now = tvm.const(0.0, dtype)
for ii in range(row):
for jj in range(col):
now = tvm.select(tvm.all(i % row == ii, j % col == jj),
tvm.const(data[ii][jj], dtype),
now)
return now
return tvm.compute(data.shape, select_array, name=name)
@conv2d.register(["mali"])
def decl_conv2d(data, kernel, stride, padding, layout='NCHW', out_dtype='float32'): def decl_conv2d(data, kernel, stride, padding, layout='NCHW', out_dtype='float32'):
"""Conv2D operator for ARM Mali GPU backend. """Conv2D operator for ARM Mali GPU backend.
...@@ -94,10 +112,20 @@ def decl_conv2d(data, kernel, stride, padding, layout='NCHW', out_dtype='float32 ...@@ -94,10 +112,20 @@ def decl_conv2d(data, kernel, stride, padding, layout='NCHW', out_dtype='float32
assert data.dtype == kernel.dtype, "Do not support inputs with different data types now." assert data.dtype == kernel.dtype, "Do not support inputs with different data types now."
out_dtype = data.dtype out_dtype = data.dtype
if util.get_const_int(kernel.shape[2]) == 1: HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
kernel_shape = util.get_const_tuple(kernel.shape)
if isinstance(stride, (tuple, list)):
HSTR, WSTR = stride
else:
HSTR, WSTR = stride, stride
if (kernel_shape[2:4] == (3, 3) and (HPAD, WPAD) == (1, 1) and kernel_shape[0] >= 64 and
(HSTR, WSTR) == (1, 1)):
return _decl_winograd(data, kernel, stride, padding, layout, out_dtype)
elif kernel_shape[2:4] == (1, 1):
return _decl_im2col(data, kernel, stride, padding, layout, out_dtype) return _decl_im2col(data, kernel, stride, padding, layout, out_dtype)
else: else:
return _decl_direct(data, kernel, stride, padding, layout, out_dtype) return _decl_spatialpack(data, kernel, stride, padding, layout, out_dtype)
@generic.schedule_conv2d_nchw.register(["mali"]) @generic.schedule_conv2d_nchw.register(["mali"])
def schedule_conv2d_nchw(outs): def schedule_conv2d_nchw(outs):
...@@ -129,14 +157,17 @@ def schedule_conv2d_nchw(outs): ...@@ -129,14 +157,17 @@ def schedule_conv2d_nchw(outs):
if 'im2col_conv_output' in op.tag: if 'im2col_conv_output' in op.tag:
_schedule_im2col_conv2d(s, op) _schedule_im2col_conv2d(s, op)
if 'direct_conv_output' in op.tag: if 'spatialpack_conv_output' in op.tag:
_schedule_direct_conv2d(s, op) _schedule_spatialpack_conv2d(s, op)
if 'winograd_conv_output' in op.tag:
_schedule_winograd(s, op)
traverse(outs[0].op) traverse(outs[0].op)
return s return s
def _decl_direct(data, kernel, stride, padding, layout, out_dtype): def _decl_spatialpack(data, kernel, stride, padding, layout, out_dtype):
"""declare the direct method (spatial packing) for conv2d""" """declare the spatialpack method (spatial packing) for conv2d"""
_, CI, IH, IW = [util.get_const_int(x) for x in data.shape] _, CI, IH, IW = [util.get_const_int(x) for x in data.shape]
CO, _, KH, KW = [util.get_const_int(x) for x in kernel.shape] CO, _, KH, KW = [util.get_const_int(x) for x in kernel.shape]
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
...@@ -207,12 +238,12 @@ def _decl_direct(data, kernel, stride, padding, layout, out_dtype): ...@@ -207,12 +238,12 @@ def _decl_direct(data, kernel, stride, padding, layout, out_dtype):
output = tvm.compute(oshape, lambda n, co, h, w: output = tvm.compute(oshape, lambda n, co, h, w:
conv[n][co//VC][h/VH][w//VW][h%VH][w%VW][co%VC], conv[n][co//VC][h/VH][w//VW][h%VH][w%VW][co%VC],
name='output_unpack', tag='direct_conv_output') name='output_unpack', tag='spatialpack_conv_output')
return output return output
def _schedule_direct_conv2d(s, op): def _schedule_spatialpack_conv2d(s, op):
"""schedule the direct method (spatial packing) for conv2d""" """schedule the spatialpack method (spatial packing) for conv2d"""
# get ops and tensors # get ops and tensors
output = op.output(0) output = op.output(0)
output_height = util.get_const_int(output.shape[2]) output_height = util.get_const_int(output.shape[2])
...@@ -294,8 +325,6 @@ def _schedule_direct_conv2d(s, op): ...@@ -294,8 +325,6 @@ def _schedule_direct_conv2d(s, op):
_, co, oh, ow = s[output].op.axis _, co, oh, ow = s[output].op.axis
tile_and_bind3d(s, output, co, oh, ow, num_thread, 1, last) tile_and_bind3d(s, output, co, oh, ow, num_thread, 1, last)
#print(tvm.lower(s, [data, kernel, output], simple_mode=True))
def _decl_im2col(data, kernel, stride, padding, layout='NCHW', out_dtype='float32'): def _decl_im2col(data, kernel, stride, padding, layout='NCHW', out_dtype='float32'):
"""declare the Im2Col method for conv2d""" """declare the Im2Col method for conv2d"""
_, CI, IH, IW = [x.value for x in data.shape] _, CI, IH, IW = [x.value for x in data.shape]
...@@ -476,4 +505,174 @@ def _schedule_im2col_conv2d(s, op): ...@@ -476,4 +505,174 @@ def _schedule_im2col_conv2d(s, op):
s[output].vectorize(vw) s[output].vectorize(vw)
fuse_and_bind(s, output, [n, co, h, w]) fuse_and_bind(s, output, [n, co, h, w])
#print(tvm.lower(s, [data, kernel], simple_mode=True)) def _decl_winograd(data, kernel, stride, padding, layout, out_dtype):
"""declare winograd fast convolution F(2x2, 3x3) for conv2d"""
N, CI, H, W = [util.get_const_int(x) for x in data.shape]
CO, CI, KH, KW = [util.get_const_int(x) for x in kernel.shape]
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
if isinstance(stride, (tuple, list)):
HSTR, WSTR = stride
else:
HSTR, WSTR = stride, stride
assert HSTR == 1 and WSTR == 1 and HPAD == 1 and WPAD == 1 and KH == 3 and KW == 3
data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad")
B_data = np.array([
[1, 0, 0, 0],
[0, 1, -1, 1],
[-1, 1, 1, 0],
[0, 0, 0, -1]
], out_dtype)
G_data = np.array([
[1, 0, 0],
[1.0/2, 1.0/2, 1.0/2],
[1.0/2, -1.0/2, 1.0/2],
[0, 0, 1],
], out_dtype)
A_data = np.array([
[1, 0],
[1, 1],
[1, -1],
[0, -1],
], out_dtype)
m = 2
r = 3
alpha = m + r - 1
K = CO
C = CI
nH, nW = (H + m-1) // m, (W + m-1) // m
P = N * nH * nW
bna, bnb = 4, 4
if data.dtype == 'float16':
bnb *= 2
P_round = (P + bnb - 1) // bnb * bnb
assert K % bna == 0 and P_round % bnb == 0
# pack input tile
input_tile = tvm.compute((C, P_round // bnb, alpha, alpha, bnb),
lambda c, b, eps, nu, bb:
tvm.select(b * bnb + bb < P,\
data_pad[(b*bnb+bb) // (nH*nW)][c][(b*bnb+bb) // nW % nH * m + eps]\
[(b*bnb+bb) % nW * m + nu], tvm.const(0, data_pad.dtype)),
name='d')
# transform kernel
G = const_array(G_data, 'G')
r_kh = tvm.reduce_axis((0, KH), 'r_kh')
r_kw = tvm.reduce_axis((0, KW), 'r_kw')
U = tvm.compute((alpha, alpha, K // bna, C, bna), lambda eps, nu, k, c, kk:
tvm.sum(kernel[k * bna + kk][c][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw],
axis=[r_kh, r_kw]), name='U')
# transform image
B = const_array(B_data, 'B')
r_eps = tvm.reduce_axis((0, alpha), 'r_eps')
r_nu = tvm.reduce_axis((0, alpha), 'r_nu')
V = tvm.compute((alpha, alpha, P_round // bnb, C, bnb), lambda eps, nu, b, c, bb:
tvm.sum(input_tile[c][b][r_eps][r_nu][bb] * B[r_eps][eps] * B[r_nu][nu],
axis=[r_eps, r_nu]), name='V')
# batch gemm
c = tvm.reduce_axis((0, C), name='c')
M = tvm.compute((alpha, alpha, K, P_round), lambda eps, nu, k, b:
tvm.sum(U[eps][nu][k // bna][c][k % bna] *
V[eps][nu][b // bnb][c][b % bnb], axis=c), name='M')
# inverse transform
A = const_array(A_data, 'A')
r_eps = tvm.reduce_axis((0, alpha), 'r_eps')
r_nu = tvm.reduce_axis((0, alpha), 'r_nu')
Y = tvm.compute((K, P, m, m), lambda k, b, vh, vw:
tvm.sum(M[r_eps][r_nu][k][b] * A[r_eps][vh] * A[r_nu][vw],
axis=[r_eps, r_nu]), name='Y')
# unpack output
output = tvm.compute((N, K, H, W), lambda n, k, h, w:
Y[k][n * nH * nW + (h//m) * nW + w//m][h % m][w % m]
# thw following term is used to make the padding effective,
# otherwise the padding will be eliminated by bound inference
+ tvm.const(0, out_dtype) * M[alpha-1][alpha-1][K-1][P_round-1],
name='output', tag='winograd_conv_output')
return output
def _schedule_winograd(s, op):
"""schedule winograd fast convolution F(2x2, 3x3) for conv2d"""
# get ops and tensors
output = op.output(0)
Y = op.input_tensors[0]
M, A = s[Y].op.input_tensors
U, V = s[M].op.input_tensors
kernel, G = s[U].op.input_tensors
d, B = s[V].op.input_tensors
data_pad = s[d].op.input_tensors[0]
data = s[data_pad].op.input_tensors[0]
# padding
s[data_pad].compute_inline()
# pack input tiles
c, b, eps, nu, bb = s[d].op.axis
s[d].reorder(eps, nu, bb)
aha = s[d].fuse(eps, nu)
s[d].unroll(bb)
tile_and_bind3d(s, d, c, b, aha, 4, 1, 1)
# transform kernel
s[G].compute_inline()
eps, nu, k, c, kk, = s[U].op.axis
r_kh, r_kw = s[U].op.reduce_axis
s[U].reorder(k, c, kk, eps, nu, r_kh, r_kw)
_ = [s[U].unroll(x) for x in [eps, nu, r_kh, r_kw]]
s[U].vectorize(kk)
tile_and_bind(s, U, k, c, 1, 256)
# transform image
s[B].compute_inline()
eps, nu, b, c, bb = s[V].op.axis
r_eps, r_nu = s[V].op.reduce_axis
s[V].reorder(b, c, bb, eps, nu, r_nu, r_eps)
_ = [s[V].unroll(x) for x in [eps, nu, r_eps, r_nu]]
s[V].vectorize(bb)
tile_and_bind(s, V, b, c, 2, 1)
# batch gemm
bna, bnb = 4, 4
if data.dtype == 'float16':
bnb *= 2
eps, nu, k, b = s[M].op.axis
c = s[M].op.reduce_axis[0]
yo, xo, yi, xi = s[M].tile(k, b, bna, bnb)
s[M].reorder(c, yi, xi)
c, c_unroll = s[M].split(c, 2)
s[M].unroll(c_unroll)
s[M].unroll(yi)
s[M].vectorize(xi)
z = s[M].fuse(eps, nu)
tile_and_bind3d(s, M, z, yo, xo, 1, 8, 1)
# inverse transform
s[A].compute_inline()
k, b, vh, vw = s[Y].op.axis
r_eps, r_nu = s[Y].op.reduce_axis
_ = [s[Y].unroll(x) for x in [vh, vw, r_eps, r_nu]]
tile_and_bind(s, Y, k, b, 4, 1)
# schedule output
if output.op in s.outputs: # no bias
output = output
else: # has bias
s[output].compute_inline()
output = s.outputs[0]
_, k, h, w = s[output].op.axis
tile_and_bind3d(s, output, k, h, w, 1, 2, 2)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment