Unverified Commit f4286cc7 by Shawn-Inspur Committed by GitHub

[TOPI][Tensor Core] Conv2d and Dense ops support on Tensor Core (#5099)

* [TOPI][Tensor Core] Optimization of CNNs on Tensor Core #6004

* update conv2d test

* # pylint: dense_tensorcore.py

* modify

* modify conv2d

* modify the unclear comment,add shape assertion in conv2d compute,combine general gemm intrinsic

* add shape assertion in conv2d compute, combine general gemm intrinsic

Co-authored-by: libaihong <libaihong@inspur.com>
Co-authored-by: libaihong <61525430+libaihong@users.noreply.github.com>
parent 949dca4d
...@@ -17,7 +17,9 @@ ...@@ -17,7 +17,9 @@
"""Definition of CUDA/GPU operator strategy.""" """Definition of CUDA/GPU operator strategy."""
# pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import # pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import
import topi import topi
import tvm
from tvm.te import SpecializedCondition from tvm.te import SpecializedCondition
from tvm.contrib import nvcc
from .generic import * from .generic import *
from .. import op as _op from .. import op as _op
from .... import get_global_func from .... import get_global_func
...@@ -112,13 +114,23 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): ...@@ -112,13 +114,23 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
wrap_compute_conv2d(topi.cuda.conv2d_hwcn), wrap_compute_conv2d(topi.cuda.conv2d_hwcn),
wrap_topi_schedule(topi.cuda.schedule_conv2d_hwcn), wrap_topi_schedule(topi.cuda.schedule_conv2d_hwcn),
name="conv2d_hwcn.cuda") name="conv2d_hwcn.cuda")
# TODO(@alexgl-github): Re-enable this after fix the conv2d_nhwc for cuda elif layout == "NHWC":
# elif layout == "NHWC": assert kernel_layout == "HWIO"
# assert kernel_layout == "HWIO" strategy.add_implementation(
# strategy.add_implementation( wrap_compute_conv2d(topi.cuda.conv2d_nhwc),
# wrap_compute_conv2d(topi.cuda.conv2d_nhwc), wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc),
# wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc), name="conv2d_nhwc.cuda")
# name="conv2d_nhwc.cuda") N, _, _, _ = get_const_tuple(data.shape)
_, _, CI, CO = get_const_tuple(kernel.shape)
if nvcc.have_tensorcore(tvm.gpu(0).compute_version):
if (N % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \
(N % 8 == 0 and CI % 16 == 0 and CO % 32 == 0) or \
(N % 32 == 0 and CI % 16 == 0 and CO % 8 == 0):
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_nhwc_tensorcore),
wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc_tensorcore),
name="conv2d_nhwc_tensorcore.cuda",
plevel=20)
elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]: elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]:
assert kernel_layout == "OIHW4o4i" assert kernel_layout == "OIHW4o4i"
strategy.add_implementation( strategy.add_implementation(
...@@ -279,6 +291,9 @@ def conv1d_transpose_strategy_cuda(attrs, inputs, out_type, target): ...@@ -279,6 +291,9 @@ def conv1d_transpose_strategy_cuda(attrs, inputs, out_type, target):
def dense_strategy_cuda(attrs, inputs, out_type, target): def dense_strategy_cuda(attrs, inputs, out_type, target):
"""dense cuda strategy""" """dense cuda strategy"""
strategy = _op.OpStrategy() strategy = _op.OpStrategy()
data, weights = inputs
b, i = get_const_tuple(data.shape)
o, _ = get_const_tuple(weights.shape)
if out_type.dtype == "int8": if out_type.dtype == "int8":
strategy.add_implementation( strategy.add_implementation(
wrap_compute_dense(topi.cuda.dense_int8), wrap_compute_dense(topi.cuda.dense_int8),
...@@ -289,13 +304,21 @@ def dense_strategy_cuda(attrs, inputs, out_type, target): ...@@ -289,13 +304,21 @@ def dense_strategy_cuda(attrs, inputs, out_type, target):
wrap_compute_dense(topi.cuda.dense_small_batch), wrap_compute_dense(topi.cuda.dense_small_batch),
wrap_topi_schedule(topi.cuda.schedule_dense_small_batch), wrap_topi_schedule(topi.cuda.schedule_dense_small_batch),
name="dense_small_batch.cuda") name="dense_small_batch.cuda")
b = inputs[0].shape[0]
with SpecializedCondition(b >= 32): with SpecializedCondition(b >= 32):
strategy.add_implementation( strategy.add_implementation(
wrap_compute_dense(topi.cuda.dense_large_batch), wrap_compute_dense(topi.cuda.dense_large_batch),
wrap_topi_schedule(topi.cuda.schedule_dense_large_batch), wrap_topi_schedule(topi.cuda.schedule_dense_large_batch),
name="dense_large_batch.cuda", name="dense_large_batch.cuda",
plevel=5) plevel=5)
if nvcc.have_tensorcore(tvm.gpu(0).compute_version):
if(i % 16 == 0 and b % 16 == 0 and o % 16 == 0) \
or (i % 16 == 0 and b % 8 == 0 and o % 32 == 0) \
or (i % 16 == 0 and b % 32 == 0 and o % 8 == 0):
strategy.add_implementation(
wrap_compute_dense(topi.cuda.dense_tensorcore),
wrap_topi_schedule(topi.cuda.schedule_dense_tensorcore),
name="dense_tensorcore.cuda",
plevel=20)
if target.target_name == "cuda" and "cublas" in target.libs: if target.target_name == "cuda" and "cublas" in target.libs:
strategy.add_implementation( strategy.add_implementation(
wrap_compute_dense(topi.cuda.dense_cublas), wrap_compute_dense(topi.cuda.dense_cublas),
......
...@@ -32,7 +32,10 @@ def residual_unit(data, ...@@ -32,7 +32,10 @@ def residual_unit(data,
stride, stride,
dim_match, dim_match,
name, name,
bottle_neck=True): bottle_neck=True,
data_layout="NCHW",
kernel_layout="IOHW"
):
"""Return ResNet Unit symbol for building ResNet """Return ResNet Unit symbol for building ResNet
Parameters Parameters
...@@ -67,42 +70,50 @@ def residual_unit(data, ...@@ -67,42 +70,50 @@ def residual_unit(data,
kernel_size=(1, 1), kernel_size=(1, 1),
strides=stride, strides=stride,
padding=(0, 0), padding=(0, 0),
name=name + '_conv1') name=name + '_conv1',
data_layout=data_layout,
kernel_layout=kernel_layout)
bn2 = layers.batch_norm_infer(data=conv1, epsilon=2e-5, name=name + '_bn2') bn2 = layers.batch_norm_infer(data=conv1, epsilon=2e-5, name=name + '_bn2')
act2 = relay.nn.relu(data=bn2) act2 = relay.nn.relu(data=bn2)
conv2 = layers.conv2d( conv2 = layers.conv2d(
data=act2, channels=int(num_filter*0.25), kernel_size=(3, 3), data=act2, channels=int(num_filter*0.25), kernel_size=(3, 3),
strides=(1, 1), padding=(1, 1), name=name + '_conv2') strides=(1, 1), padding=(1, 1), name=name + '_conv2',
data_layout=data_layout, kernel_layout=kernel_layout)
bn3 = layers.batch_norm_infer(data=conv2, epsilon=2e-5, name=name + '_bn3') bn3 = layers.batch_norm_infer(data=conv2, epsilon=2e-5, name=name + '_bn3')
act3 = relay.nn.relu(data=bn3) act3 = relay.nn.relu(data=bn3)
conv3 = layers.conv2d( conv3 = layers.conv2d(
data=act3, channels=num_filter, kernel_size=(1, 1), data=act3, channels=num_filter, kernel_size=(1, 1),
strides=(1, 1), padding=(0, 0), name=name + '_conv3') strides=(1, 1), padding=(0, 0), name=name + '_conv3',
data_layout=data_layout, kernel_layout=kernel_layout)
if dim_match: if dim_match:
shortcut = data shortcut = data
else: else:
shortcut = layers.conv2d( shortcut = layers.conv2d(
data=act1, channels=num_filter, kernel_size=(1, 1), data=act1, channels=num_filter, kernel_size=(1, 1),
strides=stride, name=name+'_sc') strides=stride, name=name+'_sc',
data_layout=data_layout, kernel_layout=kernel_layout)
return relay.add(conv3, shortcut) return relay.add(conv3, shortcut)
bn1 = layers.batch_norm_infer(data=data, epsilon=2e-5, name=name + '_bn1') bn1 = layers.batch_norm_infer(data=data, epsilon=2e-5, name=name + '_bn1')
act1 = relay.nn.relu(data=bn1) act1 = relay.nn.relu(data=bn1)
conv1 = layers.conv2d( conv1 = layers.conv2d(
data=act1, channels=num_filter, kernel_size=(3, 3), data=act1, channels=num_filter, kernel_size=(3, 3),
strides=stride, padding=(1, 1), name=name + '_conv1') strides=stride, padding=(1, 1), name=name + '_conv1',
data_layout=data_layout, kernel_layout=kernel_layout)
bn2 = layers.batch_norm_infer(data=conv1, epsilon=2e-5, name=name + '_bn2') bn2 = layers.batch_norm_infer(data=conv1, epsilon=2e-5, name=name + '_bn2')
act2 = relay.nn.relu(data=bn2) act2 = relay.nn.relu(data=bn2)
conv2 = layers.conv2d( conv2 = layers.conv2d(
data=act2, channels=num_filter, kernel_size=(3, 3), data=act2, channels=num_filter, kernel_size=(3, 3),
strides=(1, 1), padding=(1, 1), name=name + '_conv2') strides=(1, 1), padding=(1, 1), name=name + '_conv2',
data_layout=data_layout, kernel_layout=kernel_layout)
if dim_match: if dim_match:
shortcut = data shortcut = data
else: else:
shortcut = layers.conv2d( shortcut = layers.conv2d(
data=act1, channels=num_filter, kernel_size=(1, 1), data=act1, channels=num_filter, kernel_size=(1, 1),
strides=stride, name=name+'_sc') strides=stride, name=name+'_sc',
data_layout=data_layout, kernel_layout=kernel_layout)
return relay.add(conv2, shortcut) return relay.add(conv2, shortcut)
...@@ -112,6 +123,7 @@ def resnet(units, ...@@ -112,6 +123,7 @@ def resnet(units,
num_classes, num_classes,
data_shape, data_shape,
bottle_neck=True, bottle_neck=True,
layout="NCHW",
dtype="float32"): dtype="float32"):
"""Return ResNet Program. """Return ResNet Program.
...@@ -135,9 +147,16 @@ def resnet(units, ...@@ -135,9 +147,16 @@ def resnet(units,
bottle_neck : bool bottle_neck : bool
Whether apply bottleneck transformation. Whether apply bottleneck transformation.
layout: str
The data layout for conv2d
dtype : str dtype : str
The global data type. The global data type.
""" """
data_layout = layout
kernel_layout = "OIHW" if layout == "NCHW" else "HWIO"
num_unit = len(units) num_unit = len(units)
assert num_unit == num_stages assert num_unit == num_stages
data = relay.var("data", shape=data_shape, dtype=dtype) data = relay.var("data", shape=data_shape, dtype=dtype)
...@@ -146,27 +165,32 @@ def resnet(units, ...@@ -146,27 +165,32 @@ def resnet(units,
if height <= 32: # such as cifar10 if height <= 32: # such as cifar10
body = layers.conv2d( body = layers.conv2d(
data=data, channels=filter_list[0], kernel_size=(3, 3), data=data, channels=filter_list[0], kernel_size=(3, 3),
strides=(1, 1), padding=(1, 1), name="conv0") strides=(1, 1), padding=(1, 1), name="conv0",
data_layout=data_layout, kernel_layout=kernel_layout)
else: # often expected to be 224 such as imagenet else: # often expected to be 224 such as imagenet
body = layers.conv2d( body = layers.conv2d(
data=data, channels=filter_list[0], kernel_size=(7, 7), data=data, channels=filter_list[0], kernel_size=(7, 7),
strides=(2, 2), padding=(3, 3), name="conv0") strides=(2, 2), padding=(3, 3), name="conv0",
data_layout=data_layout, kernel_layout=kernel_layout)
body = layers.batch_norm_infer(data=body, epsilon=2e-5, name='bn0') body = layers.batch_norm_infer(data=body, epsilon=2e-5, name='bn0')
body = relay.nn.relu(data=body) body = relay.nn.relu(data=body)
body = relay.nn.max_pool2d(data=body, pool_size=(3, 3), strides=(2, 2), padding=(1, 1)) body = relay.nn.max_pool2d(data=body, pool_size=(3, 3), strides=(2, 2), padding=(1, 1),
layout=data_layout)
for i in range(num_stages): for i in range(num_stages):
body = residual_unit( body = residual_unit(
body, filter_list[i+1], (1 if i == 0 else 2, 1 if i == 0 else 2), body, filter_list[i+1], (1 if i == 0 else 2, 1 if i == 0 else 2),
False, name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck) False, name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck,
data_layout=data_layout, kernel_layout=kernel_layout)
for j in range(units[i]-1): for j in range(units[i]-1):
body = residual_unit( body = residual_unit(
body, filter_list[i+1], (1, 1), True, body, filter_list[i+1], (1, 1), True,
name='stage%d_unit%d' % (i + 1, j + 2), bottle_neck=bottle_neck) name='stage%d_unit%d' % (i + 1, j + 2), bottle_neck=bottle_neck,
data_layout=data_layout, kernel_layout=kernel_layout)
bn1 = layers.batch_norm_infer(data=body, epsilon=2e-5, name='bn1') bn1 = layers.batch_norm_infer(data=body, epsilon=2e-5, name='bn1')
relu1 = relay.nn.relu(data=bn1) relu1 = relay.nn.relu(data=bn1)
# Although kernel is not used here when global_pool=True, we should put one # Although kernel is not used here when global_pool=True, we should put one
pool1 = relay.nn.global_avg_pool2d(data=relu1) pool1 = relay.nn.global_avg_pool2d(data=relu1, layout=data_layout)
flat = relay.nn.batch_flatten(data=pool1) flat = relay.nn.batch_flatten(data=pool1)
fc1 = layers.dense_add_bias(data=flat, units=num_classes, name='fc1') fc1 = layers.dense_add_bias(data=flat, units=num_classes, name='fc1')
net = relay.nn.softmax(data=fc1) net = relay.nn.softmax(data=fc1)
...@@ -177,6 +201,7 @@ def get_net(batch_size, ...@@ -177,6 +201,7 @@ def get_net(batch_size,
num_classes, num_classes,
num_layers=50, num_layers=50,
image_shape=(3, 224, 224), image_shape=(3, 224, 224),
layout="NCHW",
dtype="float32", dtype="float32",
**kwargs): **kwargs):
""" """
...@@ -229,6 +254,7 @@ def get_net(batch_size, ...@@ -229,6 +254,7 @@ def get_net(batch_size,
num_classes=num_classes, num_classes=num_classes,
data_shape=data_shape, data_shape=data_shape,
bottle_neck=bottle_neck, bottle_neck=bottle_neck,
layout=layout,
dtype=dtype) dtype=dtype)
...@@ -236,6 +262,7 @@ def get_workload(batch_size=1, ...@@ -236,6 +262,7 @@ def get_workload(batch_size=1,
num_classes=1000, num_classes=1000,
num_layers=18, num_layers=18,
image_shape=(3, 224, 224), image_shape=(3, 224, 224),
layout="NCHW",
dtype="float32", dtype="float32",
**kwargs): **kwargs):
"""Get benchmark workload for resnet """Get benchmark workload for resnet
...@@ -254,6 +281,9 @@ def get_workload(batch_size=1, ...@@ -254,6 +281,9 @@ def get_workload(batch_size=1,
image_shape : tuple, optional image_shape : tuple, optional
The input image shape The input image shape
layout: str
The data layout for conv2d
dtype : str, optional dtype : str, optional
The data type The data type
...@@ -273,5 +303,6 @@ def get_workload(batch_size=1, ...@@ -273,5 +303,6 @@ def get_workload(batch_size=1,
num_layers=num_layers, num_layers=num_layers,
image_shape=image_shape, image_shape=image_shape,
dtype=dtype, dtype=dtype,
layout=layout,
**kwargs) **kwargs)
return create_workload(net) return create_workload(net)
...@@ -201,6 +201,7 @@ Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs, ...@@ -201,6 +201,7 @@ Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
func = tir::ThreadSync(func, "shared"); func = tir::ThreadSync(func, "shared");
func = tir::ThreadSync(func, "warp"); func = tir::ThreadSync(func, "warp");
func = tir::InferFragment(func);
func = tir::LowerThreadAllreduce(func, target->thread_warp_size); func = tir::LowerThreadAllreduce(func, target->thread_warp_size);
auto fsplits = tir::SplitHostDevice(func); auto fsplits = tir::SplitHostDevice(func);
fhost.push_back(fsplits[0]); fhost.push_back(fsplits[0]);
...@@ -244,6 +245,12 @@ Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs, ...@@ -244,6 +245,12 @@ Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
<< "\n"; << "\n";
} }
for (size_t i = 0; i < fdevice.size(); ++i) {
auto func = fdevice[i];
func = tir::LowerDeviceStorageAccessInfo(func);
fdevice.Set(i, func);
}
for (size_t i = 0; i < fhost.size(); ++i) { for (size_t i = 0; i < fhost.size(); ++i) {
auto func = fhost[i]; auto func = fhost[i];
func = tir::BindDeviceType(func, target->device_type); func = tir::BindDeviceType(func, target->device_type);
......
...@@ -43,3 +43,5 @@ from .ssd import * ...@@ -43,3 +43,5 @@ from .ssd import *
from .nms import get_valid_counts, non_max_suppression from .nms import get_valid_counts, non_max_suppression
from .rcnn import * from .rcnn import *
from .sort import * from .sort import *
from .conv2d_nhwc_tensorcore import *
from .dense_tensorcore import *
...@@ -24,6 +24,7 @@ from .. import nn, generic ...@@ -24,6 +24,7 @@ from .. import nn, generic
from ..nn.util import get_pad_tuple from ..nn.util import get_pad_tuple
from ..util import get_const_tuple, traverse_inline from ..util import get_const_tuple, traverse_inline
from .conv2d_direct import schedule_direct_cuda from .conv2d_direct import schedule_direct_cuda
from .conv2d_nhwc import schedule_conv2d_nhwc_direct
@autotvm.register_topi_compute("conv2d_nchw.cuda") @autotvm.register_topi_compute("conv2d_nchw.cuda")
...@@ -46,24 +47,22 @@ def schedule_conv2d_nchw(cfg, outs): ...@@ -46,24 +47,22 @@ def schedule_conv2d_nchw(cfg, outs):
return s return s
# TODO(@alexgl-github): It's invalid to call schedule_direct_cuda for NHWC layout @autotvm.register_topi_compute("conv2d_nhwc.cuda")
# as it assumes the input layout to be NCHW. Please fix this. def conv2d_nhwc(cfg, data, kernel, strides, padding, dilation, out_dtype='float32'):
# @autotvm.register_topi_compute("conv2d_nhwc.cuda") """Compute conv2d with NHWC layout"""
# def conv2d_nhwc(cfg, data, kernel, strides, padding, dilation, out_dtype='float32'): return nn.conv2d_nhwc(data, kernel, strides, padding, dilation, out_dtype)
# return nn.conv2d_nhwc(data, kernel, strides, padding, dilation, out_dtype)
#
# @autotvm.register_topi_schedule("conv2d_nhwc.cuda")
# @autotvm.register_topi_schedule("conv2d_nhwc.cuda") def schedule_conv2d_nhwc(cfg, outs):
# def schedule_conv2d_nhwc(cfg, outs): """Create the schedule for conv2d_nhwc"""
# outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
# s = te.create_schedule([x.op for x in outs]) s = te.create_schedule([x.op for x in outs])
# def _callback(op):
# def _callback(op): if op.tag == 'conv2d_nhwc':
# if op.tag == 'conv2d_nhwc': schedule_conv2d_nhwc_direct(cfg, s, op.output(0))
# schedule_direct_cuda(cfg, s, op.output(0)) traverse_inline(s, outs[0].op, _callback)
# return s
# traverse_inline(s, outs[0].op, _callback)
# return s
@autotvm.register_topi_compute("conv2d_cudnn.cuda") @autotvm.register_topi_compute("conv2d_cudnn.cuda")
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, too-many-locals, too-many-statements, unused-argument
"""Direct conv2d in NHWC layout"""
import tvm
from tvm import te
from tvm import autotvm
from ..util import get_const_tuple
def schedule_conv2d_nhwc_direct(cfg, s, Conv):
"""schedule optimized for NHWC direct conv2d"""
pad_data, kernel = s[Conv].op.input_tensors
s[pad_data].compute_inline()
if isinstance(kernel.op, tvm.te.ComputeOp) and 'dilate' in kernel.op.tag:
s[kernel].compute_inline()
if Conv.op in s.outputs:
output = Conv
OL = s.cache_write(Conv, 'local')
else:
output = s.outputs[0].output(0)
s[Conv].set_scope('local')
OL = Conv
# create cache stage
AA = s.cache_read(pad_data, 'shared', [OL])
WW = s.cache_read(kernel, "shared", [OL])
AL = s.cache_read(AA, "local", [OL])
WL = s.cache_read(WW, "local", [OL])
# Schedule for autotvm
cfg.define_knob("tile_n", [2, 4, 8])
cfg.define_knob("tile_c", [2, 4, 8])
cfg.define_knob("num_thread_n", [4, 8, 16])
cfg.define_knob("num_thread_c", [4, 8, 16])
cfg.define_knob("vthread_n", [1, 2])
cfg.define_knob("vthread_c", [1, 2])
cfg.define_knob("step", [16, 3, 32, 64])
# fallback support
target = tvm.target.Target.current()
if cfg.is_fallback:
ref_log = autotvm.tophub.load_reference_log(
target.target_name, target.model, 'conv2d_nhwc.cuda')
cfg.fallback_with_reference_log(ref_log)
tile_n = cfg["tile_n"].val
tile_c = cfg["tile_c"].val
num_thread_n = cfg["num_thread_n"].val
num_thread_c = cfg["num_thread_c"].val
vthread_n = cfg["vthread_n"].val
vthread_c = cfg["vthread_c"].val
step = cfg["step"].val
block_factor_c = tile_c * num_thread_c * vthread_c
offset = 8
A_align = step + offset
W_align = block_factor_c + offset
block_x = te.thread_axis("blockIdx.x")
block_y = te.thread_axis("blockIdx.y")
block_z = te.thread_axis("blockIdx.z")
thread_x = te.thread_axis((0, num_thread_c), "threadIdx.x")
thread_y = te.thread_axis((0, num_thread_n), "threadIdx.y")
thread_xz = te.thread_axis((0, vthread_c), "vthread", name="vx")
thread_yz = te.thread_axis((0, vthread_n), "vthread", name="vy")
# Schedule for output
ni, hi, wi, fi = s[output].op.axis
bz = s[output].fuse(hi, wi)
tx, fi = s[output].split(fi, factor=tile_c)
txz, tx = s[output].split(tx, factor=num_thread_c)
bx, txz = s[output].split(txz, factor=vthread_c)
ty, ni = s[output].split(ni, factor=tile_n)
tyz, ty = s[output].split(ty, factor=num_thread_n)
by, tyz = s[output].split(tyz, factor=vthread_n)
s[output].reorder(bz, by, bx, tyz, txz, ty, tx, ni, fi)
s[output].bind(bz, block_z)
s[output].bind(by, block_y)
s[output].bind(bx, block_x)
s[output].bind(tyz, thread_yz)
s[output].bind(txz, thread_xz)
s[output].bind(ty, thread_y)
s[output].bind(tx, thread_x)
# Schedule local computation
s[OL].compute_at(s[output], tx)
ni, yi, xi, fi = s[OL].op.axis
ry, rx, rc = s[OL].op.reduce_axis
rco, rci = s[OL].split(rc, factor=step)
s[OL].reorder(rco, ry, rx, rci, ni, fi)
s[AA].compute_at(s[OL], rx)
s[WW].compute_at(s[OL], rx)
s[AL].compute_at(s[OL], rci)
s[WL].compute_at(s[OL], rci)
# Schedule for data's share memory
ni, yi, xi, ci = s[AA].op.axis
s[AA].reorder(yi, xi, ni, ci)
s[AA].storage_align(xi, A_align - 1, A_align)
t = s[AA].fuse(ni, ci)
ty, tx = s[AA].split(t, factor=num_thread_c)
_, ty = s[AA].split(ty, factor=num_thread_n)
s[AA].bind(tx, thread_x)
s[AA].bind(ty, thread_y)
# Schedule for kernel's share memory
_, _, ic, o = s[WW].op.axis
t = s[WW].fuse(ic, o)
s[WW].storage_align(ic, W_align - 1, W_align)
ty, tx = s[WW].split(t, factor=num_thread_c)
_, ty = s[WW].split(ty, factor=num_thread_n)
s[WW].bind(tx, thread_x)
s[WW].bind(ty, thread_y)
N, OH, OW, CO = get_const_tuple(output.shape)
KH, KW, CI, _ = get_const_tuple(kernel.shape)
cfg.add_flop(2 * N * OH * OW * CO * CI * KH * KW)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, too-many-locals, too-many-function-args
# pylint: disable=too-many-statements, unused-argument, too-many-arguments
"""Tensorcore template for cuda backend"""
import numpy as np
import tvm
from tvm import te
from tvm import autotvm
from ..util import get_const_tuple, traverse_inline, simplify
from ..nn.pad import pad
from ..nn.util import get_pad_tuple
from .tensor_intrin import intrin_wmma_load_matrix_A
from .tensor_intrin import intrin_wmma_load_matrix_W
from .tensor_intrin import intrin_wmma_store_matrix
from .tensor_intrin import intrin_wmma_gemm
def nhwc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, out_dtype):
"""Compute declaration for tensorcore"""
assert isinstance(stride, int) or len(stride) == 2
assert isinstance(dilation, int) or len(dilation) == 2
if isinstance(stride, int):
stride_h = stride_w = stride
else:
stride_h, stride_w = stride
if isinstance(dilation, int):
dilation_h = dilation_w = dilation
else:
dilation_h, dilation_w = dilation
batch, in_height, in_width, in_channel = get_const_tuple(Input.shape)
kernel_h, kernel_w, _, num_filter = get_const_tuple(Filter.shape)
assert (batch % 16 == 0 and in_channel % 16 == 0 and num_filter % 16 == 0) or \
(batch % 8 == 0 and in_channel % 16 == 0 and num_filter % 32 == 0) or \
(batch % 32 == 0 and in_channel % 16 == 0 and num_filter % 8 == 0), \
"The shape of (batch, in_channel, num_filter) "\
"must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for now"
# compute the output shape
dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
padding, (dilated_kernel_h, dilated_kernel_w))
out_channel = num_filter
out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)
pad_before = [0, pad_top, pad_left, 0]
pad_after = [0, pad_down, pad_right, 0]
PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput")
rc = te.reduce_axis((0, in_channel), name='rc')
ry = te.reduce_axis((0, kernel_h), name='ry')
rx = te.reduce_axis((0, kernel_w), name='rx')
# convert data type of input feature maps and weights
TransPaddedInput = te.compute(
PaddedInput.shape,
lambda h, w, i, o: PaddedInput[h, w, i, o].astype('float16'))
TransFilter = te.compute(
Filter.shape, lambda h, w, i, o: Filter[h, w, i, o].astype('float16'))
Output = te.compute(
(batch, out_height, out_width, out_channel),
lambda nn, yy, xx, ff: te.sum(
TransPaddedInput[nn, yy * stride_h + ry * dilation_h,
xx * stride_w + rx * dilation_w, rc].astype(out_dtype) *
TransFilter[ry, rx, rc, ff].astype(out_dtype), axis=[ry, rx, rc]),
name="Conv2dOutput", tag="conv2d_nhwc_tensorcore")
return Output
def schedule_nhwc_tensorcore_cuda(cfg, s, Conv):
"""Schedule tensorcore template"""
kh, kw, ic = s[Conv].op.reduce_axis
out_dtype = Conv.dtype
trans_paddata, kernel = s[Conv].op.input_tensors
in_dtype = trans_paddata.dtype
batch, _, _, _ = get_const_tuple(Conv.shape)
_, _, _, out_channels = get_const_tuple(kernel.shape)
paddata = s[trans_paddata].op.input_tensors
# inline the pad and dtype transform
s[trans_paddata].compute_inline()
s[kernel].compute_inline()
s[paddata[0]].compute_inline()
# Designate the memory hierarchy
AS = s.cache_read(trans_paddata, 'shared', [Conv])
WS = s.cache_read(kernel, 'shared', [Conv])
AF = s.cache_read(AS, 'wmma.matrix_a', [Conv])
WF = s.cache_read(WS, 'wmma.matrix_b', [Conv])
ConvF = s.cache_write(Conv, 'wmma.accumulator')
if Conv.op in s.outputs:
output = Conv
ConvS = s.cache_read(ConvF, 'shared', [Conv])
OL = ConvS
else:
output = s.outputs[0].output(0)
s[Conv].set_scope('shared')
OL = Conv
# Schedule for autotvm
cfg.define_knob("block_row_warps", [1, 2, 4])
cfg.define_knob("block_col_warps", [1, 2, 4])
cfg.define_knob("warp_row_tiles", [1, 2, 4])
cfg.define_knob("warp_col_tiles", [1, 2, 4])
cfg.define_knob("chunk", [1, 2, 4, 8])
cfg.define_knob("offset", [0, 8])
cfg.define_knob("vector_width", [1, 2, 4, 8])
if (batch % 16 == 0 and out_channels % 16 == 0):
cfg.define_knob("wmma_m", [16, 8, 32])
elif (batch % 8 == 0 and out_channels % 32 == 0):
cfg.define_knob("wmma_m", [8, 16, 32])
elif (batch % 32 == 0 and out_channels % 8 == 0):
cfg.define_knob("wmma_m", [32, 16, 8])
# fallback support
target = tvm.target.Target.current()
if cfg.is_fallback:
ref_log = autotvm.tophub.load_reference_log(
target.target_name, target.model, 'conv2d_nhwc_tensorcore.cuda')
cfg.fallback_with_reference_log(ref_log)
block_row_warps = cfg["block_row_warps"].val
block_col_warps = cfg["block_col_warps"].val
warp_row_tiles = cfg["warp_row_tiles"].val
warp_col_tiles = cfg["warp_col_tiles"].val
chunk = cfg["chunk"].val
offset = cfg["offset"].val
wmma_m = cfg["wmma_m"].val
vector_width = cfg["vector_width"].val
wmma_k = 16
if wmma_m == 16:
wmma_n = 16
elif wmma_m == 8:
wmma_n = 32
elif wmma_m == 32:
wmma_n = 8
warp_size = 32
block_x = te.thread_axis('blockIdx.x')
block_y = te.thread_axis('blockIdx.y')
block_z = te.thread_axis('blockIdx.z')
thread_x = te.thread_axis('threadIdx.x')
thread_y = te.thread_axis('threadIdx.y')
thread_z = te.thread_axis('threadIdx.z')
# Define the intrin strides
def get_strides(extents):
return [np.prod(extents[i:]).tolist() for i in range(len(extents))]
AS_align = chunk * wmma_k + offset
WS_align = warp_col_tiles * block_col_warps * wmma_n + offset
block_factor_n = wmma_m * warp_row_tiles * block_row_warps
block_factor_o = wmma_n * warp_col_tiles * block_col_warps
CS_align = block_factor_o + offset
AS_strides = get_strides([1, 1, AS_align, 1])
AL_strides = get_strides([1, 1, wmma_k, 1])
WS_strides = get_strides([WS_align, 1])
WL_strides = get_strides([wmma_n * warp_col_tiles, 1])
CL_strides = get_strides([1, 1, wmma_n * warp_col_tiles, 1])
CS_strides = get_strides([1, 1, CS_align, 1])
# Schedule for output
nc, hc, wc, oc = output.op.axis
block_k = s[output].fuse(hc, wc)
s[output].bind(block_k, block_z)
block_i, nc = s[output].split(nc, factor=block_factor_n)
block_j, oc = s[output].split(oc, factor=block_factor_o)
s[output].reorder(block_k, block_i, block_j, nc, oc)
t = s[output].fuse(nc, oc)
t, ti = s[output].split(t, factor=vector_width)
t, tx = s[output].split(t, factor=warp_size)
t, ty = s[output].split(t, factor=block_row_warps)
t, tz = s[output].split(t, factor=block_col_warps)
s[output].bind(block_i, block_x)
s[output].bind(block_j, block_y)
s[output].bind(tz, thread_z)
s[output].bind(ty, thread_y)
s[output].bind(tx, thread_x)
s[output].vectorize(ti)
# Schedule wmma store
s[OL].compute_at(s[output], block_j)
nc, hc, wc, oc = OL.op.axis
s[OL].reorder(hc, wc, nc, oc)
s[OL].storage_align(wc, CS_align - 1, CS_align)
oc, ooc = s[OL].split(oc, factor=wmma_n)
oc, oci = s[OL].split(oc, factor=warp_col_tiles)
_, oc = s[OL].split(oc, factor=block_col_warps)
nc, nnc = s[OL].split(nc, factor=wmma_m)
nc, nci = s[OL].split(nc, factor=warp_row_tiles)
_, nc = s[OL].split(nc, factor=block_row_warps)
s[OL].reorder(nc, oc, nci, oci, nnc, ooc)
s[OL].bind(nc, thread_y)
s[OL].bind(oc, thread_z)
# Schedule wmma computation
s[ConvF].compute_at(s[OL], oc)
n, h, w, o = ConvF.op.axis
n, nnf = s[ConvF].split(n, factor=wmma_m)
o, oof = s[ConvF].split(o, factor=wmma_n)
ic, ii = s[ConvF].split(ic, factor=wmma_k)
ko, ki = s[ConvF].split(ic, factor=chunk)
s[ConvF].reorder(kh, kw, ko, ki, n, o, nnf, oof, ii)
s[AF].compute_at(s[ConvF], ki)
s[WF].compute_at(s[ConvF], ki)
# Schedule wmma load
n, h, w, i = AF.op.axis
n, nn = s[AF].split(n, factor=wmma_m)
i, ii = s[AF].split(i, factor=wmma_k)
s[AF].reorder(n, i, nn, ii)
kh, kw, i, o = WF.op.axis
i, ii = s[WF].split(i, factor=wmma_k)
o, oo = s[WF].split(o, factor=wmma_n)
s[WF].reorder(o, i, oo)
s[WF].reorder(i, o, ii, oo)
s[WS].compute_at(s[ConvF], ko)
s[AS].compute_at(s[ConvF], ko)
# Schedule for data's share memory
n, h, w, i = AS.op.axis
s[AS].reorder(h, w, n, i)
s[AS].storage_align(w, AS_align - 1, AS_align)
t = s[AS].fuse(n, i)
t, ti = s[AS].split(t, factor=vector_width)
t, tx = s[AS].split(t, factor=warp_size)
t, ty = s[AS].split(t, factor=block_row_warps)
_, tz = s[AS].split(t, factor=block_col_warps)
s[AS].bind(ty, thread_y)
s[AS].bind(tz, thread_z)
s[AS].bind(tx, thread_x)
s[AS].vectorize(ti)
# Schedule for kernel's share memory
kh, kw, ic, o = WS.op.axis
t = s[WS].fuse(ic, o)
s[WS].storage_align(ic, WS_align - 1, WS_align)
t, ti = s[WS].split(t, factor=vector_width)
t, tx = s[WS].split(t, factor=warp_size)
t, ty = s[WS].split(t, factor=block_row_warps)
_, tz = s[WS].split(t, factor=block_col_warps)
s[WS].bind(ty, thread_y)
s[WS].bind(tz, thread_z)
s[WS].bind(tx, thread_x)
s[WS].vectorize(ti)
shape = (wmma_m, wmma_n, wmma_k)
# tensorize the wmma process
AS_shape = (wmma_m, 1, 1, wmma_k)
AL_shape = (wmma_m, 1, 1, wmma_k)
WS_shape = (wmma_k, wmma_n)
WL_shape = (wmma_k, wmma_n)
CL_shape = (wmma_m, 1, 1, wmma_n)
CS_shape = (wmma_m, 1, 1, wmma_n)
AL_gemm = te.placeholder(AL_shape, name='A', dtype=in_dtype)
WL_gemm = te.placeholder(WL_shape, name='B', dtype=in_dtype)
k_gemm = te.reduce_axis((0, wmma_k), name="k")
CL_compute = te.compute(CL_shape, lambda ii, t0, t1, jj:
te.sum(AL_gemm[ii, t0, t1, k_gemm].astype(out_dtype) * \
WL_gemm[k_gemm, jj].astype(out_dtype), axis=k_gemm),
name='C')
s[AF].tensorize(nn, intrin_wmma_load_matrix_A(AL_strides, AS_strides, shape,
"row_major", AS_shape, AL_shape, in_dtype))
s[WF].tensorize(ii, intrin_wmma_load_matrix_W(WL_strides, WS_strides, shape,
"row_major", WS_shape, WL_shape, in_dtype))
s[OL].tensorize(nnc, intrin_wmma_store_matrix(CS_strides, CL_strides,
shape, out_dtype, CL_shape, CS_shape))
s[ConvF].tensorize(nnf, intrin_wmma_gemm(AL_gemm, WL_gemm, CL_compute, AL_strides,
WL_strides, CL_strides, shape))
N, OH, OW, CO = get_const_tuple(output.shape)
KH, KW, CI, _ = get_const_tuple(kernel.shape)
cfg.add_flop(2 * N * OH * OW * CO * CI * KH * KW)
@autotvm.register_topi_compute("conv2d_nhwc_tensorcore.cuda")
def conv2d_nhwc_tensorcore(cfg, data, kernel, strides, padding, dilation, out_dtype):
"""Compute conv2d with tensorcore for NCHW layout"""
return nhwc_tensorcore_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype)
@autotvm.register_topi_schedule("conv2d_nhwc_tensorcore.cuda")
def schedule_conv2d_nhwc_tensorcore(cfg, outs):
"""TOPI schedule callback"""
s = te.create_schedule([x.op for x in outs])
def _callback(op):
if 'conv2d_nhwc_tensorcore' in op.tag:
schedule_nhwc_tensorcore_cuda(cfg, s, op.output(0))
traverse_inline(s, outs[0].op, _callback)
return s
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, too-many-locals, too-many-statements, unused-argument
"""Compute and Schedule definition for dense tensorcore with cuda backend"""
from __future__ import absolute_import as _abs
import tvm
from tvm import te
import tvm.autotvm as autotvm
from .. import tag
from ..util import traverse_inline, get_const_tuple
from .tensor_intrin import intrin_wmma_load_matrix_A, \
intrin_wmma_load_matrix_W, intrin_wmma_store_matrix, intrin_wmma_gemm
@autotvm.register_topi_compute("dense_tensorcore.cuda")
def dense_tensorcore(cfg, data, weight, bias=None, out_dtype=None):
"""Dense tensorcore operator on CUDA"""
matmul = dense_tensorcore_cuda(data, weight, bias, out_dtype)
return matmul
@autotvm.register_topi_schedule("dense_tensorcore.cuda")
def schedule_dense_tensorcore(cfg, outs):
"""Schedule dense operator using Tensorcore"""
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
s = te.create_schedule([x.op for x in outs])
def _callback(op):
if op.tag == 'dense_tensorcore':
_schedule_dense_tensorcore(cfg, s, op.output(0))
traverse_inline(s, outs[0].op, _callback)
return s
def dense_tensorcore_cuda(data, weight, bias=None, out_dtype=None):
"""Dense tensorcore operator on CUDA"""
assert len(data.shape) == 2 and len(weight.shape) == 2, \
"only support 2-dim dense"
if bias is not None:
assert len(bias.shape) == 1
if out_dtype is None:
out_dtype = data.dtype
batch, in_dim = get_const_tuple(data.shape)
out_dim, _ = get_const_tuple(weight.shape)
assert ((batch % 8 == 0 and in_dim % 16 == 0 and out_dim % 32 == 0) or \
(batch % 16 == 0 and in_dim % 16 == 0 and out_dim % 16 == 0) or \
(batch % 32 == 0 and in_dim % 16 == 0 and out_dim % 8 == 0)), \
"The shape of (batch, in_dim, out_dim) "\
"must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for now"
k = te.reduce_axis((0, in_dim), name='k')
data_16 = te.compute((batch, in_dim), lambda b, i: data[b, i].astype('float16'))
weight_16 = te.compute((out_dim, in_dim), lambda o, i: weight[o, i].astype('float16'))
matmul = te.compute((batch, out_dim), \
lambda i, j: te.sum(data_16[i, k].astype(out_dtype) * \
weight_16[j, k].astype(out_dtype), axis=k), \
name='T_dense', tag='dense_tensorcore')
if bias is not None:
matmul = te.compute((batch, out_dim), \
lambda i, j: matmul[i, j] + bias[j].astype(out_dtype), \
tag=tag.BROADCAST)
return matmul
def _schedule_dense_tensorcore(cfg, s, C):
"""Schedule dense operator using Tensorcore"""
A, B = s[C].op.input_tensors
batch, out_dim = get_const_tuple(C.shape)
out_dtype = C.dtype
s[A].compute_inline()
s[B].compute_inline()
# Explicit memory access
AS = s.cache_read(A, 'shared', [C])
BS = s.cache_read(B, 'shared', [C])
AF = s.cache_read(AS, 'wmma.matrix_a', [C])
BF = s.cache_read(BS, 'wmma.matrix_b', [C])
CF = s.cache_write(C, 'wmma.accumulator')
CS = s.cache_read(CF, 'shared', [C])
# fallback support
target = tvm.target.Target.current()
if cfg.is_fallback:
ref_log = autotvm.tophub.load_reference_log(
target.target_name, target.model, 'dense_tensorcore.cuda')
cfg.fallback_with_reference_log(ref_log)
# Deal with op fusion, such as bias and relu
if C.op not in s.outputs:
s[C].compute_inline()
C = s.outputs[0].output(0)
# create tuning space
cfg.define_knob("block_row_warps", [1, 2, 4])
cfg.define_knob("block_col_warps", [1, 2, 4])
cfg.define_knob("warp_row_tiles", [1, 2, 4])
cfg.define_knob("warp_col_tiles", [1, 2, 4])
cfg.define_knob("chunk", [1, 2, 4, 8])
cfg.define_knob("offset", [0, 8])
cfg.define_knob("offsetCS", [0, 8])
cfg.define_knob("vec", [1, 2, 4, 8])
#Ensure that the default parameters are applicable when autotvm is not in use
if (batch % 32 == 0 and out_dim % 8 == 0):
cfg.define_knob("wmma_m", [32, 16, 8])
elif (batch%16 == 0 and out_dim % 16 == 0):
cfg.define_knob("wmma_m", [16, 8, 32])
elif (batch % 8 == 0 and out_dim % 32 == 0):
cfg.define_knob("wmma_m", [8, 16, 32])
warp_size = 32
wmma_k = 16
block_row_warps = cfg["block_row_warps"].val
block_col_warps = cfg["block_col_warps"].val
warp_row_tiles = cfg["warp_row_tiles"].val
warp_col_tiles = cfg["warp_col_tiles"].val
chunk = cfg["chunk"].val
offset = cfg["offset"].val
offsetCS = cfg["offsetCS"].val
wmma_m = cfg["wmma_m"].val
vec = cfg["vec"].val
if wmma_m == 16:
wmma_n = 16
elif wmma_m == 8:
wmma_n = 32
elif wmma_m == 32:
wmma_n = 8
#Define the stride of intrin functions
AS_align = chunk * wmma_k + offset
BS_align = chunk * wmma_k + offset
CS_align = warp_col_tiles * block_col_warps * wmma_n + offsetCS
AS_stride = [AS_align, 1]
BS_stride = [BS_align, 1]
AF_stride = [wmma_k, 1]
BF_stride = [wmma_k, 1]
CF_stride = [warp_col_tiles * wmma_n, 1]
CS_stride = [CS_align, 1]
block_x = te.thread_axis('blockIdx.x')
block_y = te.thread_axis('blockIdx.y')
thread_x = te.thread_axis('threadIdx.x')
thread_y = te.thread_axis('threadIdx.y')
thread_z = te.thread_axis('threadIdx.z')
#Schedule for dense computation
block_factor_b = wmma_m * warp_row_tiles * block_row_warps
block_factor_o = wmma_n * warp_col_tiles * block_col_warps
b, o = C.op.axis
block_i, bc = s[C].split(b, factor=block_factor_b)
block_j, oc = s[C].split(o, factor=block_factor_o)
s[C].reorder(block_i, block_j, bc, oc)
t = s[C].fuse(bc, oc)
t, vi = s[C].split(t, factor=vec)
t, tx = s[C].split(t, factor=warp_size)
t, ty = s[C].split(t, factor=block_row_warps)
t, tz = s[C].split(t, factor=block_col_warps)
s[C].bind(block_i, block_x)
s[C].bind(block_j, block_y)
s[C].bind(tz, thread_z)
s[C].bind(ty, thread_y)
s[C].bind(tx, thread_x)
s[C].vectorize(vi)
#Schedule for wmma store
s[CS].compute_at(s[C], block_j)
bb, oo = CS.op.axis
s[CS].storage_align(bb, CS_align - 1, CS_align)
bb, bbi = s[CS].split(bb, factor=wmma_m)
oo, ooi = s[CS].split(oo, factor=wmma_n)
bb, bbii = s[CS].split(bb, factor=warp_row_tiles)
oo, ooii = s[CS].split(oo, factor=warp_col_tiles)
s[CS].reorder(bb, oo, bbii, ooii, bbi, ooi)
#Schedule for wmma computation
s[CF].compute_at(s[CS], oo)
warp_i, warp_j = CF.op.axis
warp_i, _ii = s[CF].split(warp_i, factor=wmma_m)
warp_j, _jj = s[CF].split(warp_j, factor=wmma_n)
k, = CF.op.reduce_axis
k, _k = s[CF].split(k, factor=wmma_k)
ko, ki = s[CF].split(k, factor=chunk)
s[CF].reorder(ko, ki, warp_i, warp_j, _ii, _jj, _k)
#Schedule for wmma_matrix_a load
s[AF].compute_at(s[CF], ki)
b, i = AF.op.axis
b, b_ii = s[AF].split(b, factor=wmma_m)
i, i_jj = s[AF].split(i, factor=wmma_k)
s[AF].reorder(b, i, b_ii, i_jj)
#Schedule for wmma_matrix_b load
s[BF].compute_at(s[CF], ki)
o, i = BF.op.axis
o, o_ii = s[BF].split(o, factor=wmma_n)
i, i_ii = s[BF].split(i, factor=wmma_k)
s[BF].reorder(o, i, o_ii, i_ii)
#Schedule for A's(B's) shared memory load
def shared_shedule(stage, strides):
s[stage].compute_at(s[CF], ko)
xo, yo = stage.op.axis
s[stage].storage_align(xo, strides - 1, strides)
t = s[stage].fuse(xo, yo)
t, vi = s[stage].split(t, factor=vec)
t, tx = s[stage].split(t, factor=warp_size)
t, ty = s[stage].split(t, factor=block_row_warps)
_, tz = s[stage].split(t, factor=block_col_warps)
s[stage].bind(ty, thread_y)
s[stage].bind(tz, thread_z)
s[stage].bind(tx, thread_x)
s[stage].vectorize(vi)
shared_shedule(AS, AS_align)
shared_shedule(BS, BS_align)
shape = (wmma_m, wmma_n, wmma_k)
in_dtype = 'float16'
AL_gemm = te.placeholder((wmma_m, wmma_k), name='AL_gemm', dtype=in_dtype)
BL_gemm = te.placeholder((wmma_n, wmma_k), name='BL_gemm', dtype=in_dtype)
k_gemm = te.reduce_axis((0, wmma_k), name='k_gemm')
CL_compute = te.compute((wmma_m, wmma_n), lambda ii, jj:
te.sum(AL_gemm[ii, k_gemm].astype(out_dtype) *\
BL_gemm[jj, k_gemm].astype(out_dtype),\
axis=k_gemm), name='CL_compute')
#lower the computation loops down to TensorCore hardware intrinsics
#by mapping the dense tensorcore to tensor intrinsics
s[AF].tensorize(b_ii, intrin_wmma_load_matrix_A( \
AF_stride, AS_stride, shape, "row_major",\
(wmma_m, wmma_k), (wmma_m, wmma_k), 'float16'))
s[BF].tensorize(o_ii, intrin_wmma_load_matrix_W( \
BF_stride, BS_stride, shape, "col_major",\
(wmma_n, wmma_k), (wmma_n, wmma_k), 'float16'))
s[CF].tensorize(_ii, intrin_wmma_gemm( \
AL_gemm, BL_gemm, CL_compute, AF_stride, BF_stride, CF_stride, shape))
s[CS].tensorize(bbi, intrin_wmma_store_matrix( \
CS_stride, CF_stride, shape, out_dtype, (wmma_m, wmma_n), (wmma_m, wmma_n)))
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=invalid-name, unnecessary-lambda, too-many-arguments
"""Tensor intrinsics on CUDA.""" """Tensor intrinsics on CUDA."""
#pylint: disable=invalid-name
import tvm import tvm
from tvm import te from tvm import te
...@@ -77,3 +77,148 @@ def dp4a(x_scope='local', y_scope='local', z_scope='local'): ...@@ -77,3 +77,148 @@ def dp4a(x_scope='local', y_scope='local', z_scope='local'):
scope=scopes[t]) for t in [x, y, z]} scope=scopes[t]) for t in [x, y, z]}
return te.decl_tensor_intrin(z.op, _intrin_func, binds=binds) return te.decl_tensor_intrin(z.op, _intrin_func, binds=binds)
def intrin_wmma_load_matrix_A(strides_dst, strides_from, shape, layout, A_shape, C_shape, in_dtype):
"""Intrin function for loading data from shared memory to wmma.matrix_a"""
wmma_m, wmma_n, wmma_k = shape
A = te.placeholder(A_shape, name='A', dtype=in_dtype)
BA = tvm.tir.decl_buffer(A.shape, A.dtype,
scope='shared', strides=strides_from,
data_alignment=32, offset_factor=8)
C = te.compute(C_shape, lambda *i: A(*i), name='C')
BC = tvm.tir.decl_buffer(C.shape, C.dtype,
scope="wmma.matrix_a", strides=strides_dst,
data_alignment=32, offset_factor=8)
def intrin_func(ins, outs):
ib = tvm.tir.ir_builder.create()
BA = ins[0]
BC = outs[0]
row = wmma_m * wmma_k
warp_index = BC.elem_offset // row + BC.elem_offset % row // wmma_k
ib.emit(tvm.tir.call_intrin('handle', 'tvm_load_matrix_sync',
BC.data, wmma_m, wmma_n, wmma_k, warp_index,
BA.access_ptr('r'), strides_from[0], layout))
return ib.get()
return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})
def intrin_wmma_load_matrix_W(strides_dst, strides_from, shape, layout, A_shape, C_shape, in_dtype):
"""Intrin function for loading data from shared memory to wmma.matrix_b"""
wmma_m, wmma_n, wmma_k = shape
A = te.placeholder(A_shape, name='A', dtype=in_dtype)
BA = tvm.tir.decl_buffer(A.shape, A.dtype,
scope='shared', strides=strides_from,
data_alignment=32, offset_factor=8)
C = te.compute(C_shape, lambda *i: A(*i), name='C')
BC = tvm.tir.decl_buffer(C.shape, C.dtype,
scope="wmma.matrix_b", strides=strides_dst,
data_alignment=32, offset_factor=8)
def intrin_func(ins, outs):
ib = tvm.tir.ir_builder.create()
BA = ins[0]
BC = outs[0]
row = wmma_n * wmma_k
warp_index = BC.elem_offset // row + BC.elem_offset % row // wmma_n
ib.emit(tvm.tir.call_intrin('handle', 'tvm_load_matrix_sync',
BC.data, wmma_m, wmma_n, wmma_k, warp_index,
BA.access_ptr('r'), strides_from[0], layout))
return ib.get()
return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})
def intrin_wmma_store_matrix(strides_dst, strides_from, shape, out_dtype, A_shape, C_shape):
"""Intrin function for storing the results from wmma.accumulator to shared"""
wmma_m, wmma_n, wmma_k = shape
A = te.placeholder(A_shape, name='A', dtype=out_dtype)
BA = tvm.tir.decl_buffer(A.shape, A.dtype,
scope='wmma.accumulator',
strides=strides_from, data_alignment=32,
offset_factor=8)
C = te.compute(C_shape, lambda *i: A(*i), name='C')
BC = tvm.tir.decl_buffer(C.shape, C.dtype,
scope='shared', strides=strides_dst,
data_alignment=32, offset_factor=8)
def intrin_func(ins, outs):
ib = tvm.tir.ir_builder.create()
BA = ins[0]
BC = outs[0]
row = wmma_m * wmma_n
warp_index = BA.elem_offset // row + BA.elem_offset % row // wmma_n
ib.emit(tvm.tir.call_intrin('handle', 'tvm_store_matrix_sync',
BA.data, wmma_m, wmma_n, wmma_k, warp_index,
BC.access_ptr('w'), strides_dst[0], 'row_major'))
return ib.get()
return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})
def intrin_wmma_gemm(AL_gemm, WL_gemm, CL_compute, strides_A,
strides_W, strides_Conv, shape):
"""Intrin for wmma fill_fragment and mma_sync
Parameters
----------
AL_gemm : tvm.te.placeholder
wmma matrix A
WL_gemm : tvm.te.placeholder
wmma matrix B
CL_compute : tvm.te.compute
The definition of wmma gemm
"""
wmma_m, wmma_n, wmma_k = shape
A = AL_gemm
B = WL_gemm
C = CL_compute
BA = tvm.tir.decl_buffer(A.shape, A.dtype, name='BA',
scope='wmma.matrix_a', data_alignment=32,
offset_factor=8, strides=strides_A)
BB = tvm.tir.decl_buffer(B.shape, B.dtype, name='BB',
scope='wmma.matrix_b', data_alignment=32,
offset_factor=8, strides=strides_W)
BC = tvm.tir.decl_buffer(C.shape, C.dtype, name='BC',
scope='wmma.accumulator', data_alignment=32,
offset_factor=8, strides=strides_Conv)
def intrin_func(ins, outs):
BA, BB = ins
BC, = outs
def warp_idnex(offset, row, col):
row = row * col
return offset // row + offset % row // col
warp_index_A = warp_idnex(BA.elem_offset, wmma_m, wmma_k)
warp_index_B = warp_idnex(BB.elem_offset, wmma_k, wmma_n)
warp_index_C = warp_idnex(BC.elem_offset, wmma_m, wmma_n)
def init():
ib = tvm.tir.ir_builder.create()
ib.emit(
tvm.tir.call_intrin('handle', 'tvm_fill_fragment', BC.data, wmma_m, wmma_n, wmma_k,
warp_index_C, 0.0))
return ib.get()
def update():
ib = tvm.tir.ir_builder.create()
ib.emit(tvm.tir.call_intrin('handle', 'tvm_mma_sync',
BC.data, warp_index_C,
BA.data, warp_index_A,
BB.data, warp_index_B,
BC.data, warp_index_C))
return ib.get()
return update(), init(), update()
return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, B: BB, C: BC})
...@@ -27,7 +27,8 @@ from topi.util import get_const_tuple ...@@ -27,7 +27,8 @@ from topi.util import get_const_tuple
_conv2d_nhwc_implement = { _conv2d_nhwc_implement = {
"generic": (topi.nn.conv2d_nhwc, topi.generic.schedule_conv2d_nhwc), "llvm": (topi.nn.conv2d_nhwc, topi.generic.schedule_conv2d_nhwc),
"cuda": (topi.cuda.conv2d_nhwc, topi.cuda.schedule_conv2d_nhwc),
"cpu": (topi.nn.conv2d_nhwc, topi.x86.schedule_conv2d_nhwc), "cpu": (topi.nn.conv2d_nhwc, topi.x86.schedule_conv2d_nhwc),
"arm_cpu": (topi.arm_cpu.conv2d_nhwc_spatial_pack, "arm_cpu": (topi.arm_cpu.conv2d_nhwc_spatial_pack,
topi.arm_cpu.schedule_conv2d_nhwc_spatial_pack), topi.arm_cpu.schedule_conv2d_nhwc_spatial_pack),
...@@ -60,9 +61,9 @@ def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride, p ...@@ -60,9 +61,9 @@ def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride, p
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
B = topi.nn.conv2d(A, W, (stride, stride), padding, fcompute, fschedule = topi.testing.dispatch(device, _conv2d_nhwc_implement)
(dilation, dilation), layout='NHWC', out_dtype=dtype) B = fcompute(A, W, stride, padding, dilation, dtype)
s = topi.generic.schedule_conv2d_nhwc([B]) s = fschedule([B])
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx) w = tvm.nd.array(w_np, ctx)
...@@ -71,8 +72,7 @@ def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride, p ...@@ -71,8 +72,7 @@ def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride, p
func(a, w, b) func(a, w, b)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
# TODO(@alexgl-github): add cuda back after fix conv2d_nhwc for cuda for device in ['llvm', 'cuda']:
for device in ['llvm']:
check_device(device) check_device(device)
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, too-many-locals, too-many-arguments
"""Example code to do convolution."""
import numpy as np
import tvm
import topi
import topi.testing
from tvm import te
from tvm.contrib.pickle_memoize import memoize
from tvm.contrib import nvcc
from topi.nn.util import get_pad_tuple
from topi.util import get_const_tuple
_conv2d_nhwc_tensorcore_implement = {
"cuda": (topi.cuda.conv2d_nhwc_tensorcore, topi.cuda.schedule_conv2d_nhwc_tensorcore)
}
def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride,
padding, dilation=1, add_bias=False, add_relu=False, devices='cuda'):
"""Test the conv2d with tensorcore for nhwc layout"""
pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
padding_sum = pad_top + pad_left + pad_bottom + pad_right
print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (
batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
in_height = in_width = in_size
A = te.placeholder((batch, in_height, in_width, in_channel), name='A')
W = te.placeholder((kernel, kernel, in_channel, num_filter), name='W')
bias = te.placeholder((1, 1, 1, num_filter), name='bias')
a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
bias_shape = get_const_tuple(bias.shape)
dtype = A.dtype
@memoize("topi.tests.test_topi_conv2d_nhwc.verify_conv2d_nhwc")
def get_ref_data():
a_np = np.random.uniform(size=a_shape).astype(dtype)
w_np = np.random.uniform(size=w_shape).astype(dtype)
b_np = np.random.uniform(size=bias_shape).astype(dtype)
dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
c_np = topi.testing.conv2d_nhwc_python(a_np, dw_np, stride, padding)
if add_bias:
b_np = np.random.uniform(size=bias_shape).astype(dtype)
c_np += b_np
if add_relu:
c_np = np.maximum(c_np, 0)
return a_np, w_np, b_np, c_np
a_np, w_np, b_np, c_np = get_ref_data()
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
if not nvcc.have_tensorcore(ctx.compute_version):
print("skip because gpu does not support Tensor Cores")
return
print("Running on target: %s" % device)
with tvm.target.create(device):
fcompute, fschedule = topi.testing.dispatch(device, _conv2d_nhwc_tensorcore_implement)
C = fcompute(A, W, stride, padding, dilation, 'float32')
if add_bias:
C = topi.add(C, bias)
if add_relu:
C = topi.nn.relu(C)
s = fschedule([C])
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
if add_bias:
func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (
batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
func(a, w, b, c)
else:
func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (
batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
func(a, w, c)
rtol = 1e-3
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=rtol)
check_device(devices)
def test_conv2d_nhwc_tensorcore():
"""Test the conv2d with tensorcore for nhwc layout"""
verify_conv2d_nhwc(16, 16, 14, 16, 3, 1, 1)
verify_conv2d_nhwc(16, 128, 7, 128, 7, 1, 3)
verify_conv2d_nhwc(16, 160, 7, 160, 7, 1, 3)
verify_conv2d_nhwc(32, 64, 14, 64, 3, 1, 1, add_bias=True)
verify_conv2d_nhwc(32, 64, 14, 64, 3, 1, 1, add_relu=True)
verify_conv2d_nhwc(32, 64, 14, 64, 3, 1, 1, add_relu=True, add_bias=True)
verify_conv2d_nhwc(16, 64, 17, 64, 7, 1, (3, 3, 2, 2))
verify_conv2d_nhwc(16, 64, 17, 64, 7, 1, "SAME")
verify_conv2d_nhwc(16, 48, 35, 48, 5, 1, "VALID")
verify_conv2d_nhwc(16, 48, 56, 48, 3, 1, (1, 1, 1, 1))
verify_conv2d_nhwc(16, 64, 28, 64, 3, 1, (1, 1, 1, 1))
if __name__ == "__main__":
test_conv2d_nhwc_tensorcore()
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, too-many-locals, too-many-statements, unused-argument
"""Test code for dense tensorcore operator"""
import numpy as np
import tvm
import topi
import topi.testing
from topi.util import get_const_tuple
from tvm import te
from tvm.contrib.pickle_memoize import memoize
from tvm.contrib import nvcc
_dense_implement = {
"gpu": [(topi.cuda.dense_tensorcore, topi.cuda.schedule_dense_tensorcore)]
}
def verify_dense(batch, in_dim, out_dim, use_bias=True):
"""Dense tensorcore verify function"""
A = te.placeholder((batch, in_dim), name='A')
B = te.placeholder((out_dim, in_dim), name='B')
C = te.placeholder((out_dim,), name='C')
dtype = A.dtype
# use memoize to pickle the test data for next time use
@memoize("topi.tests.test_topi_dense_tensorcore")
def get_ref_data():
a_np = np.random.uniform(size=(batch, in_dim)).astype(dtype)
b_np = np.random.uniform(size=(out_dim, in_dim)).astype(dtype)
c_np = np.random.uniform(size=(out_dim,)).astype(dtype)
if use_bias:
d_np = np.maximum(np.dot(a_np, b_np.T) + c_np, 0.0)
else:
d_np = np.maximum(np.dot(a_np, b_np.T), 0.0)
return (a_np, b_np, c_np, d_np)
# get the test data
a_np, b_np, c_np, d_np = get_ref_data()
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
if not nvcc.have_tensorcore(ctx.compute_version):
print("skip because gpu does not support Tensor Cores")
return
print("Running on target: %s" % device)
for fcompute, fschedule in topi.testing.dispatch(device, _dense_implement):
with tvm.target.create(device):
D = fcompute(A, B, C if use_bias else None)
D = topi.nn.relu(D)
s = fschedule([D])
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(c_np, ctx)
d = tvm.nd.array(np.zeros(get_const_tuple(D.shape), dtype=dtype), ctx)
f = tvm.build(s, [A, B, C, D], device, name="dense")
f(a, b, c, d)
tvm.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-3)
for device in ['cuda']:
check_device(device)
def test_dense_tensorcore():
"""Test cases"""
verify_dense(8, 16, 32, use_bias=True)
verify_dense(16, 32, 16, use_bias=True)
verify_dense(256, 1024, 1024, use_bias=True)
verify_dense(1000, 1024, 1024, use_bias=False)
verify_dense(256, 2048, 1000, use_bias=False)
if __name__ == "__main__":
test_dense_tensorcore()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment