Commit 85e4058c by masahi Committed by Tianqi Chen

[TOPI] CUDNN integration (#730)

* add target.libs to target str representation

* integrate cudnn into topi cuda

* append target.libs to target.options
parent cdb2f873
......@@ -88,17 +88,15 @@ class Target(object):
target_name,
options=None):
self.target_name = target_name
self.options = []
self.options = _merge_opts([], options)
self.device_name = ""
self.libs = []
# Parse device option
for item in _merge_opts([], options):
for item in self.options:
if item.startswith("-libs="):
self.libs.append(item.split("=")[1])
continue
elif item.startswith("-device="):
self.device_name = item.split("=")[1]
self.options.append(item)
# Target query searchs device name first
if self.device_name:
self.keys = (self.device_name,)
......
......@@ -82,7 +82,7 @@ GetLLVMTargetMachine(const std::string& target_str,
} else {
LOG(FATAL) << "invalid -mfloat-abi option " << value;
}
} else if (key == "-device") {
} else if (key == "-device" || key == "-libs") {
// pass
} else {
LOG(FATAL) << "unknown option " << key;
......
......@@ -41,7 +41,8 @@ def test_conv2d():
tensor_format=0,
algo=1)
yshape = [x.value for x in Y.shape]
s = tvm.create_schedule(Y.op)
with tvm.target.create("cuda -libs=cudnn"):
s = tvm.create_schedule(Y.op)
def verify():
ctx = tvm.gpu(0)
......
......@@ -2,6 +2,7 @@
"""CUDA specific declaration and schedules."""
from __future__ import absolute_import as _abs
from .conv2d import conv2d_cuda
from .conv2d_nchw import schedule_conv2d_nchw
from .conv2d_hwcn import schedule_conv2d_hwcn
from .depthwise_conv2d import schedule_depthwise_conv2d_nchw, schedule_depthwise_conv2d_nhwc
......
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-statements, too-many-arguments, too-many-branches, line-too-long
"""Compute definition for conv2d with cuda backend"""
import tvm
from tvm.contrib import cudnn
import topi
from ..nn.conv2d import conv2d
@conv2d.register("cuda")
def conv2d_cuda(data, kernel, stride, padding, layout='NCHW', out_dtype='float32'):
"""Conv2D operator for cuda backend.
Parameters
----------
input : tvm.Tensor
4-D with shape [batch, in_channel, in_height, in_width]
filter : tvm.Tensor
4-D with shape [num_filter, in_channel, filter_height, filter_width]
stride : int or a list/tuple of two ints
stride size, or [stride_height, stride_width]
padding : int or a list/tuple of two ints
padding size, or [pad_height, pad_width]
layout : str
layout of data
Returns
-------
output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
"""
assert isinstance(stride, int) or len(stride) == 2
if isinstance(stride, int):
stride_h = stride_w = stride
else:
stride_h, stride_w = stride
if isinstance(padding, int):
pad_h = pad_w = padding
else:
pad_h, pad_w = padding
target = tvm.target.current_target()
if "cudnn" in target.libs:
assert layout != 'HWCN', "HWCN layout not supported with CUDNN."
tensor_format = 0 # CUDNN_TENSOR_NCHW
if layout == 'NHWC':
tensor_format = 1 # CUDNN_TENSOR_NHWC
return cudnn.conv2d_forward(data,
kernel,
stride_h,
stride_w,
pad_h,
pad_w,
1, # dilation_h
1, # dilation_w
conv_mode=1,
tensor_format=tensor_format,
algo=0)
elif layout == 'NCHW':
return topi.nn.conv2d_nchw(data, kernel, stride, padding, out_dtype)
elif layout == 'HWCN':
return topi.nn.conv2d_hwcn(data, kernel, stride, padding, out_dtype)
else:
raise ValueError("not support this layout {} yet".format(layout))
#pylint: disable=invalid-name, no-member, too-many-locals, too-many-statements, too-many-arguments, too-many-branches, line-too-long
"""Schedule for conv2d_nchw with auto fusion"""
import tvm
import topi
from .. import util
from .. import tag
from .. import generic
......@@ -516,6 +517,10 @@ def schedule_conv2d_nchw(outs):
s: Schedule
The computation schedule for conv2d_nchw.
"""
target = tvm.target.current_target()
if target.target_name == "cuda" and "cudnn" in target.libs:
return topi.generic.schedule_extern(outs)
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
batch_size = util.get_const_int(outs[0].op.output(0).shape[0])
if batch_size > 1:
......
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-statements, too-many-arguments, too-many-branches, line-too-long
"""Schedule for rocm conv2d_nchw with auto fusion"""
"""Compute and schedule for rocm conv2d_nchw with auto fusion"""
import tvm
from tvm.contrib import miopen
import topi
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment