Commit ecb0a7ea by mbarrett97 Committed by Tianqi Chen

[TOPI] Added support for Mali Bifrost target (#4047)

parent a21904a5
......@@ -506,6 +506,19 @@ def vta(model='unknown', options=None):
return ret
def bifrost(model='unknown', options=None):
"""Return an ARM Mali GPU target (Bifrost architecture).
Parameters
----------
options : str or list of str
Additional options
"""
opts = ["-device=bifrost", '-model=%s' % model]
opts = _merge_opts(opts, options)
return _api_internal._TargetCreate("opencl", *opts)
def create(target_str):
"""Get a target given target string.
......
......@@ -28,6 +28,7 @@ from . import x86
from . import cuda
from . import arm_cpu
from . import mali
from . import bifrost
from . import intel_graphics
from . import opengl
from . import util
......
......@@ -552,6 +552,9 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
if "-device=arm_cpu" in target.options:
tile_size = 4
VC = cfg['tile_k'].size[-1]
elif "-device=bifrost" in target.options:
tile_size = 2
VC = 0
else:
from ..mali.conv2d import _pick_tile_size
tile_size = _pick_tile_size(tinfos[0], tinfos[1])
......@@ -559,21 +562,28 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
weight = F.nn.contrib_conv2d_winograd_weight_transform(copy_inputs[1],
tile_size=tile_size)
weight = F.reshape(weight,
newshape=(KH + tile_size - 1,
KW + tile_size - 1,
idxd(CO, VC), VC, CI))
weight = F.transpose(weight, axes=[0, 1, 2, 4, 3])
if VC > 0:
weight = F.reshape(weight,
newshape=(KH + tile_size - 1,
KW + tile_size - 1,
idxd(CO, VC), VC, CI))
weight = F.transpose(weight, axes=[0, 1, 2, 4, 3])
new_weight = tvm.placeholder((KH + tile_size - 1,
KW + tile_size -1,
idxd(CO, VC), CI, VC),
kernel.dtype)
else:
weight = F.reshape(weight,
newshape=(KH + tile_size - 1, KW + tile_size - 1, CO, CI))
new_weight = tvm.placeholder(
(KH + tile_size - 1, KW + tile_size -1, CO, CI), kernel.dtype
)
copy_inputs[1] = weight
new_attrs['tile_size'] = tile_size
# Store the same config for the altered operator (workload)
new_data = data
new_weight = tvm.placeholder((KH + tile_size - 1,
KH + tile_size -1,
idxd(CO, VC), CI, VC),
kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_weight, strides, padding, dilation,
new_attrs[data_layout_key], out_dtype, tile_size],
......
# pylint: disable=redefined-builtin, wildcard-import
"""ARM Mali GPU specific declaration and schedules."""
from __future__ import absolute_import as _abs
from .gemm import *
from .conv2d import *
from .dense import *
from .depthwise_conv2d import *
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,unused-variable,unused-argument
"""conv2d schedule on ARM Mali (Bifrost) GPU"""
import tvm
from tvm import autotvm
from .gemm import decl_winograd_gemm, schedule_gemm
from .transforms import tile_and_bind, tile_and_bind3d
from ..generic import schedule_conv2d_nchw, schedule_conv2d_winograd_without_weight_transform
from ..util import traverse_inline, get_const_int, get_const_tuple
from ..nn import conv2d, conv2d_winograd_without_weight_transform, \
get_pad_tuple, pad, conv2d_alter_layout, dilate
from ..nn.winograd_util import winograd_transform_matrices
# reuse some compute declarations from ARM CPU
from ..arm_cpu.conv2d_spatial_pack import conv2d_spatial_pack_nchw
from ..arm_cpu.conv2d import _alter_conv2d_layout_arm
@autotvm.register_topi_compute(conv2d, 'bifrost', ['direct'])
def conv2d_bifrost(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
"""TOPI compute callback for conv2d
Parameters
----------
cfg: ConfigEntity
The config for this template
data : tvm.Tensor
4-D with shape [batch, in_channel, in_height, in_width]
kernel : tvm.Tensor
4-D with shape [num_filter, in_channel, filter_height, filter_width] or
pre-packed 5-D with shape [num_filter_chunk, in_channel, filter_height,
filter_width, num_filter_block]
strides : list of two ints
[stride_height, stride_width]
padding : list of two ints
[pad_height, pad_width]
dilation : list of two ints
[dilation_height, dilation_width]
layout : str
layout of data
out_dtype: str
The output type. This is used for mixed precision.
Returns
-------
output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
"""
if layout == 'NCHW':
return conv2d_spatial_pack_nchw(cfg, data, kernel, strides, padding,
dilation, out_dtype, num_tile=3)
else:
raise ValueError("Unsupported layout {}".format(layout))
@autotvm.register_topi_schedule(schedule_conv2d_nchw, 'bifrost', ['direct', 'winograd'])
def schedule_conv2d_nchw_bifrost(cfg, outs):
"""TOPI schedule callback for conv2d
Parameters
----------
cfg: ConfigEntity
The configuration of this template
outs: Array of Tensor
The computation graph description of convolution2d
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for conv2d
"""
s = tvm.create_schedule([x.op for x in outs])
def _callback(op):
# schedule conv2d
if 'spatial_conv2d_output' in op.tag:
output = op.output(0)
conv = op.input_tensors[0]
data_vec = conv.op.input_tensors[0]
data_pad = data_vec.op.input_tensors[0]
s[data_pad].compute_inline()
kernel_vec = conv.op.input_tensors[1]
if kernel_vec.op.name == 'kernel_vec':
kernel = kernel_vec.op.input_tensors[0]
else:
kernel = kernel_vec
if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
s[kernel].compute_inline()
_schedule_spatial_pack(cfg, s, output, conv, data_vec, kernel_vec)
if 'winograd_conv2d_output' in op.tag:
_schedule_winograd(cfg, s, op)
traverse_inline(s, outs[0].op, _callback)
return s
def _schedule_spatial_pack(cfg, s, output, conv, data_vec, kernel_vec):
"""schedule the spatial packing for conv2d"""
data = s[data_vec].op.input_tensors[0]
max_unroll = 16
vec_size = [1, 2, 4, 8, 16]
# get tunable parameters (they are defined in compute)
BC, TC, VC = cfg["tile_co"].size
BH, TH, VH = cfg["tile_oh"].size
BW, TW, VW = cfg["tile_ow"].size
# schedule padding
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
data_pad = data
s[data_pad].compute_inline()
# schedule data packing
if isinstance(data_vec.op, tvm.tensor.ComputeOp) and data_vec.op.name == 'data_vec_undilated':
_, h, w, ci, _, _, vh, vw = s[data_vec].op.axis
else:
_, h, w, ci, vh, vw = s[data_vec].op.axis
tile_and_bind3d(s, data_vec, h, w, ci, 1)
if vh.dom.extent.value < max_unroll:
s[data_vec].unroll(vh)
if vw.dom.extent.value < max_unroll:
s[data_vec].unroll(vw)
if isinstance(kernel_vec.op, tvm.tensor.ComputeOp) and kernel_vec.name == 'kernel_vec':
if autotvm.GLOBAL_SCOPE.in_tuning:
# kernel packing will be pre-computed during compilation, so we skip
# this part to make tuning records correct
s[kernel_vec].pragma(s[kernel_vec].op.axis[0], 'debug_skip_region')
else:
max_threads = tvm.target.current_target(allow_none=False).max_num_threads
co, ci, kh, kw, vc = s[kernel_vec].op.axis
fused = s[kernel_vec].fuse(co, ci, kh, kw, vc)
fused, vec = s[kernel_vec].split(fused, VC)
bb, tt = s[kernel_vec].split(fused, max_threads)
s[kernel_vec].bind(bb, tvm.thread_axis("blockIdx.x"))
s[kernel_vec].bind(tt, tvm.thread_axis("threadIdx.x"))
if VC in vec_size:
s[kernel_vec].vectorize(vec)
# schedule convolution
n, c, h, w, vh, vw, vc = s[conv].op.axis
kc, kh, kw = s[conv].op.reduce_axis
cfg["reorder_0"].apply(s, conv, [n, c, h, w, kc, kh, kw, vh, vw, vc])
tile_and_bind3d(s, conv, c, h, w, TC, TH, TW)
cfg["ann_reduce"].apply(s, conv, [kh, kw],
axis_lens=[get_const_int(kernel_vec.shape[2]),
get_const_int(kernel_vec.shape[3])],
max_unroll=max_unroll)
cfg["ann_spatial"].apply(s, conv, [vh, vw, vc],
axis_lens=[VH, VW, VC],
max_unroll=max_unroll,
vec_size=vec_size,
cfg=cfg)
# schedule output
if output.op not in s.outputs: # has bias
s[output].compute_inline()
output = s.outputs[0]
_, co, oh, ow = s[output].op.axis
tile_and_bind3d(s, output, co, oh, ow, TC, TH, TW)
return s
@autotvm.register_topi_compute(conv2d, 'bifrost', ['winograd'])
def conv2d_bifrost_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
"""Use Winograd as the convolution method"""
return _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype)
def _decl_winograd_kernel_transform(kernel, tile_size, G):
"""Declare a Winograd kernel transform
This exists separately to allow for precomputation
The precomputation will most often happen on CPU
Parameters
----------
kernel : tvm.Tensor
The kernel to transform
tile_size : int
The size of the tile to use for the Winograd filter
Returns
-------
U : tvm.Tensor
Transformed kernel
"""
CO, CI, KH, KW = [get_const_int(x) for x in kernel.shape]
# Only support 32 bit floats
out_dtype = 'float32'
alpha = G.shape[0]
K = CO
C = CI
def upround(x, align):
return (x + align - 1) // align * align
ALIGN = 16
K_round = upround(K, ALIGN)
# Padded Kernel [K_round, C, KH, KW]
# Pad the number of kernels to multiple of ALIGN
padded_kernel = tvm.compute((K_round, C, KH, KW),
lambda k, c, h, w:
tvm.if_then_else(k < K,
kernel[k][c][h][w],
tvm.const(0, out_dtype)),
name='padded_kernel')
# U [alpha, alpha, K_round, C]
# Perform the kernel transform
r_kh = tvm.reduce_axis((0, KH), 'r_kh')
r_kw = tvm.reduce_axis((0, KW), 'r_kw')
U = tvm.compute((alpha, alpha, K_round, C),
lambda eps, nu, k, c:
tvm.sum(padded_kernel[k][c][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw],
axis=[r_kh, r_kw]),
name='U')
return U
def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, tile_size=2):
"""Declare a winograd convolution - only tile_size=2 is currently supported"""
N, CI, IH, IW = get_const_tuple(data.shape)
if isinstance(dilation, int):
dilation_h = dilation_w = dilation
else:
dilation_h, dilation_w = dilation
if int(kernel.shape[2]) == 3:
if dilation_h != 1 or dilation_w != 1:
kernel = dilate(kernel, (1, 1, dilation_h, dilation_w))
pre_computed = False
CO, _, KH, KW = get_const_tuple(kernel.shape)
else:
assert (dilation_h, dilation_w) == (1, 1), "Does not support dilation"
pre_computed = True
H_CAT, W_CAT, CO, CI = get_const_tuple(kernel.shape)
KH, KW = H_CAT - tile_size + 1, W_CAT - tile_size + 1
HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
assert layout == 'NCHW'
assert KH == 3 and KW == 3 and HSTR == 1 and WSTR == 1
data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad")
r = KW
m = tile_size
alpha = m + r - 1
A, B, G = winograd_transform_matrices(m, r, out_dtype)
K = CO
C = CI
H = (IH + 2 * HPAD - 3) // HSTR + 1
W = (IW + 2 * WPAD - 3) // WSTR + 1
nH, nW = (H + m-1) // m, (W + m-1) // m
P = N * nH * nW
def upround(x, align):
return (x + align - 1) // align * align
ALIGN = 16
P_round = upround(P, ALIGN)
K_round = upround(K, ALIGN)
# CONFIG
cfg.define_knob("data_transform_wgx", [1, 2, 4, 8, 16, 32, 64])
cfg.define_knob("data_transform_wgy", [1, 2, 4, 8, 16, 32, 64])
# Pack input tile
input_tile = tvm.compute((N, C, H + 2, W + 2),
lambda n, c, h, w:
data_pad[n][c][h][w],
name='d')
if pre_computed:
U = kernel
else:
U = _decl_winograd_kernel_transform(kernel, tile_size, G)
# V [alpha * alpha, C, P_round)
# Perform the image transform
r_eps = tvm.reduce_axis((0, alpha), 'r_eps')
r_nu = tvm.reduce_axis((0, alpha), 'r_nu')
V = tvm.compute((alpha * alpha, C, P_round),
lambda epsnu, c, b:
tvm.sum(input_tile[b // (nH*nW)][c][b // nW % nH * m + r_eps][b % nW * m +r_nu]\
* B[r_eps][epsnu // alpha] * B[r_nu][epsnu % alpha],
axis=[r_eps, r_nu]),
name='V')
# Winograd GEMM is a wrapper around batched GEMM to convert U to a 3D Tensor
_, M = decl_winograd_gemm(cfg, U, V)
# Y [K, P, m, m]
# Winograd output transform
r_eps = tvm.reduce_axis((0, alpha), 'r_eps')
r_nu = tvm.reduce_axis((0, alpha), 'r_nu')
Y = tvm.compute((K, P, m, m), lambda k, b, vh, vw:
tvm.sum(M[r_eps * alpha + r_nu][k][b] * A[r_eps][vh] * A[r_nu][vw],
axis=[r_eps, r_nu]), name='Y')
# Output [N, K, H, W]
# Unpack back to NCHW format
# The last term ensures alignment is not lost to bound inference
output = tvm.compute((N, K, H, W), lambda n, k, h, w:
Y[k][n * nH * nW + (h//m) * nW + w//m][h % m][w % m]
+ tvm.const(0, out_dtype) * M[(alpha*alpha)-1][K_round-1][P_round-1],
name='output', tag='winograd_conv2d_output')
return output
def _schedule_winograd(cfg, s, op):
"""Schedule Winograd convolution for Bifrost"""
# Get ops and tensors
output = op.output(0)
Y = op.input_tensors[0]
M, A = s[Y].op.input_tensors
U_3D, V = s[M].op.input_tensors
U = s[U_3D].op.input_tensors[0]
d, B = s[V].op.input_tensors
data_pad = s[d].op.input_tensors[0]
if isinstance(U.op, tvm.tensor.ComputeOp):
padded_kernel, G = s[U].op.input_tensors
kernel = s[padded_kernel].op.input_tensors[0]
s[G].compute_inline()
eps, _, _, _ = s[U].op.axis
y, _, _, _ = s[padded_kernel].op.axis
if autotvm.GLOBAL_SCOPE.in_tuning:
# Kernel transformation will be pre-computed during compilation, so we skip
# this part to make tuning records correct
s[U].pragma(eps, 'debug_skip_region')
s[padded_kernel].pragma(y, 'debug_skip_region')
else:
# Pad kernel
y, x, ky, kx = s[padded_kernel].op.axis
s[padded_kernel].unroll(ky)
s[padded_kernel].unroll(kx)
tile_and_bind(s, padded_kernel, y, x, 1, 8)
# Transform kernel
eps, nu, k, c = s[U].op.axis
s[U].reorder(k, c, eps, nu)
r_kh, r_kw = s[U].op.reduce_axis
_ = [s[U].unroll(x) for x in [eps, nu, r_kh, r_kw]]
yo, xo, yi, xi = tile_and_bind(s, U, k, c, 1, 4)
# Dilation
if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
s[kernel].compute_inline()
# Pad data
s[data_pad].compute_inline()
# Pack data
n, c, h, w = s[d].op.axis
w, wi = s[d].split(w, 4)
s[d].unroll(wi)
b = s[d].fuse(n, c)
tile_and_bind3d(s, d, b, h, w, 1, 4, 2)
# Transform data
bIL_d = s.cache_read(d, 'local', [V])
s[B].compute_inline()
epsnu, c, b = s[V].op.axis
r_eps, r_nu = s[V].op.reduce_axis
s[V].reorder(b, c, epsnu, r_nu, r_eps)
_ = [s[V].unroll(x) for x in [epsnu, r_eps, r_nu]]
yo, xo, yi, xi = tile_and_bind(
s, V, b, c, cfg["data_transform_wgy"].val, cfg["data_transform_wgx"].val
)
s[bIL_d].compute_at(s[V], xi)
n, c, h, w = s[bIL_d].op.axis
s[bIL_d].unroll(h)
s[bIL_d].vectorize(w)
# Batched GEMM
# Inline the 4D -> 3D tensor transform on the kernel
s[U_3D].compute_inline()
U_transform, V_transform = schedule_gemm(
cfg, s, U_3D, V, M, batched=True, schedule_transforms=True
)
# Inverse transform
CR_M = s.cache_read(M, 'local', [Y])
CW_Y = s.cache_write(Y, 'local')
s[A].compute_inline()
k, b, vh, vw = s[Y].op.axis
fused = s[Y].fuse(vh, vw)
s[Y].vectorize(fused)
yo, xo, yi, xi = tile_and_bind(s, Y, k, b, 1, 4)
s[CR_M].compute_at(s[Y], xi)
k, b, epsnu = s[CR_M].op.axis
s[CR_M].unroll(k)
s[CW_Y].compute_at(s[Y], xi)
k, b, vh, vw = s[CW_Y].op.axis
r_eps, r_nu = s[CW_Y].op.reduce_axis
_ = [s[CW_Y].unroll(x) for x in [vh, vw, r_eps, r_nu]]
# Schedule output and fusion
if output.op not in s.outputs:
s[output].compute_inline()
output = s.outputs[0]
_, k, h, w = s[output].op.axis
tile_and_bind3d(s, output, k, h, w, 1, 2, 2)
##### REGISTER TOPI COMPUTE / SCHEDULE FOR WINOGRAD WITH WEIGHT TRANSFORM #####
@autotvm.register_topi_compute(conv2d_winograd_without_weight_transform, 'bifrost', ['winograd'])
def conv2d_winograd_ww(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, tile_size):
"""TOPI compute callback"""
return _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype)
@autotvm.register_topi_schedule(schedule_conv2d_winograd_without_weight_transform,
'bifrost', ['winograd'])
def schedule_conv2d_winograd_without_weight_transform_(cfg, outs):
"""TOPI schedule callback"""
s = tvm.create_schedule([x.op for x in outs])
def _callback(op):
if 'winograd_conv2d_output' in op.tag:
_schedule_winograd(cfg, s, op)
traverse_inline(s, outs[0].op, _callback)
return s
##### REGISTER ALTER OP LAYOUT #####
@conv2d_alter_layout.register(["bifrost"])
def _alter_conv2d_layout(attrs, inputs, tinfos, F):
try:
return _alter_conv2d_layout_arm(attrs, inputs, tinfos, F)
except KeyError: # to filter out fallback opencl templates
return None
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,unused-variable
"""dense schedule on ARM Mali GPU"""
from __future__ import absolute_import as _abs
import tvm
from tvm import autotvm
from .. import generic, nn
from ..util import traverse_inline
autotvm.register_topi_compute(nn.dense, 'bifrost', 'direct', nn.dense.fdefault)
@autotvm.register_topi_schedule(generic.schedule_dense, 'bifrost', 'direct')
def schedule_dense(cfg, outs):
"""Schedule for dense operator.
Parameters
----------
cfg: ConfigEntity
The config entity for this template
outs: Array of Tensor
The computation graph description of dense
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for dense.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
def _callback(op):
if op.tag == 'dense':
vec_size = [1, 2, 4, 8, 16]
max_unroll = 32
dense = op.output(0)
output = outs[0]
y, x = s[output].op.axis
c = s[dense].op.reduce_axis[0]
##### space definition begin #####
cfg.define_split('tile_y', y, num_outputs=3)
cfg.define_split('tile_x', x, num_outputs=3)
cfg.define_split('c_unroll', c, num_outputs=2, max_factor=64)
# fallback support
if cfg.is_fallback:
ref_log = autotvm.tophub.load_reference_log(
'mali', 'rk3399', 'dense', 'direct')
cfg.fallback_with_reference_log(ref_log)
##### space definition end #####
if dense.op in s.outputs:
dense = s.cache_write(output, 'local')
by, ty, yi = cfg['tile_y'].apply(s, output, y)
bx, tx, xi = cfg['tile_x'].apply(s, output, x)
s[output].bind(by, tvm.thread_axis('blockIdx.y'))
s[output].bind(bx, tvm.thread_axis('blockIdx.x'))
s[output].bind(ty, tvm.thread_axis('threadIdx.y'))
s[output].bind(tx, tvm.thread_axis('threadIdx.x'))
if cfg['tile_y'].size[-1] < max_unroll:
s[output].unroll(yi)
if cfg['tile_x'].size[-1] in vec_size:
s[output].vectorize(xi)
s[dense].compute_at(s[output], tx)
k = s[dense].op.reduce_axis[0]
y, x = s[dense].op.axis
k, k_unroll = cfg['c_unroll'].apply(s, dense, k)
s[dense].reorder(k, k_unroll, y, x)
s[dense].unroll(k_unroll)
if cfg['tile_y'].size[-1] < max_unroll:
s[dense].unroll(y)
if cfg['tile_x'].size[-1] in vec_size:
s[dense].vectorize(x)
traverse_inline(s, outs[0].op, _callback)
return s
def fuse_and_bind(s, tensor, axis=None, num_thread=None):
""" fuse all the axis and bind to GPU threads """
axis = axis or s[tensor].op.axis
fused = s[tensor].fuse(*axis)
bx, tx = s[tensor].split(fused, num_thread)
s[tensor].bind(bx, tvm.thread_axis("blockIdx.x"))
s[tensor].bind(tx, tvm.thread_axis("threadIdx.x"))
return bx, tx
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,unused-variable,unused-argument
"""depthwise_conv2d schedule on ARM Mali GPU"""
from __future__ import absolute_import as _abs
import tvm
from .. import generic
from .. import util
from .. import tag
@generic.schedule_depthwise_conv2d_nchw.register(["bifrost"])
def schedule_depthwise_conv2d_nchw(outs):
"""Schedule for depthwise_conv2d nchw forward.
Parameters
----------
outs: Array of Tensor
The computation graph description of depthwise_conv2d
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for depthwise_conv2d nchw.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
def _schedule(pad_data, kernel, conv):
raw_data = s[pad_data].op.input_tensors[0]
if conv.op not in s.outputs: # has bias or relu
output = outs[0]
else: # no bias or relu
output = conv
def tile_and_bind3d(tensor, z, y, x, z_factor=2, y_factor=None, x_factor=None):
""" tile and bind 3d """
y_factor = y_factor or z_factor
x_factor = x_factor or y_factor
zo, zi = s[tensor].split(z, z_factor)
yo, yi = s[tensor].split(y, y_factor)
xo, xi = s[tensor].split(x, x_factor)
s[tensor].bind(zo, tvm.thread_axis("blockIdx.z"))
s[tensor].bind(zi, tvm.thread_axis("threadIdx.z"))
s[tensor].bind(yo, tvm.thread_axis("blockIdx.y"))
s[tensor].bind(yi, tvm.thread_axis("threadIdx.y"))
s[tensor].bind(xo, tvm.thread_axis("blockIdx.x"))
s[tensor].bind(xi, tvm.thread_axis("threadIdx.x"))
return zo, zi, yo, yi, xo, xi
# set tunable parameters
VH = 1
VW = 1
num_thread = 4
while util.get_const_int(conv.shape[3]) % (VW * 2) == 0 and VW * 2 <= 4:
VW = VW * 2
while util.get_const_int(conv.shape[2]) % (VH * 2) == 0 and VH * 2 <= 2:
VH = VH * 2
if raw_data.dtype == 'float16':
if util.get_const_int(conv.shape[3]) % (VW * 2) == 0:
VW *= 2
num_thread *= 2
else:
num_thread *= 2
# schedule padding
_, c, y, x = s[pad_data].op.axis
tile_and_bind3d(pad_data, c, y, x, num_thread, 1, 1)
# schedule conv
di, dj = s[conv].op.reduce_axis
s[conv].unroll(di)
s[conv].unroll(dj)
_, c, y, x = s[output].op.axis
y, x, yi, xi = s[output].tile(y, x, VH, VW)
s[output].unroll(yi)
s[output].vectorize(xi)
_, _, _, _, _, ji = tile_and_bind3d(output, c, y, x, num_thread, 1, 1)
if conv.op not in s.outputs:
_, c, y, x = s[conv].op.axis
y, x, yi, xi = s[conv].tile(y, x, VH, VW)
s[conv].unroll(yi)
s[conv].vectorize(xi)
s[conv].compute_at(s[output], ji)
def traverse(op):
"""Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(op.tag):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
# schedule depthwise_conv2d
if op.tag == 'depthwise_conv2d_nchw':
pad_data = op.input_tensors[0]
kernel = op.input_tensors[1]
if isinstance(kernel.op, tvm.tensor.ComputeOp) and 'dilate' in kernel.op.tag:
s[kernel].compute_inline()
conv = op.output(0)
_schedule(pad_data, kernel, conv)
traverse(outs[0].op)
return s
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,unused-variable,unused-argument
"""GEMM schedules for Mali Bifrost"""
import tvm
from .transforms import tile_and_bind, tile_and_bind3d, interleave_transpose, \
transpose_interleave
from .. import util
def decl_gemm(cfg, A, B):
"""Declare a single GEMM computation for Mali Bifrost GPUs
Parameters
----------
cfg : Config
Schedule configuration
A : tvm.Tensor
2D Tensor, shape [n, k]
B : tvm.Tensor
2D Tensor, shape [k, m]
Returns
-------
C : tvm.Tensor
2D Tensor, shape [n, m]
"""
cfg.define_knob("work_group_x", [1, 2, 3, 4, 6, 8, 12, 16, 24, 32, 64])
cfg.define_knob("work_group_y", [1, 2, 3, 4, 6, 8, 12, 16, 24, 32, 64])
cfg.define_knob("unroll_k_factor", [1, 2, 4])
cfg.define_knob("A_interleave", [1, 4, 8, 16, 24, 32, 48, 64])
cfg.define_knob("B_interleave", [1, 4, 8, 16, 32])
cfg.define_knob("split_k_factor", [1, 4, 16])
# Mutual k axis must be of equal extent
assert util.get_const_int(A.shape[1]) == util.get_const_int(B.shape[0])
n = A.shape[0]
m = B.shape[1]
k_size = util.get_const_int(A.shape[1])
unroll_gemm = cfg["split_k_factor"].val
if unroll_gemm == 1:
# No unrolling case must have the same set of tensors to keep scheduling consistent
# Create identity tensors to take the place of A_unrolled, B_unrolled and R
A_unrolled = tvm.compute((n, k_size), lambda i, j: A[i, j], name="A_unrolled")
B_unrolled = tvm.compute((k_size, m), lambda i, j: B[i, j], name="B_unrolled")
# Declare standard GEMM
k = tvm.reduce_axis((0, A.shape[1]), name='k')
C = tvm.compute((n, m), lambda i, j:
tvm.sum(A_unrolled[i, k] * B_unrolled[k, j], axis=k), name='C')
R = tvm.compute((n, m), lambda i, j: C[i, j], name="R")
else:
unrolled_k_size = k_size // unroll_gemm
# Unroll the two input matrices along the shared k axis
A_unrolled = tvm.compute((unroll_gemm, n, unrolled_k_size), lambda b, i, j:
A[i][unrolled_k_size * b + j], name='A_unrolled')
B_unrolled = tvm.compute((unroll_gemm, unrolled_k_size, m), lambda b, i, j:
B[unrolled_k_size * b + i][j], name='B_unrolled')
# Declare a batched GEMM
k = tvm.reduce_axis((0, unrolled_k_size), name='k')
C = tvm.compute((unroll_gemm, n, m), lambda b, i, j:
tvm.sum(A_unrolled[b][i][k] * B_unrolled[b][k][j], axis=k), name='C')
# Then declare a reduction to reduce the sub matrices
k = tvm.reduce_axis((0, unroll_gemm), name='k')
R = tvm.compute((n, m), lambda i, j:
tvm.sum(C[k][i][j], axis=k), name='R')
return R
def decl_batched_gemm(cfg, A, B):
"""Declare a batched GEMM computation for Mali Bifrost GPUs
Parameters
----------
cfg : Config
Schedule configuration
A : tvm.Tensor
3D Tensor, shape [b, n, k]
B : tvm.Tensor
3D Tensor, shape [b, k, m]
Returns
-------
C : tvm.Tensor
3D Tensor, shape [b, n, m]
"""
# Mutual b and k axis must be of equal extent
assert util.get_const_int(A.shape[2]) == util.get_const_int(B.shape[1])
assert util.get_const_int(A.shape[0]) == util.get_const_int(B.shape[0])
cfg.define_knob("work_group_x", [1, 2, 3, 4, 6, 8, 12, 16, 24, 32, 64])
cfg.define_knob("work_group_y", [1, 2, 3, 4, 6, 8, 12, 16, 24, 32, 64])
cfg.define_knob("unroll_k_factor", [1, 2, 4])
cfg.define_knob("A_interleave", [1, 4, 8, 16, 32, 64])
cfg.define_knob("B_interleave", [1, 4, 8, 16, 32])
n = A.shape[1]
m = B.shape[2]
k_size = util.get_const_int(A.shape[2])
b_size = util.get_const_int(A.shape[0])
# Declare a batched GEMM
k = tvm.reduce_axis((0, k_size), name='k')
C = tvm.compute((b_size, n, m), lambda b, i, j:
tvm.sum(A[b][i][k] * B[b][k][j], axis=k), name='C')
return C
def decl_winograd_gemm(cfg, A, B):
"""Declare a winograd GEMM for Mali Bifrost GPUs
Winograd uses batched GEMM, however the input tensors are 4D
This wraps decl_batched_gemm to provide it with 3D tensors
Parameters
----------
cfg : Config
Schedule configuration
A : tvm.Tensor
4D Tensor, shape [a, a, n, k]
B : tvm.Tensor
4D Tensor, shape [a * a, k, m]
Returns
-------
"""
alpha = util.get_const_int(A.shape[0])
n = util.get_const_int(A.shape[2])
k = util.get_const_int(A.shape[3])
A_3D = tvm.compute((alpha * alpha, n, k), lambda b, i, j:
A[b // alpha][b % alpha][i][j], name='A_3D')
C = decl_batched_gemm(cfg, A_3D, B)
return A_3D, C
def schedule_gemm(cfg, s, A, B, C, batched=False, schedule_transforms=True):
"""Schedule GEMM, single and batched
Parameters
----------
cfg : Config
Schedule configuration
s : tvm.schedule.Schedule
Operator schedule
A : tvm.Tensor
2D/3D Tensor, shape [n, k]/[b, n, k]
B : tvm.Tensor
2D/3D Tensor, shape [k, m]/[b, k, m]
C : tvm.Tensor
2D/3D Tensor, shape [n, m]/[b, n, m]
batched : bool
Whether the GEMM is batched
Returns
-------
"""
block_size_x = 4
block_size_y = 4
warp_size_x = 2
warp_size_y = 2
work_group_x = cfg["work_group_x"].val
work_group_y = cfg["work_group_y"].val
k_unroll = cfg["unroll_k_factor"].val
if not batched:
y_index, x_index = (0, 1)
else:
y_index, x_index = (1, 2)
trans_inter, A_transposed_interleaved = transpose_interleave(
s, A, cfg["A_interleave"].val, y_index, x_index, [C], batched=batched
)
inter_trans, B_interleaved_transposed = interleave_transpose(
s, B, cfg["B_interleave"].val, y_index, x_index, [C], batched=batched
)
if schedule_transforms:
# Schedule A
y, x = s[trans_inter].op.axis
y, x, yi, xi = s[trans_inter].tile(y, x, 1, 8)
s[trans_inter].unroll(yi)
s[trans_inter].unroll(xi)
tile_and_bind(s, trans_inter, y, x, 1, 4)
# Schedule B
y, x = s[inter_trans].op.axis
xo, xi = s[inter_trans].split(x, 4)
s[inter_trans].vectorize(xi)
tile_and_bind(s, inter_trans, y, xo, 4, 4)
# Schedule C
CR_A = s.cache_read(A_transposed_interleaved, 'local', [C])
CR_B = s.cache_read(B_interleaved_transposed, 'local', [C])
CW_C = s.cache_write(C, 'local')
if not batched:
y, x = s[C].op.axis
else:
z, y, x = s[C].op.axis
y, x, yt, xt = s[C].tile(y, x, block_size_y, block_size_x)
s[C].unroll(yt)
s[C].vectorize(xt)
# Tile the global work space to generate 'square' warps -> 2x2 for warp size of 4
y, x, wy, wx = s[C].tile(y, x, warp_size_y, warp_size_x)
x = s[C].fuse(x, wy, wx)
if not batched:
yo, xo, yi, xi = tile_and_bind(s, C, y, x, work_group_y, work_group_x)
else:
# For batched GEMM bind batch to z axis
zo, yo, xo, zi, yi, xi = tile_and_bind3d(s, C, z, y, x, 1, work_group_y, work_group_x)
s[CW_C].compute_at(s[C], xi)
if not batched:
y, x = s[CW_C].op.axis
else:
_, y, x = s[CW_C].op.axis
y, x, yt, xt = s[CW_C].tile(y, x, block_size_y, block_size_x)
k = s[CW_C].op.reduce_axis[0]
s[CW_C].reorder(k, yt, xt)
ko, ki = s[CW_C].split(k, k_unroll)
s[CW_C].unroll(ki)
s[CW_C].unroll(yt)
s[CW_C].unroll(xt)
if not batched:
i, j = s[CR_A].op.axis
else:
_, i, j = s[CR_A].op.axis
s[CR_A].reorder(j, i)
s[CR_A].compute_at(s[CW_C], ki)
s[CR_A].unroll(j)
s[CR_A].vectorize(i)
if not batched:
i, j = s[CR_B].op.axis
else:
_, i, j = s[CR_B].op.axis
s[CR_B].compute_at(s[CW_C], ki)
s[CR_B].unroll(i)
s[CR_B].vectorize(j)
return trans_inter, inter_trans
def schedule_unrollable_gemm(cfg, s, A, B, C, R):
"""Schedule a GEMM that can be unrolled by a constant factor
along its inner dimension
Parameters
----------
cfg : Config
Schedule configuration
s : tvm.schedule.Schedule
Operator schedule
A : tvm.Tensor
2D/3D Tensor, shape [n, k]/[b, n, k]
B : tvm.Tensor
2D/3D Tensor, shape [k, m]/[b, k, m]
C : tvm.Tensor
2D/3D Tensor, shape [n, m]/[b, n, m]
R : tvm.Tensor
2D Tensor, shape [n, m]
"""
# If the GEMM is 2D, no unrolling has taken place
# Use non-batched GEMM schedule and inline identity matrices A, B and R
if len(C.op.axis) == 2:
s[A].compute_inline()
s[B].compute_inline()
schedule_gemm(cfg, s, A, B, C)
s[R].compute_inline()
# GEMM is 3D, use batched GEMM schedule, inline A and B and schedule R
else:
s[A].compute_inline()
s[B].compute_inline()
schedule_gemm(cfg, s, A, B, C, batched=True)
CR_C = s.cache_read(C, 'local', [R])
y, x = s[R].op.axis
xo, xi = s[R].split(x, 4)
k = s[R].op.reduce_axis[0]
s[R].reorder(k, xi)
ko, ki = s[R].split(k, 4)
s[R].unroll(xi)
s[R].unroll(ki)
tile_and_bind(s, R, y, xo, 1, 2)
s[CR_C].compute_at(s[R], ko)
_, y, x = s[CR_C].op.axis
s[CR_C].unroll(y)
s[CR_C].vectorize(x)
def get_unrollable_gemm_ops(R):
"""Get all GEMM operators from the final reduction
This is a helper function to more easily get all the GEMM operations
from an operator
Parameters
----------
R : tvm.Tensor
Reduced tensor, final stage of GEMM
Returns
-------
A_unrolled : tvm.Tensor
Matrix A unrolled along k
B_unrolled: tvm.Tensor
Matrix B unrolled along k
C : tvm.Tensor
Result of batched GEMM
R : tvm.Tensor
Reduction of C, result of unrollable GEMM
"""
C = R.op.input_tensors[0]
A_unrolled, B_unrolled = C.op.input_tensors
return A_unrolled, B_unrolled, C, R
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,unused-variable,unused-argument
"""Utility scheduling functions for the Bifrost schedules"""
from __future__ import absolute_import as _abs
import tvm
def fuse_and_bind(s, tensor, axis=None, num_thread=None):
"""Fuse all the axis and bind to GPU threads"""
axis = axis or s[tensor].op.axis
fused = s[tensor].fuse(*axis)
max_threads = tvm.target.current_target(allow_none=False).max_num_threads
bx, tx = s[tensor].split(fused, num_thread or max_threads)
s[tensor].bind(bx, tvm.thread_axis("blockIdx.x"))
s[tensor].bind(tx, tvm.thread_axis("threadIdx.x"))
return bx, tx
def tile_and_bind(s, tensor, y, x, y_factor, x_factor=None):
"""Tile and bind to GPU threads"""
x_factor = x_factor or y_factor
yo, xo, yi, xi = s[tensor].tile(y, x, y_factor, x_factor)
s[tensor].bind(xo, tvm.thread_axis("blockIdx.x"))
s[tensor].bind(xi, tvm.thread_axis("threadIdx.x"))
s[tensor].bind(yo, tvm.thread_axis("blockIdx.y"))
s[tensor].bind(yi, tvm.thread_axis("threadIdx.y"))
return yo, xo, yi, xi
def tile_and_bind3d(s, tensor, z, y, x, z_factor=2, y_factor=None, x_factor=None):
"""Tile and bind 3d"""
y_factor = y_factor or z_factor
x_factor = x_factor or y_factor
zo, zi = s[tensor].split(z, z_factor)
yo, yi = s[tensor].split(y, y_factor)
xo, xi = s[tensor].split(x, x_factor)
s[tensor].bind(zo, tvm.thread_axis("blockIdx.z"))
s[tensor].bind(zi, tvm.thread_axis("threadIdx.z"))
s[tensor].bind(yo, tvm.thread_axis("blockIdx.y"))
s[tensor].bind(yi, tvm.thread_axis("threadIdx.y"))
s[tensor].bind(xo, tvm.thread_axis("blockIdx.x"))
s[tensor].bind(xi, tvm.thread_axis("threadIdx.x"))
return zo, yo, xo, zi, yi, xi
def pack_tensor(s, tensor, factor, readers):
"""Do transform X[n, m] -> X[n / factor, m, factor]"""
tmp = s.cache_read(tensor, 'global', readers)
y, x = s[tmp].op.axis
yo, yi = s[tmp].split(y, factor)
s[tmp].reorder(yo, x, yi)
s[tmp].compute_inline()
return s.cache_write(tmp, 'global'), tmp
def transpose(s, tensor, y_index, x_index, readers):
"""Do transform X[n, m] -> X[m, n]"""
tmp = s.cache_read(tensor, 'global', readers)
y, x = s[tmp].op.axis[y_index], s[tmp].op.axis[x_index]
s[tmp].reorder(x, y)
s[tmp].compute_inline()
A_transpose = s.cache_write(tmp, "global")
CR_A = s.cache_read(tensor, 'local', [A_transpose])
CW_A_transpose = s.cache_write(A_transpose, 'local')
y, x = s[A_transpose].op.axis[y_index], s[A_transpose].op.axis[x_index]
yo, xo, yi, xi = s[A_transpose].tile(y, x, 4, 4)
s[A_transpose].unroll(yi)
s[A_transpose].vectorize(xi)
_, _, _, xi = tile_and_bind(s, A_transpose, yo, xo, 32, 2)
s[CW_A_transpose].compute_at(s[A_transpose], xi)
y, x = s[CW_A_transpose].op.axis[y_index], s[CW_A_transpose].op.axis[x_index]
s[CW_A_transpose].unroll(x)
s[CW_A_transpose].unroll(y)
s[CR_A].compute_at(s[A_transpose], xi)
y, x = s[CR_A].op.axis[y_index], s[CR_A].op.axis[x_index]
s[CR_A].unroll(y)
s[CR_A].vectorize(x)
return tmp
def interleave_transpose(s, tensor, width, y_index, x_index, readers, batched=False):
"""Interleave the tensor, then transpose it"""
tmp = s.cache_read(tensor, 'global', readers)
y, x = s[tmp].op.axis[y_index], s[tmp].op.axis[x_index]
xo, xi = s[tmp].split(x, width)
s[tmp].reorder(xo, y, xi)
s[tmp].fuse(y, xi)
if batched:
z = s[tmp].op.axis[0]
s[tmp].fuse(z, xo)
s[tmp].compute_inline()
return s.cache_write(tmp, 'global'), tmp
def transpose_interleave(s, tensor, width, y_index, x_index, readers, batched=False):
"""Transpose the tensor, then interleave it"""
tmp = s.cache_read(tensor, 'global', readers)
y, x = s[tmp].op.axis[y_index], s[tmp].op.axis[x_index]
yo, yi = s[tmp].split(y, width)
s[tmp].reorder(yo, x, yi)
s[tmp].fuse(x, yi)
if batched:
z = s[tmp].op.axis[0]
s[tmp].fuse(z, yo)
s[tmp].compute_inline()
return s.cache_write(tmp, 'global'), tmp
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment