Unverified Commit f4286cc7 by Shawn-Inspur Committed by GitHub

[TOPI][Tensor Core] Conv2d and Dense ops support on Tensor Core (#5099)

* [TOPI][Tensor Core] Optimization of CNNs on Tensor Core #6004

* update conv2d test

* # pylint: dense_tensorcore.py

* modify

* modify conv2d

* modify the unclear comment,add shape assertion in conv2d compute,combine general gemm intrinsic

* add shape assertion in conv2d compute, combine general gemm intrinsic

Co-authored-by: libaihong <libaihong@inspur.com>
Co-authored-by: libaihong <61525430+libaihong@users.noreply.github.com>
parent 949dca4d
......@@ -17,7 +17,9 @@
"""Definition of CUDA/GPU operator strategy."""
# pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import
import topi
import tvm
from tvm.te import SpecializedCondition
from tvm.contrib import nvcc
from .generic import *
from .. import op as _op
from .... import get_global_func
......@@ -112,13 +114,23 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
wrap_compute_conv2d(topi.cuda.conv2d_hwcn),
wrap_topi_schedule(topi.cuda.schedule_conv2d_hwcn),
name="conv2d_hwcn.cuda")
# TODO(@alexgl-github): Re-enable this after fix the conv2d_nhwc for cuda
# elif layout == "NHWC":
# assert kernel_layout == "HWIO"
# strategy.add_implementation(
# wrap_compute_conv2d(topi.cuda.conv2d_nhwc),
# wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc),
# name="conv2d_nhwc.cuda")
elif layout == "NHWC":
assert kernel_layout == "HWIO"
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_nhwc),
wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc),
name="conv2d_nhwc.cuda")
N, _, _, _ = get_const_tuple(data.shape)
_, _, CI, CO = get_const_tuple(kernel.shape)
if nvcc.have_tensorcore(tvm.gpu(0).compute_version):
if (N % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \
(N % 8 == 0 and CI % 16 == 0 and CO % 32 == 0) or \
(N % 32 == 0 and CI % 16 == 0 and CO % 8 == 0):
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_nhwc_tensorcore),
wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc_tensorcore),
name="conv2d_nhwc_tensorcore.cuda",
plevel=20)
elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]:
assert kernel_layout == "OIHW4o4i"
strategy.add_implementation(
......@@ -279,6 +291,9 @@ def conv1d_transpose_strategy_cuda(attrs, inputs, out_type, target):
def dense_strategy_cuda(attrs, inputs, out_type, target):
"""dense cuda strategy"""
strategy = _op.OpStrategy()
data, weights = inputs
b, i = get_const_tuple(data.shape)
o, _ = get_const_tuple(weights.shape)
if out_type.dtype == "int8":
strategy.add_implementation(
wrap_compute_dense(topi.cuda.dense_int8),
......@@ -289,13 +304,21 @@ def dense_strategy_cuda(attrs, inputs, out_type, target):
wrap_compute_dense(topi.cuda.dense_small_batch),
wrap_topi_schedule(topi.cuda.schedule_dense_small_batch),
name="dense_small_batch.cuda")
b = inputs[0].shape[0]
with SpecializedCondition(b >= 32):
strategy.add_implementation(
wrap_compute_dense(topi.cuda.dense_large_batch),
wrap_topi_schedule(topi.cuda.schedule_dense_large_batch),
name="dense_large_batch.cuda",
plevel=5)
if nvcc.have_tensorcore(tvm.gpu(0).compute_version):
if(i % 16 == 0 and b % 16 == 0 and o % 16 == 0) \
or (i % 16 == 0 and b % 8 == 0 and o % 32 == 0) \
or (i % 16 == 0 and b % 32 == 0 and o % 8 == 0):
strategy.add_implementation(
wrap_compute_dense(topi.cuda.dense_tensorcore),
wrap_topi_schedule(topi.cuda.schedule_dense_tensorcore),
name="dense_tensorcore.cuda",
plevel=20)
if target.target_name == "cuda" and "cublas" in target.libs:
strategy.add_implementation(
wrap_compute_dense(topi.cuda.dense_cublas),
......
......@@ -32,7 +32,10 @@ def residual_unit(data,
stride,
dim_match,
name,
bottle_neck=True):
bottle_neck=True,
data_layout="NCHW",
kernel_layout="IOHW"
):
"""Return ResNet Unit symbol for building ResNet
Parameters
......@@ -67,42 +70,50 @@ def residual_unit(data,
kernel_size=(1, 1),
strides=stride,
padding=(0, 0),
name=name + '_conv1')
name=name + '_conv1',
data_layout=data_layout,
kernel_layout=kernel_layout)
bn2 = layers.batch_norm_infer(data=conv1, epsilon=2e-5, name=name + '_bn2')
act2 = relay.nn.relu(data=bn2)
conv2 = layers.conv2d(
data=act2, channels=int(num_filter*0.25), kernel_size=(3, 3),
strides=(1, 1), padding=(1, 1), name=name + '_conv2')
strides=(1, 1), padding=(1, 1), name=name + '_conv2',
data_layout=data_layout, kernel_layout=kernel_layout)
bn3 = layers.batch_norm_infer(data=conv2, epsilon=2e-5, name=name + '_bn3')
act3 = relay.nn.relu(data=bn3)
conv3 = layers.conv2d(
data=act3, channels=num_filter, kernel_size=(1, 1),
strides=(1, 1), padding=(0, 0), name=name + '_conv3')
strides=(1, 1), padding=(0, 0), name=name + '_conv3',
data_layout=data_layout, kernel_layout=kernel_layout)
if dim_match:
shortcut = data
else:
shortcut = layers.conv2d(
data=act1, channels=num_filter, kernel_size=(1, 1),
strides=stride, name=name+'_sc')
strides=stride, name=name+'_sc',
data_layout=data_layout, kernel_layout=kernel_layout)
return relay.add(conv3, shortcut)
bn1 = layers.batch_norm_infer(data=data, epsilon=2e-5, name=name + '_bn1')
act1 = relay.nn.relu(data=bn1)
conv1 = layers.conv2d(
data=act1, channels=num_filter, kernel_size=(3, 3),
strides=stride, padding=(1, 1), name=name + '_conv1')
strides=stride, padding=(1, 1), name=name + '_conv1',
data_layout=data_layout, kernel_layout=kernel_layout)
bn2 = layers.batch_norm_infer(data=conv1, epsilon=2e-5, name=name + '_bn2')
act2 = relay.nn.relu(data=bn2)
conv2 = layers.conv2d(
data=act2, channels=num_filter, kernel_size=(3, 3),
strides=(1, 1), padding=(1, 1), name=name + '_conv2')
strides=(1, 1), padding=(1, 1), name=name + '_conv2',
data_layout=data_layout, kernel_layout=kernel_layout)
if dim_match:
shortcut = data
else:
shortcut = layers.conv2d(
data=act1, channels=num_filter, kernel_size=(1, 1),
strides=stride, name=name+'_sc')
strides=stride, name=name+'_sc',
data_layout=data_layout, kernel_layout=kernel_layout)
return relay.add(conv2, shortcut)
......@@ -112,6 +123,7 @@ def resnet(units,
num_classes,
data_shape,
bottle_neck=True,
layout="NCHW",
dtype="float32"):
"""Return ResNet Program.
......@@ -135,9 +147,16 @@ def resnet(units,
bottle_neck : bool
Whether apply bottleneck transformation.
layout: str
The data layout for conv2d
dtype : str
The global data type.
"""
data_layout = layout
kernel_layout = "OIHW" if layout == "NCHW" else "HWIO"
num_unit = len(units)
assert num_unit == num_stages
data = relay.var("data", shape=data_shape, dtype=dtype)
......@@ -146,27 +165,32 @@ def resnet(units,
if height <= 32: # such as cifar10
body = layers.conv2d(
data=data, channels=filter_list[0], kernel_size=(3, 3),
strides=(1, 1), padding=(1, 1), name="conv0")
strides=(1, 1), padding=(1, 1), name="conv0",
data_layout=data_layout, kernel_layout=kernel_layout)
else: # often expected to be 224 such as imagenet
body = layers.conv2d(
data=data, channels=filter_list[0], kernel_size=(7, 7),
strides=(2, 2), padding=(3, 3), name="conv0")
strides=(2, 2), padding=(3, 3), name="conv0",
data_layout=data_layout, kernel_layout=kernel_layout)
body = layers.batch_norm_infer(data=body, epsilon=2e-5, name='bn0')
body = relay.nn.relu(data=body)
body = relay.nn.max_pool2d(data=body, pool_size=(3, 3), strides=(2, 2), padding=(1, 1))
body = relay.nn.max_pool2d(data=body, pool_size=(3, 3), strides=(2, 2), padding=(1, 1),
layout=data_layout)
for i in range(num_stages):
body = residual_unit(
body, filter_list[i+1], (1 if i == 0 else 2, 1 if i == 0 else 2),
False, name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck)
False, name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck,
data_layout=data_layout, kernel_layout=kernel_layout)
for j in range(units[i]-1):
body = residual_unit(
body, filter_list[i+1], (1, 1), True,
name='stage%d_unit%d' % (i + 1, j + 2), bottle_neck=bottle_neck)
name='stage%d_unit%d' % (i + 1, j + 2), bottle_neck=bottle_neck,
data_layout=data_layout, kernel_layout=kernel_layout)
bn1 = layers.batch_norm_infer(data=body, epsilon=2e-5, name='bn1')
relu1 = relay.nn.relu(data=bn1)
# Although kernel is not used here when global_pool=True, we should put one
pool1 = relay.nn.global_avg_pool2d(data=relu1)
pool1 = relay.nn.global_avg_pool2d(data=relu1, layout=data_layout)
flat = relay.nn.batch_flatten(data=pool1)
fc1 = layers.dense_add_bias(data=flat, units=num_classes, name='fc1')
net = relay.nn.softmax(data=fc1)
......@@ -177,6 +201,7 @@ def get_net(batch_size,
num_classes,
num_layers=50,
image_shape=(3, 224, 224),
layout="NCHW",
dtype="float32",
**kwargs):
"""
......@@ -229,6 +254,7 @@ def get_net(batch_size,
num_classes=num_classes,
data_shape=data_shape,
bottle_neck=bottle_neck,
layout=layout,
dtype=dtype)
......@@ -236,6 +262,7 @@ def get_workload(batch_size=1,
num_classes=1000,
num_layers=18,
image_shape=(3, 224, 224),
layout="NCHW",
dtype="float32",
**kwargs):
"""Get benchmark workload for resnet
......@@ -254,6 +281,9 @@ def get_workload(batch_size=1,
image_shape : tuple, optional
The input image shape
layout: str
The data layout for conv2d
dtype : str, optional
The data type
......@@ -273,5 +303,6 @@ def get_workload(batch_size=1,
num_layers=num_layers,
image_shape=image_shape,
dtype=dtype,
layout=layout,
**kwargs)
return create_workload(net)
......@@ -201,6 +201,7 @@ Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
func = tir::ThreadSync(func, "shared");
func = tir::ThreadSync(func, "warp");
func = tir::InferFragment(func);
func = tir::LowerThreadAllreduce(func, target->thread_warp_size);
auto fsplits = tir::SplitHostDevice(func);
fhost.push_back(fsplits[0]);
......@@ -244,6 +245,12 @@ Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
<< "\n";
}
for (size_t i = 0; i < fdevice.size(); ++i) {
auto func = fdevice[i];
func = tir::LowerDeviceStorageAccessInfo(func);
fdevice.Set(i, func);
}
for (size_t i = 0; i < fhost.size(); ++i) {
auto func = fhost[i];
func = tir::BindDeviceType(func, target->device_type);
......
......@@ -43,3 +43,5 @@ from .ssd import *
from .nms import get_valid_counts, non_max_suppression
from .rcnn import *
from .sort import *
from .conv2d_nhwc_tensorcore import *
from .dense_tensorcore import *
......@@ -24,6 +24,7 @@ from .. import nn, generic
from ..nn.util import get_pad_tuple
from ..util import get_const_tuple, traverse_inline
from .conv2d_direct import schedule_direct_cuda
from .conv2d_nhwc import schedule_conv2d_nhwc_direct
@autotvm.register_topi_compute("conv2d_nchw.cuda")
......@@ -46,24 +47,22 @@ def schedule_conv2d_nchw(cfg, outs):
return s
# TODO(@alexgl-github): It's invalid to call schedule_direct_cuda for NHWC layout
# as it assumes the input layout to be NCHW. Please fix this.
# @autotvm.register_topi_compute("conv2d_nhwc.cuda")
# def conv2d_nhwc(cfg, data, kernel, strides, padding, dilation, out_dtype='float32'):
# return nn.conv2d_nhwc(data, kernel, strides, padding, dilation, out_dtype)
#
#
# @autotvm.register_topi_schedule("conv2d_nhwc.cuda")
# def schedule_conv2d_nhwc(cfg, outs):
# outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
# s = te.create_schedule([x.op for x in outs])
#
# def _callback(op):
# if op.tag == 'conv2d_nhwc':
# schedule_direct_cuda(cfg, s, op.output(0))
#
# traverse_inline(s, outs[0].op, _callback)
# return s
@autotvm.register_topi_compute("conv2d_nhwc.cuda")
def conv2d_nhwc(cfg, data, kernel, strides, padding, dilation, out_dtype='float32'):
"""Compute conv2d with NHWC layout"""
return nn.conv2d_nhwc(data, kernel, strides, padding, dilation, out_dtype)
@autotvm.register_topi_schedule("conv2d_nhwc.cuda")
def schedule_conv2d_nhwc(cfg, outs):
"""Create the schedule for conv2d_nhwc"""
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
s = te.create_schedule([x.op for x in outs])
def _callback(op):
if op.tag == 'conv2d_nhwc':
schedule_conv2d_nhwc_direct(cfg, s, op.output(0))
traverse_inline(s, outs[0].op, _callback)
return s
@autotvm.register_topi_compute("conv2d_cudnn.cuda")
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, too-many-locals, too-many-statements, unused-argument
"""Direct conv2d in NHWC layout"""
import tvm
from tvm import te
from tvm import autotvm
from ..util import get_const_tuple
def schedule_conv2d_nhwc_direct(cfg, s, Conv):
"""schedule optimized for NHWC direct conv2d"""
pad_data, kernel = s[Conv].op.input_tensors
s[pad_data].compute_inline()
if isinstance(kernel.op, tvm.te.ComputeOp) and 'dilate' in kernel.op.tag:
s[kernel].compute_inline()
if Conv.op in s.outputs:
output = Conv
OL = s.cache_write(Conv, 'local')
else:
output = s.outputs[0].output(0)
s[Conv].set_scope('local')
OL = Conv
# create cache stage
AA = s.cache_read(pad_data, 'shared', [OL])
WW = s.cache_read(kernel, "shared", [OL])
AL = s.cache_read(AA, "local", [OL])
WL = s.cache_read(WW, "local", [OL])
# Schedule for autotvm
cfg.define_knob("tile_n", [2, 4, 8])
cfg.define_knob("tile_c", [2, 4, 8])
cfg.define_knob("num_thread_n", [4, 8, 16])
cfg.define_knob("num_thread_c", [4, 8, 16])
cfg.define_knob("vthread_n", [1, 2])
cfg.define_knob("vthread_c", [1, 2])
cfg.define_knob("step", [16, 3, 32, 64])
# fallback support
target = tvm.target.Target.current()
if cfg.is_fallback:
ref_log = autotvm.tophub.load_reference_log(
target.target_name, target.model, 'conv2d_nhwc.cuda')
cfg.fallback_with_reference_log(ref_log)
tile_n = cfg["tile_n"].val
tile_c = cfg["tile_c"].val
num_thread_n = cfg["num_thread_n"].val
num_thread_c = cfg["num_thread_c"].val
vthread_n = cfg["vthread_n"].val
vthread_c = cfg["vthread_c"].val
step = cfg["step"].val
block_factor_c = tile_c * num_thread_c * vthread_c
offset = 8
A_align = step + offset
W_align = block_factor_c + offset
block_x = te.thread_axis("blockIdx.x")
block_y = te.thread_axis("blockIdx.y")
block_z = te.thread_axis("blockIdx.z")
thread_x = te.thread_axis((0, num_thread_c), "threadIdx.x")
thread_y = te.thread_axis((0, num_thread_n), "threadIdx.y")
thread_xz = te.thread_axis((0, vthread_c), "vthread", name="vx")
thread_yz = te.thread_axis((0, vthread_n), "vthread", name="vy")
# Schedule for output
ni, hi, wi, fi = s[output].op.axis
bz = s[output].fuse(hi, wi)
tx, fi = s[output].split(fi, factor=tile_c)
txz, tx = s[output].split(tx, factor=num_thread_c)
bx, txz = s[output].split(txz, factor=vthread_c)
ty, ni = s[output].split(ni, factor=tile_n)
tyz, ty = s[output].split(ty, factor=num_thread_n)
by, tyz = s[output].split(tyz, factor=vthread_n)
s[output].reorder(bz, by, bx, tyz, txz, ty, tx, ni, fi)
s[output].bind(bz, block_z)
s[output].bind(by, block_y)
s[output].bind(bx, block_x)
s[output].bind(tyz, thread_yz)
s[output].bind(txz, thread_xz)
s[output].bind(ty, thread_y)
s[output].bind(tx, thread_x)
# Schedule local computation
s[OL].compute_at(s[output], tx)
ni, yi, xi, fi = s[OL].op.axis
ry, rx, rc = s[OL].op.reduce_axis
rco, rci = s[OL].split(rc, factor=step)
s[OL].reorder(rco, ry, rx, rci, ni, fi)
s[AA].compute_at(s[OL], rx)
s[WW].compute_at(s[OL], rx)
s[AL].compute_at(s[OL], rci)
s[WL].compute_at(s[OL], rci)
# Schedule for data's share memory
ni, yi, xi, ci = s[AA].op.axis
s[AA].reorder(yi, xi, ni, ci)
s[AA].storage_align(xi, A_align - 1, A_align)
t = s[AA].fuse(ni, ci)
ty, tx = s[AA].split(t, factor=num_thread_c)
_, ty = s[AA].split(ty, factor=num_thread_n)
s[AA].bind(tx, thread_x)
s[AA].bind(ty, thread_y)
# Schedule for kernel's share memory
_, _, ic, o = s[WW].op.axis
t = s[WW].fuse(ic, o)
s[WW].storage_align(ic, W_align - 1, W_align)
ty, tx = s[WW].split(t, factor=num_thread_c)
_, ty = s[WW].split(ty, factor=num_thread_n)
s[WW].bind(tx, thread_x)
s[WW].bind(ty, thread_y)
N, OH, OW, CO = get_const_tuple(output.shape)
KH, KW, CI, _ = get_const_tuple(kernel.shape)
cfg.add_flop(2 * N * OH * OW * CO * CI * KH * KW)
......@@ -14,8 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unnecessary-lambda, too-many-arguments
"""Tensor intrinsics on CUDA."""
#pylint: disable=invalid-name
import tvm
from tvm import te
......@@ -77,3 +77,148 @@ def dp4a(x_scope='local', y_scope='local', z_scope='local'):
scope=scopes[t]) for t in [x, y, z]}
return te.decl_tensor_intrin(z.op, _intrin_func, binds=binds)
def intrin_wmma_load_matrix_A(strides_dst, strides_from, shape, layout, A_shape, C_shape, in_dtype):
"""Intrin function for loading data from shared memory to wmma.matrix_a"""
wmma_m, wmma_n, wmma_k = shape
A = te.placeholder(A_shape, name='A', dtype=in_dtype)
BA = tvm.tir.decl_buffer(A.shape, A.dtype,
scope='shared', strides=strides_from,
data_alignment=32, offset_factor=8)
C = te.compute(C_shape, lambda *i: A(*i), name='C')
BC = tvm.tir.decl_buffer(C.shape, C.dtype,
scope="wmma.matrix_a", strides=strides_dst,
data_alignment=32, offset_factor=8)
def intrin_func(ins, outs):
ib = tvm.tir.ir_builder.create()
BA = ins[0]
BC = outs[0]
row = wmma_m * wmma_k
warp_index = BC.elem_offset // row + BC.elem_offset % row // wmma_k
ib.emit(tvm.tir.call_intrin('handle', 'tvm_load_matrix_sync',
BC.data, wmma_m, wmma_n, wmma_k, warp_index,
BA.access_ptr('r'), strides_from[0], layout))
return ib.get()
return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})
def intrin_wmma_load_matrix_W(strides_dst, strides_from, shape, layout, A_shape, C_shape, in_dtype):
"""Intrin function for loading data from shared memory to wmma.matrix_b"""
wmma_m, wmma_n, wmma_k = shape
A = te.placeholder(A_shape, name='A', dtype=in_dtype)
BA = tvm.tir.decl_buffer(A.shape, A.dtype,
scope='shared', strides=strides_from,
data_alignment=32, offset_factor=8)
C = te.compute(C_shape, lambda *i: A(*i), name='C')
BC = tvm.tir.decl_buffer(C.shape, C.dtype,
scope="wmma.matrix_b", strides=strides_dst,
data_alignment=32, offset_factor=8)
def intrin_func(ins, outs):
ib = tvm.tir.ir_builder.create()
BA = ins[0]
BC = outs[0]
row = wmma_n * wmma_k
warp_index = BC.elem_offset // row + BC.elem_offset % row // wmma_n
ib.emit(tvm.tir.call_intrin('handle', 'tvm_load_matrix_sync',
BC.data, wmma_m, wmma_n, wmma_k, warp_index,
BA.access_ptr('r'), strides_from[0], layout))
return ib.get()
return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})
def intrin_wmma_store_matrix(strides_dst, strides_from, shape, out_dtype, A_shape, C_shape):
"""Intrin function for storing the results from wmma.accumulator to shared"""
wmma_m, wmma_n, wmma_k = shape
A = te.placeholder(A_shape, name='A', dtype=out_dtype)
BA = tvm.tir.decl_buffer(A.shape, A.dtype,
scope='wmma.accumulator',
strides=strides_from, data_alignment=32,
offset_factor=8)
C = te.compute(C_shape, lambda *i: A(*i), name='C')
BC = tvm.tir.decl_buffer(C.shape, C.dtype,
scope='shared', strides=strides_dst,
data_alignment=32, offset_factor=8)
def intrin_func(ins, outs):
ib = tvm.tir.ir_builder.create()
BA = ins[0]
BC = outs[0]
row = wmma_m * wmma_n
warp_index = BA.elem_offset // row + BA.elem_offset % row // wmma_n
ib.emit(tvm.tir.call_intrin('handle', 'tvm_store_matrix_sync',
BA.data, wmma_m, wmma_n, wmma_k, warp_index,
BC.access_ptr('w'), strides_dst[0], 'row_major'))
return ib.get()
return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})
def intrin_wmma_gemm(AL_gemm, WL_gemm, CL_compute, strides_A,
strides_W, strides_Conv, shape):
"""Intrin for wmma fill_fragment and mma_sync
Parameters
----------
AL_gemm : tvm.te.placeholder
wmma matrix A
WL_gemm : tvm.te.placeholder
wmma matrix B
CL_compute : tvm.te.compute
The definition of wmma gemm
"""
wmma_m, wmma_n, wmma_k = shape
A = AL_gemm
B = WL_gemm
C = CL_compute
BA = tvm.tir.decl_buffer(A.shape, A.dtype, name='BA',
scope='wmma.matrix_a', data_alignment=32,
offset_factor=8, strides=strides_A)
BB = tvm.tir.decl_buffer(B.shape, B.dtype, name='BB',
scope='wmma.matrix_b', data_alignment=32,
offset_factor=8, strides=strides_W)
BC = tvm.tir.decl_buffer(C.shape, C.dtype, name='BC',
scope='wmma.accumulator', data_alignment=32,
offset_factor=8, strides=strides_Conv)
def intrin_func(ins, outs):
BA, BB = ins
BC, = outs
def warp_idnex(offset, row, col):
row = row * col
return offset // row + offset % row // col
warp_index_A = warp_idnex(BA.elem_offset, wmma_m, wmma_k)
warp_index_B = warp_idnex(BB.elem_offset, wmma_k, wmma_n)
warp_index_C = warp_idnex(BC.elem_offset, wmma_m, wmma_n)
def init():
ib = tvm.tir.ir_builder.create()
ib.emit(
tvm.tir.call_intrin('handle', 'tvm_fill_fragment', BC.data, wmma_m, wmma_n, wmma_k,
warp_index_C, 0.0))
return ib.get()
def update():
ib = tvm.tir.ir_builder.create()
ib.emit(tvm.tir.call_intrin('handle', 'tvm_mma_sync',
BC.data, warp_index_C,
BA.data, warp_index_A,
BB.data, warp_index_B,
BC.data, warp_index_C))
return ib.get()
return update(), init(), update()
return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, B: BB, C: BC})
......@@ -27,7 +27,8 @@ from topi.util import get_const_tuple
_conv2d_nhwc_implement = {
"generic": (topi.nn.conv2d_nhwc, topi.generic.schedule_conv2d_nhwc),
"llvm": (topi.nn.conv2d_nhwc, topi.generic.schedule_conv2d_nhwc),
"cuda": (topi.cuda.conv2d_nhwc, topi.cuda.schedule_conv2d_nhwc),
"cpu": (topi.nn.conv2d_nhwc, topi.x86.schedule_conv2d_nhwc),
"arm_cpu": (topi.arm_cpu.conv2d_nhwc_spatial_pack,
topi.arm_cpu.schedule_conv2d_nhwc_spatial_pack),
......@@ -60,9 +61,9 @@ def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride, p
return
print("Running on target: %s" % device)
with tvm.target.create(device):
B = topi.nn.conv2d(A, W, (stride, stride), padding,
(dilation, dilation), layout='NHWC', out_dtype=dtype)
s = topi.generic.schedule_conv2d_nhwc([B])
fcompute, fschedule = topi.testing.dispatch(device, _conv2d_nhwc_implement)
B = fcompute(A, W, stride, padding, dilation, dtype)
s = fschedule([B])
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
......@@ -71,8 +72,7 @@ def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride, p
func(a, w, b)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
# TODO(@alexgl-github): add cuda back after fix conv2d_nhwc for cuda
for device in ['llvm']:
for device in ['llvm', 'cuda']:
check_device(device)
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, too-many-locals, too-many-arguments
"""Example code to do convolution."""
import numpy as np
import tvm
import topi
import topi.testing
from tvm import te
from tvm.contrib.pickle_memoize import memoize
from tvm.contrib import nvcc
from topi.nn.util import get_pad_tuple
from topi.util import get_const_tuple
_conv2d_nhwc_tensorcore_implement = {
"cuda": (topi.cuda.conv2d_nhwc_tensorcore, topi.cuda.schedule_conv2d_nhwc_tensorcore)
}
def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride,
padding, dilation=1, add_bias=False, add_relu=False, devices='cuda'):
"""Test the conv2d with tensorcore for nhwc layout"""
pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
padding_sum = pad_top + pad_left + pad_bottom + pad_right
print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (
batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
in_height = in_width = in_size
A = te.placeholder((batch, in_height, in_width, in_channel), name='A')
W = te.placeholder((kernel, kernel, in_channel, num_filter), name='W')
bias = te.placeholder((1, 1, 1, num_filter), name='bias')
a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
bias_shape = get_const_tuple(bias.shape)
dtype = A.dtype
@memoize("topi.tests.test_topi_conv2d_nhwc.verify_conv2d_nhwc")
def get_ref_data():
a_np = np.random.uniform(size=a_shape).astype(dtype)
w_np = np.random.uniform(size=w_shape).astype(dtype)
b_np = np.random.uniform(size=bias_shape).astype(dtype)
dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
c_np = topi.testing.conv2d_nhwc_python(a_np, dw_np, stride, padding)
if add_bias:
b_np = np.random.uniform(size=bias_shape).astype(dtype)
c_np += b_np
if add_relu:
c_np = np.maximum(c_np, 0)
return a_np, w_np, b_np, c_np
a_np, w_np, b_np, c_np = get_ref_data()
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
if not nvcc.have_tensorcore(ctx.compute_version):
print("skip because gpu does not support Tensor Cores")
return
print("Running on target: %s" % device)
with tvm.target.create(device):
fcompute, fschedule = topi.testing.dispatch(device, _conv2d_nhwc_tensorcore_implement)
C = fcompute(A, W, stride, padding, dilation, 'float32')
if add_bias:
C = topi.add(C, bias)
if add_relu:
C = topi.nn.relu(C)
s = fschedule([C])
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
if add_bias:
func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (
batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
func(a, w, b, c)
else:
func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (
batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
func(a, w, c)
rtol = 1e-3
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=rtol)
check_device(devices)
def test_conv2d_nhwc_tensorcore():
"""Test the conv2d with tensorcore for nhwc layout"""
verify_conv2d_nhwc(16, 16, 14, 16, 3, 1, 1)
verify_conv2d_nhwc(16, 128, 7, 128, 7, 1, 3)
verify_conv2d_nhwc(16, 160, 7, 160, 7, 1, 3)
verify_conv2d_nhwc(32, 64, 14, 64, 3, 1, 1, add_bias=True)
verify_conv2d_nhwc(32, 64, 14, 64, 3, 1, 1, add_relu=True)
verify_conv2d_nhwc(32, 64, 14, 64, 3, 1, 1, add_relu=True, add_bias=True)
verify_conv2d_nhwc(16, 64, 17, 64, 7, 1, (3, 3, 2, 2))
verify_conv2d_nhwc(16, 64, 17, 64, 7, 1, "SAME")
verify_conv2d_nhwc(16, 48, 35, 48, 5, 1, "VALID")
verify_conv2d_nhwc(16, 48, 56, 48, 3, 1, (1, 1, 1, 1))
verify_conv2d_nhwc(16, 64, 28, 64, 3, 1, (1, 1, 1, 1))
if __name__ == "__main__":
test_conv2d_nhwc_tensorcore()
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, too-many-locals, too-many-statements, unused-argument
"""Test code for dense tensorcore operator"""
import numpy as np
import tvm
import topi
import topi.testing
from topi.util import get_const_tuple
from tvm import te
from tvm.contrib.pickle_memoize import memoize
from tvm.contrib import nvcc
_dense_implement = {
"gpu": [(topi.cuda.dense_tensorcore, topi.cuda.schedule_dense_tensorcore)]
}
def verify_dense(batch, in_dim, out_dim, use_bias=True):
"""Dense tensorcore verify function"""
A = te.placeholder((batch, in_dim), name='A')
B = te.placeholder((out_dim, in_dim), name='B')
C = te.placeholder((out_dim,), name='C')
dtype = A.dtype
# use memoize to pickle the test data for next time use
@memoize("topi.tests.test_topi_dense_tensorcore")
def get_ref_data():
a_np = np.random.uniform(size=(batch, in_dim)).astype(dtype)
b_np = np.random.uniform(size=(out_dim, in_dim)).astype(dtype)
c_np = np.random.uniform(size=(out_dim,)).astype(dtype)
if use_bias:
d_np = np.maximum(np.dot(a_np, b_np.T) + c_np, 0.0)
else:
d_np = np.maximum(np.dot(a_np, b_np.T), 0.0)
return (a_np, b_np, c_np, d_np)
# get the test data
a_np, b_np, c_np, d_np = get_ref_data()
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
if not nvcc.have_tensorcore(ctx.compute_version):
print("skip because gpu does not support Tensor Cores")
return
print("Running on target: %s" % device)
for fcompute, fschedule in topi.testing.dispatch(device, _dense_implement):
with tvm.target.create(device):
D = fcompute(A, B, C if use_bias else None)
D = topi.nn.relu(D)
s = fschedule([D])
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(c_np, ctx)
d = tvm.nd.array(np.zeros(get_const_tuple(D.shape), dtype=dtype), ctx)
f = tvm.build(s, [A, B, C, D], device, name="dense")
f(a, b, c, d)
tvm.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-3)
for device in ['cuda']:
check_device(device)
def test_dense_tensorcore():
"""Test cases"""
verify_dense(8, 16, 32, use_bias=True)
verify_dense(16, 32, 16, use_bias=True)
verify_dense(256, 1024, 1024, use_bias=True)
verify_dense(1000, 1024, 1024, use_bias=False)
verify_dense(256, 2048, 1000, use_bias=False)
if __name__ == "__main__":
test_dense_tensorcore()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment