Commit ecb0a7ea by mbarrett97 Committed by Tianqi Chen

[TOPI] Added support for Mali Bifrost target (#4047)

parent a21904a5
......@@ -506,6 +506,19 @@ def vta(model='unknown', options=None):
return ret
def bifrost(model='unknown', options=None):
"""Return an ARM Mali GPU target (Bifrost architecture).
Parameters
----------
options : str or list of str
Additional options
"""
opts = ["-device=bifrost", '-model=%s' % model]
opts = _merge_opts(opts, options)
return _api_internal._TargetCreate("opencl", *opts)
def create(target_str):
"""Get a target given target string.
......
......@@ -28,6 +28,7 @@ from . import x86
from . import cuda
from . import arm_cpu
from . import mali
from . import bifrost
from . import intel_graphics
from . import opengl
from . import util
......
......@@ -552,6 +552,9 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
if "-device=arm_cpu" in target.options:
tile_size = 4
VC = cfg['tile_k'].size[-1]
elif "-device=bifrost" in target.options:
tile_size = 2
VC = 0
else:
from ..mali.conv2d import _pick_tile_size
tile_size = _pick_tile_size(tinfos[0], tinfos[1])
......@@ -559,21 +562,28 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
weight = F.nn.contrib_conv2d_winograd_weight_transform(copy_inputs[1],
tile_size=tile_size)
weight = F.reshape(weight,
newshape=(KH + tile_size - 1,
KW + tile_size - 1,
idxd(CO, VC), VC, CI))
weight = F.transpose(weight, axes=[0, 1, 2, 4, 3])
if VC > 0:
weight = F.reshape(weight,
newshape=(KH + tile_size - 1,
KW + tile_size - 1,
idxd(CO, VC), VC, CI))
weight = F.transpose(weight, axes=[0, 1, 2, 4, 3])
new_weight = tvm.placeholder((KH + tile_size - 1,
KW + tile_size -1,
idxd(CO, VC), CI, VC),
kernel.dtype)
else:
weight = F.reshape(weight,
newshape=(KH + tile_size - 1, KW + tile_size - 1, CO, CI))
new_weight = tvm.placeholder(
(KH + tile_size - 1, KW + tile_size -1, CO, CI), kernel.dtype
)
copy_inputs[1] = weight
new_attrs['tile_size'] = tile_size
# Store the same config for the altered operator (workload)
new_data = data
new_weight = tvm.placeholder((KH + tile_size - 1,
KH + tile_size -1,
idxd(CO, VC), CI, VC),
kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_weight, strides, padding, dilation,
new_attrs[data_layout_key], out_dtype, tile_size],
......
# pylint: disable=redefined-builtin, wildcard-import
"""ARM Mali GPU specific declaration and schedules."""
from __future__ import absolute_import as _abs
from .gemm import *
from .conv2d import *
from .dense import *
from .depthwise_conv2d import *
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,unused-variable
"""dense schedule on ARM Mali GPU"""
from __future__ import absolute_import as _abs
import tvm
from tvm import autotvm
from .. import generic, nn
from ..util import traverse_inline
autotvm.register_topi_compute(nn.dense, 'bifrost', 'direct', nn.dense.fdefault)
@autotvm.register_topi_schedule(generic.schedule_dense, 'bifrost', 'direct')
def schedule_dense(cfg, outs):
"""Schedule for dense operator.
Parameters
----------
cfg: ConfigEntity
The config entity for this template
outs: Array of Tensor
The computation graph description of dense
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for dense.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
def _callback(op):
if op.tag == 'dense':
vec_size = [1, 2, 4, 8, 16]
max_unroll = 32
dense = op.output(0)
output = outs[0]
y, x = s[output].op.axis
c = s[dense].op.reduce_axis[0]
##### space definition begin #####
cfg.define_split('tile_y', y, num_outputs=3)
cfg.define_split('tile_x', x, num_outputs=3)
cfg.define_split('c_unroll', c, num_outputs=2, max_factor=64)
# fallback support
if cfg.is_fallback:
ref_log = autotvm.tophub.load_reference_log(
'mali', 'rk3399', 'dense', 'direct')
cfg.fallback_with_reference_log(ref_log)
##### space definition end #####
if dense.op in s.outputs:
dense = s.cache_write(output, 'local')
by, ty, yi = cfg['tile_y'].apply(s, output, y)
bx, tx, xi = cfg['tile_x'].apply(s, output, x)
s[output].bind(by, tvm.thread_axis('blockIdx.y'))
s[output].bind(bx, tvm.thread_axis('blockIdx.x'))
s[output].bind(ty, tvm.thread_axis('threadIdx.y'))
s[output].bind(tx, tvm.thread_axis('threadIdx.x'))
if cfg['tile_y'].size[-1] < max_unroll:
s[output].unroll(yi)
if cfg['tile_x'].size[-1] in vec_size:
s[output].vectorize(xi)
s[dense].compute_at(s[output], tx)
k = s[dense].op.reduce_axis[0]
y, x = s[dense].op.axis
k, k_unroll = cfg['c_unroll'].apply(s, dense, k)
s[dense].reorder(k, k_unroll, y, x)
s[dense].unroll(k_unroll)
if cfg['tile_y'].size[-1] < max_unroll:
s[dense].unroll(y)
if cfg['tile_x'].size[-1] in vec_size:
s[dense].vectorize(x)
traverse_inline(s, outs[0].op, _callback)
return s
def fuse_and_bind(s, tensor, axis=None, num_thread=None):
""" fuse all the axis and bind to GPU threads """
axis = axis or s[tensor].op.axis
fused = s[tensor].fuse(*axis)
bx, tx = s[tensor].split(fused, num_thread)
s[tensor].bind(bx, tvm.thread_axis("blockIdx.x"))
s[tensor].bind(tx, tvm.thread_axis("threadIdx.x"))
return bx, tx
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,unused-variable,unused-argument
"""depthwise_conv2d schedule on ARM Mali GPU"""
from __future__ import absolute_import as _abs
import tvm
from .. import generic
from .. import util
from .. import tag
@generic.schedule_depthwise_conv2d_nchw.register(["bifrost"])
def schedule_depthwise_conv2d_nchw(outs):
"""Schedule for depthwise_conv2d nchw forward.
Parameters
----------
outs: Array of Tensor
The computation graph description of depthwise_conv2d
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for depthwise_conv2d nchw.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
def _schedule(pad_data, kernel, conv):
raw_data = s[pad_data].op.input_tensors[0]
if conv.op not in s.outputs: # has bias or relu
output = outs[0]
else: # no bias or relu
output = conv
def tile_and_bind3d(tensor, z, y, x, z_factor=2, y_factor=None, x_factor=None):
""" tile and bind 3d """
y_factor = y_factor or z_factor
x_factor = x_factor or y_factor
zo, zi = s[tensor].split(z, z_factor)
yo, yi = s[tensor].split(y, y_factor)
xo, xi = s[tensor].split(x, x_factor)
s[tensor].bind(zo, tvm.thread_axis("blockIdx.z"))
s[tensor].bind(zi, tvm.thread_axis("threadIdx.z"))
s[tensor].bind(yo, tvm.thread_axis("blockIdx.y"))
s[tensor].bind(yi, tvm.thread_axis("threadIdx.y"))
s[tensor].bind(xo, tvm.thread_axis("blockIdx.x"))
s[tensor].bind(xi, tvm.thread_axis("threadIdx.x"))
return zo, zi, yo, yi, xo, xi
# set tunable parameters
VH = 1
VW = 1
num_thread = 4
while util.get_const_int(conv.shape[3]) % (VW * 2) == 0 and VW * 2 <= 4:
VW = VW * 2
while util.get_const_int(conv.shape[2]) % (VH * 2) == 0 and VH * 2 <= 2:
VH = VH * 2
if raw_data.dtype == 'float16':
if util.get_const_int(conv.shape[3]) % (VW * 2) == 0:
VW *= 2
num_thread *= 2
else:
num_thread *= 2
# schedule padding
_, c, y, x = s[pad_data].op.axis
tile_and_bind3d(pad_data, c, y, x, num_thread, 1, 1)
# schedule conv
di, dj = s[conv].op.reduce_axis
s[conv].unroll(di)
s[conv].unroll(dj)
_, c, y, x = s[output].op.axis
y, x, yi, xi = s[output].tile(y, x, VH, VW)
s[output].unroll(yi)
s[output].vectorize(xi)
_, _, _, _, _, ji = tile_and_bind3d(output, c, y, x, num_thread, 1, 1)
if conv.op not in s.outputs:
_, c, y, x = s[conv].op.axis
y, x, yi, xi = s[conv].tile(y, x, VH, VW)
s[conv].unroll(yi)
s[conv].vectorize(xi)
s[conv].compute_at(s[output], ji)
def traverse(op):
"""Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(op.tag):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
# schedule depthwise_conv2d
if op.tag == 'depthwise_conv2d_nchw':
pad_data = op.input_tensors[0]
kernel = op.input_tensors[1]
if isinstance(kernel.op, tvm.tensor.ComputeOp) and 'dilate' in kernel.op.tag:
s[kernel].compute_inline()
conv = op.output(0)
_schedule(pad_data, kernel, conv)
traverse(outs[0].op)
return s
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,unused-variable,unused-argument
"""Utility scheduling functions for the Bifrost schedules"""
from __future__ import absolute_import as _abs
import tvm
def fuse_and_bind(s, tensor, axis=None, num_thread=None):
"""Fuse all the axis and bind to GPU threads"""
axis = axis or s[tensor].op.axis
fused = s[tensor].fuse(*axis)
max_threads = tvm.target.current_target(allow_none=False).max_num_threads
bx, tx = s[tensor].split(fused, num_thread or max_threads)
s[tensor].bind(bx, tvm.thread_axis("blockIdx.x"))
s[tensor].bind(tx, tvm.thread_axis("threadIdx.x"))
return bx, tx
def tile_and_bind(s, tensor, y, x, y_factor, x_factor=None):
"""Tile and bind to GPU threads"""
x_factor = x_factor or y_factor
yo, xo, yi, xi = s[tensor].tile(y, x, y_factor, x_factor)
s[tensor].bind(xo, tvm.thread_axis("blockIdx.x"))
s[tensor].bind(xi, tvm.thread_axis("threadIdx.x"))
s[tensor].bind(yo, tvm.thread_axis("blockIdx.y"))
s[tensor].bind(yi, tvm.thread_axis("threadIdx.y"))
return yo, xo, yi, xi
def tile_and_bind3d(s, tensor, z, y, x, z_factor=2, y_factor=None, x_factor=None):
"""Tile and bind 3d"""
y_factor = y_factor or z_factor
x_factor = x_factor or y_factor
zo, zi = s[tensor].split(z, z_factor)
yo, yi = s[tensor].split(y, y_factor)
xo, xi = s[tensor].split(x, x_factor)
s[tensor].bind(zo, tvm.thread_axis("blockIdx.z"))
s[tensor].bind(zi, tvm.thread_axis("threadIdx.z"))
s[tensor].bind(yo, tvm.thread_axis("blockIdx.y"))
s[tensor].bind(yi, tvm.thread_axis("threadIdx.y"))
s[tensor].bind(xo, tvm.thread_axis("blockIdx.x"))
s[tensor].bind(xi, tvm.thread_axis("threadIdx.x"))
return zo, yo, xo, zi, yi, xi
def pack_tensor(s, tensor, factor, readers):
"""Do transform X[n, m] -> X[n / factor, m, factor]"""
tmp = s.cache_read(tensor, 'global', readers)
y, x = s[tmp].op.axis
yo, yi = s[tmp].split(y, factor)
s[tmp].reorder(yo, x, yi)
s[tmp].compute_inline()
return s.cache_write(tmp, 'global'), tmp
def transpose(s, tensor, y_index, x_index, readers):
"""Do transform X[n, m] -> X[m, n]"""
tmp = s.cache_read(tensor, 'global', readers)
y, x = s[tmp].op.axis[y_index], s[tmp].op.axis[x_index]
s[tmp].reorder(x, y)
s[tmp].compute_inline()
A_transpose = s.cache_write(tmp, "global")
CR_A = s.cache_read(tensor, 'local', [A_transpose])
CW_A_transpose = s.cache_write(A_transpose, 'local')
y, x = s[A_transpose].op.axis[y_index], s[A_transpose].op.axis[x_index]
yo, xo, yi, xi = s[A_transpose].tile(y, x, 4, 4)
s[A_transpose].unroll(yi)
s[A_transpose].vectorize(xi)
_, _, _, xi = tile_and_bind(s, A_transpose, yo, xo, 32, 2)
s[CW_A_transpose].compute_at(s[A_transpose], xi)
y, x = s[CW_A_transpose].op.axis[y_index], s[CW_A_transpose].op.axis[x_index]
s[CW_A_transpose].unroll(x)
s[CW_A_transpose].unroll(y)
s[CR_A].compute_at(s[A_transpose], xi)
y, x = s[CR_A].op.axis[y_index], s[CR_A].op.axis[x_index]
s[CR_A].unroll(y)
s[CR_A].vectorize(x)
return tmp
def interleave_transpose(s, tensor, width, y_index, x_index, readers, batched=False):
"""Interleave the tensor, then transpose it"""
tmp = s.cache_read(tensor, 'global', readers)
y, x = s[tmp].op.axis[y_index], s[tmp].op.axis[x_index]
xo, xi = s[tmp].split(x, width)
s[tmp].reorder(xo, y, xi)
s[tmp].fuse(y, xi)
if batched:
z = s[tmp].op.axis[0]
s[tmp].fuse(z, xo)
s[tmp].compute_inline()
return s.cache_write(tmp, 'global'), tmp
def transpose_interleave(s, tensor, width, y_index, x_index, readers, batched=False):
"""Transpose the tensor, then interleave it"""
tmp = s.cache_read(tensor, 'global', readers)
y, x = s[tmp].op.axis[y_index], s[tmp].op.axis[x_index]
yo, yi = s[tmp].split(y, width)
s[tmp].reorder(yo, x, yi)
s[tmp].fuse(x, yi)
if batched:
z = s[tmp].op.axis[0]
s[tmp].fuse(z, yo)
s[tmp].compute_inline()
return s.cache_write(tmp, 'global'), tmp
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment