Unverified Commit 3df8d560 by Josh Fromm Committed by GitHub

[Topi] Tensorcore support for Conv3D (#5284)

* one weird trick.

* Added schedule knob for different workloads.

* Initial conv3d tensorcore working.

* Added conv3d tensorcore strategy.

* Added layout conversion to tensorcore friendly format for conv2d and conv3d.

* Add target name check.

* Fixed bad names and depthwise check.

* Removed duplicated attribute assignment.
parent 0d48361a
......@@ -27,6 +27,7 @@ from .. import op as reg
from .. import strategy
from ..op import OpPattern
from .._tensor import elemwise_shape_func
from ..strategy.generic import is_depthwise_conv2d
# relu
reg.register_broadcast_schedule("nn.relu")
......@@ -139,13 +140,21 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layout):
# pylint: disable=import-outside-toplevel
from tvm import relay
data, weight = inputs
assert desired_layout == 'NCHW', \
"Currently only transformation to NCHW layout is supported."
new_attrs = dict(attrs)
new_attrs['data_layout'] = desired_layout
if desired_layout == 'NCHW':
new_attrs = dict(attrs)
new_attrs['data_layout'] = desired_layout
new_attrs['kernel_layout'] = 'OIHW'
return relay.nn.conv2d(data, weight, **new_attrs)
elif desired_layout == 'NHWC':
# Check for depthwise convolution.
if is_depthwise_conv2d(data.shape, attrs['data_layout'], weight.shape,
attrs['kernel_layout'], attrs['groups']):
new_attrs['kernel_layout'] = 'HWOI'
else:
new_attrs['kernel_layout'] = 'HWIO'
return relay.nn.conv2d(data, weight, **new_attrs)
else:
assert "Layout %s is not yet supported." % (desired_layout)
return None
......@@ -183,6 +192,41 @@ def alter_op_layout_conv3d(attrs, inputs, tinfos, out_type):
"""Alternate the layout of conv3d"""
return topi.nn.conv3d_alter_layout(attrs, inputs, tinfos, out_type)
@reg.register_convert_op_layout("nn.conv3d")
def convert_conv3d(attrs, inputs, tinfos, desired_layout):
"""Convert Layout pass registration for conv3d op.
Parameters
----------
attrs : tvm.ir.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
tinfos : list of types
List of input and output types
desired_layout : str
The desired layout
Returns
-------
result : tvm.relay.Expr
The transformed expr
"""
# pylint: disable=import-outside-toplevel
from tvm import relay
data, weight = inputs
new_attrs = dict(attrs)
new_attrs['data_layout'] = desired_layout
if desired_layout == 'NCDHW':
new_attrs['kernel_layout'] = 'OIDHW'
return relay.nn.conv3d(data, weight, **new_attrs)
elif desired_layout == "NDHWC":
new_attrs['kernel_layout'] = 'DHWIO'
return relay.nn.conv3d(data, weight, **new_attrs)
else:
assert "Layout %s is not yet supported" % desired_layout
return None
# conv3d_winograd related operators
reg.register_strategy("nn.contrib_conv3d_winograd_without_weight_transform",
strategy.conv3d_winograd_without_weight_transfrom_strategy)
......
......@@ -138,15 +138,16 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
name="conv2d_nhwc.cuda")
N, _, _, _ = get_const_tuple(data.shape)
_, _, CI, CO = get_const_tuple(kernel.shape)
if nvcc.have_tensorcore(tvm.gpu(0).compute_version):
if (N % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \
(N % 8 == 0 and CI % 16 == 0 and CO % 32 == 0) or \
(N % 32 == 0 and CI % 16 == 0 and CO % 8 == 0):
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_nhwc_tensorcore),
wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc_tensorcore),
name="conv2d_nhwc_tensorcore.cuda",
plevel=20)
if target.target_name == "cuda":
if nvcc.have_tensorcore(tvm.gpu(0).compute_version):
if (N % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \
(N % 8 == 0 and CI % 16 == 0 and CO % 32 == 0) or \
(N % 32 == 0 and CI % 16 == 0 and CO % 8 == 0):
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_nhwc_tensorcore),
wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc_tensorcore),
name="conv2d_nhwc_tensorcore.cuda",
plevel=20)
elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]:
assert kernel_layout == "OIHW4o4i"
strategy.add_implementation(
......@@ -170,7 +171,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.depthwise_conv2d_nchw),
wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nchw),
name="dpethwise_nchw.cuda")
name="depthwise_conv2d_nchw.cuda")
elif layout == "NHWC":
assert kernel_layout == "HWOI"
strategy.add_implementation(
......@@ -249,7 +250,7 @@ def conv2d_transpose_strategy_cuda(attrs, inputs, out_type, target):
def conv3d_strategy_cuda(attrs, inputs, out_type, target):
"""conv3d cuda strategy"""
strategy = _op.OpStrategy()
_, kernel = inputs
data, kernel = inputs
layout = attrs.data_layout
_, stride_h, stride_w = attrs.get_int_tuple("strides")
_, dilation_h, dilation_w = attrs.get_int_tuple("dilation")
......@@ -268,11 +269,25 @@ def conv3d_strategy_cuda(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.cuda.schedule_conv3d_ncdhw_winograd),
name="conv3d_ncdhw_winograd.cuda",
plevel=5)
else: # layout == "NDHWC":
strategy.add_implementation(wrap_compute_conv3d(topi.cuda.conv3d_ndhwc),
wrap_topi_schedule(topi.cuda.schedule_conv3d_ndhwc),
name="conv3d_ndhwc.cuda",
plevel=10)
else: # layout == "NDHWC":
strategy.add_implementation(
wrap_compute_conv3d(topi.cuda.conv3d_ndhwc),
wrap_topi_schedule(topi.cuda.schedule_conv3d_ndhwc),
name="conv3d_ndhwc.cuda",
plevel=10)
N, _, _, _, _ = get_const_tuple(data.shape)
_, _, _, CI, CO = get_const_tuple(kernel.shape)
if target.target_name == "cuda":
if nvcc.have_tensorcore(tvm.gpu(0).compute_version):
if (N % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \
(N % 8 == 0 and CI % 16 == 0 and CO % 32 == 0) or \
(N % 32 == 0 and CI % 16 == 0 and CO % 8 == 0):
strategy.add_implementation(
wrap_compute_conv3d(topi.cuda.conv3d_ndhwc_tensorcore),
wrap_topi_schedule(topi.cuda.schedule_conv3d_ndhwc_tensorcore),
name="conv3d_ndhwc_tensorcore.cuda",
plevel=20)
if target.target_name == "cuda" and "cudnn" in target.libs:
strategy.add_implementation(wrap_compute_conv3d(topi.cuda.conv3d_cudnn, True),
wrap_topi_schedule(topi.cuda.schedule_conv3d_cudnn),
......
......@@ -46,4 +46,5 @@ from .nms import get_valid_counts, non_max_suppression
from .rcnn import *
from .sort import *
from .conv2d_nhwc_tensorcore import *
from .conv3d_ndhwc_tensorcore import *
from .dense_tensorcore import *
......@@ -70,7 +70,7 @@ def nhwc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, out_dtyp
# convert data type of input feature maps and weights
TransPaddedInput = te.compute(
PaddedInput.shape,
lambda h, w, i, o: PaddedInput[h, w, i, o].astype('float16'))
lambda n, h, w, c: PaddedInput[n, h, w, c].astype('float16'))
TransFilter = te.compute(
Filter.shape, lambda h, w, i, o: Filter[h, w, i, o].astype('float16'))
Output = te.compute(
......
......@@ -493,13 +493,17 @@ def schedule_winograd_no_depth_cuda(cfg, s, output, pre_computed):
BB = s.cache_read(B0, 'shared', [OL])
b = s[bgemm].fuse(b1, b2)
y = s[bgemm].fuse(z, y)
# Allow two different tiling strategies as both seem
# to work best in different cases.
cfg.define_knob("unroll_axis", [0, 1])
# tile and bind spatial axes
bgemm_scope, b = s[bgemm].split(b, nparts=1)
bz, vz, tz, zi = cfg["tile_b"].apply(s, C, b)
by, vy, ty, yi = cfg["tile_y"].apply(s, C, y)
bx, vx, tx, xi = cfg["tile_x"].apply(s, C, x)
by, vy, ty, yi = cfg["tile_y"].apply(s, C, z)
if cfg['unroll_axis'].val:
bx, vx, tx, xi = cfg["tile_x"].apply(s, C, y)
else:
bx, vx, tx, xi = cfg["tile_x"].apply(s, C, x)
s[C].bind(bz, te.thread_axis("blockIdx.z"))
s[C].bind(by, te.thread_axis("blockIdx.y"))
s[C].bind(bx, te.thread_axis("blockIdx.x"))
......@@ -510,6 +514,10 @@ def schedule_winograd_no_depth_cuda(cfg, s, output, pre_computed):
s[C].bind(ty, te.thread_axis("threadIdx.y"))
s[C].bind(tx, te.thread_axis("threadIdx.x"))
s[C].reorder(bgemm_scope, bz, by, bx, vz, vy, vx, tz, ty, tx, zi, yi, xi)
if cfg['unroll_axis'].val:
s[C].unroll(x)
else:
s[C].unroll(y)
# tile reduction axes
s[OL].compute_at(s[C], tx)
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, too-many-locals, too-many-arguments
"""Example code to do convolution."""
import numpy as np
import tvm
import topi
import topi.testing
from tvm import te
from tvm.contrib.pickle_memoize import memoize
from tvm.contrib import nvcc
from topi.nn.util import get_pad_tuple3d
from topi.util import get_const_tuple
_conv3d_ndhwc_tensorcore_implement = {
"cuda": (topi.cuda.conv3d_ndhwc_tensorcore, topi.cuda.schedule_conv3d_ndhwc_tensorcore)
}
def verify_conv3d_ndhwc(batch, in_channel, in_size, num_filter, kernel, stride,
padding, dilation=1, add_bias=False, add_relu=False, devices='cuda'):
"""Test the conv3d with tensorcore for ndhwc layout"""
pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right = get_pad_tuple3d(
padding, (kernel, kernel, kernel))
padding_sum = pad_front + pad_top + pad_left + pad_back + pad_bottom + pad_right
print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (
batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
in_depth = in_height = in_width = in_size
A = te.placeholder((batch, in_depth, in_height, in_width, in_channel), name='A')
W = te.placeholder((kernel, kernel, kernel, in_channel, num_filter), name='W')
bias = te.placeholder((1, 1, 1, 1, num_filter), name='bias')
a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
bias_shape = get_const_tuple(bias.shape)
dtype = A.dtype
@memoize("topi.tests.test_topi_conv3d_ndhwc.verify_conv3d_ndhwc")
def get_ref_data():
a_np = np.random.uniform(size=a_shape).astype(dtype)
w_np = np.random.uniform(size=w_shape).astype(dtype)
b_np = np.random.uniform(size=bias_shape).astype(dtype)
dw_np = topi.testing.dilate_python(w_np, (1, 1, 1, dilation, dilation))
c_np = topi.testing.conv3d_ndhwc_python(a_np, dw_np, stride, padding)
if add_bias:
b_np = np.random.uniform(size=bias_shape).astype(dtype)
c_np += b_np
if add_relu:
c_np = np.maximum(c_np, 0)
return a_np, w_np, b_np, c_np
a_np, w_np, b_np, c_np = get_ref_data()
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
if not nvcc.have_tensorcore(ctx.compute_version):
print("skip because gpu does not support Tensor Cores")
return
print("Running on target: %s" % device)
with tvm.target.create(device):
fcompute, fschedule = topi.testing.dispatch(device, _conv3d_ndhwc_tensorcore_implement)
C = fcompute(A, W, stride, padding, dilation, 'float32')
if add_bias:
C = topi.add(C, bias)
if add_relu:
C = topi.nn.relu(C)
s = fschedule([C])
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
if add_bias:
func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (
batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
func(a, w, b, c)
else:
func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (
batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
func(a, w, c)
rtol = 1e-3
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=rtol)
check_device(devices)
def test_conv3d_ndhwc_tensorcore():
"""Test the conv3d with tensorcore for ndhwc layout"""
verify_conv3d_ndhwc(16, 16, 14, 16, 3, 1, 1)
verify_conv3d_ndhwc(16, 64, 7, 64, 7, 1, 3)
verify_conv3d_ndhwc(16, 32, 7, 32, 7, 1, 3)
verify_conv3d_ndhwc(32, 16, 14, 16, 3, 1, 1, add_bias=True)
verify_conv3d_ndhwc(32, 16, 14, 16, 3, 1, 1, add_relu=True)
verify_conv3d_ndhwc(32, 16, 14, 16, 3, 1, 1, add_relu=True, add_bias=True)
verify_conv3d_ndhwc(16, 16, 17, 16, 7, 1, (3, 3, 3, 2, 2, 2))
verify_conv3d_ndhwc(16, 16, 17, 16, 7, 1, "SAME")
verify_conv3d_ndhwc(8, 16, 35, 32, 5, 1, "VALID")
verify_conv3d_ndhwc(16, 32, 16, 32, 3, 1, (1, 1, 1, 1, 1, 1))
verify_conv3d_ndhwc(16, 16, 12, 16, 3, 1, (1, 1, 1, 1, 1, 1))
if __name__ == "__main__":
test_conv3d_ndhwc_tensorcore()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment