Commit 85e4058c by masahi Committed by Tianqi Chen

[TOPI] CUDNN integration (#730)

* add target.libs to target str representation

* integrate cudnn into topi cuda

* append target.libs to target.options
parent cdb2f873
...@@ -88,17 +88,15 @@ class Target(object): ...@@ -88,17 +88,15 @@ class Target(object):
target_name, target_name,
options=None): options=None):
self.target_name = target_name self.target_name = target_name
self.options = [] self.options = _merge_opts([], options)
self.device_name = "" self.device_name = ""
self.libs = [] self.libs = []
# Parse device option # Parse device option
for item in _merge_opts([], options): for item in self.options:
if item.startswith("-libs="): if item.startswith("-libs="):
self.libs.append(item.split("=")[1]) self.libs.append(item.split("=")[1])
continue
elif item.startswith("-device="): elif item.startswith("-device="):
self.device_name = item.split("=")[1] self.device_name = item.split("=")[1]
self.options.append(item)
# Target query searchs device name first # Target query searchs device name first
if self.device_name: if self.device_name:
self.keys = (self.device_name,) self.keys = (self.device_name,)
......
...@@ -82,7 +82,7 @@ GetLLVMTargetMachine(const std::string& target_str, ...@@ -82,7 +82,7 @@ GetLLVMTargetMachine(const std::string& target_str,
} else { } else {
LOG(FATAL) << "invalid -mfloat-abi option " << value; LOG(FATAL) << "invalid -mfloat-abi option " << value;
} }
} else if (key == "-device") { } else if (key == "-device" || key == "-libs") {
// pass // pass
} else { } else {
LOG(FATAL) << "unknown option " << key; LOG(FATAL) << "unknown option " << key;
......
...@@ -41,6 +41,7 @@ def test_conv2d(): ...@@ -41,6 +41,7 @@ def test_conv2d():
tensor_format=0, tensor_format=0,
algo=1) algo=1)
yshape = [x.value for x in Y.shape] yshape = [x.value for x in Y.shape]
with tvm.target.create("cuda -libs=cudnn"):
s = tvm.create_schedule(Y.op) s = tvm.create_schedule(Y.op)
def verify(): def verify():
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
"""CUDA specific declaration and schedules.""" """CUDA specific declaration and schedules."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from .conv2d import conv2d_cuda
from .conv2d_nchw import schedule_conv2d_nchw from .conv2d_nchw import schedule_conv2d_nchw
from .conv2d_hwcn import schedule_conv2d_hwcn from .conv2d_hwcn import schedule_conv2d_hwcn
from .depthwise_conv2d import schedule_depthwise_conv2d_nchw, schedule_depthwise_conv2d_nhwc from .depthwise_conv2d import schedule_depthwise_conv2d_nchw, schedule_depthwise_conv2d_nhwc
......
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-statements, too-many-arguments, too-many-branches, line-too-long
"""Compute definition for conv2d with cuda backend"""
import tvm
from tvm.contrib import cudnn
import topi
from ..nn.conv2d import conv2d
@conv2d.register("cuda")
def conv2d_cuda(data, kernel, stride, padding, layout='NCHW', out_dtype='float32'):
"""Conv2D operator for cuda backend.
Parameters
----------
input : tvm.Tensor
4-D with shape [batch, in_channel, in_height, in_width]
filter : tvm.Tensor
4-D with shape [num_filter, in_channel, filter_height, filter_width]
stride : int or a list/tuple of two ints
stride size, or [stride_height, stride_width]
padding : int or a list/tuple of two ints
padding size, or [pad_height, pad_width]
layout : str
layout of data
Returns
-------
output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
"""
assert isinstance(stride, int) or len(stride) == 2
if isinstance(stride, int):
stride_h = stride_w = stride
else:
stride_h, stride_w = stride
if isinstance(padding, int):
pad_h = pad_w = padding
else:
pad_h, pad_w = padding
target = tvm.target.current_target()
if "cudnn" in target.libs:
assert layout != 'HWCN', "HWCN layout not supported with CUDNN."
tensor_format = 0 # CUDNN_TENSOR_NCHW
if layout == 'NHWC':
tensor_format = 1 # CUDNN_TENSOR_NHWC
return cudnn.conv2d_forward(data,
kernel,
stride_h,
stride_w,
pad_h,
pad_w,
1, # dilation_h
1, # dilation_w
conv_mode=1,
tensor_format=tensor_format,
algo=0)
elif layout == 'NCHW':
return topi.nn.conv2d_nchw(data, kernel, stride, padding, out_dtype)
elif layout == 'HWCN':
return topi.nn.conv2d_hwcn(data, kernel, stride, padding, out_dtype)
else:
raise ValueError("not support this layout {} yet".format(layout))
#pylint: disable=invalid-name, no-member, too-many-locals, too-many-statements, too-many-arguments, too-many-branches, line-too-long #pylint: disable=invalid-name, no-member, too-many-locals, too-many-statements, too-many-arguments, too-many-branches, line-too-long
"""Schedule for conv2d_nchw with auto fusion""" """Schedule for conv2d_nchw with auto fusion"""
import tvm import tvm
import topi
from .. import util from .. import util
from .. import tag from .. import tag
from .. import generic from .. import generic
...@@ -516,6 +517,10 @@ def schedule_conv2d_nchw(outs): ...@@ -516,6 +517,10 @@ def schedule_conv2d_nchw(outs):
s: Schedule s: Schedule
The computation schedule for conv2d_nchw. The computation schedule for conv2d_nchw.
""" """
target = tvm.target.current_target()
if target.target_name == "cuda" and "cudnn" in target.libs:
return topi.generic.schedule_extern(outs)
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
batch_size = util.get_const_int(outs[0].op.output(0).shape[0]) batch_size = util.get_const_int(outs[0].op.output(0).shape[0])
if batch_size > 1: if batch_size > 1:
......
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-statements, too-many-arguments, too-many-branches, line-too-long # pylint: disable=invalid-name, no-member, too-many-locals, too-many-statements, too-many-arguments, too-many-branches, line-too-long
"""Schedule for rocm conv2d_nchw with auto fusion""" """Compute and schedule for rocm conv2d_nchw with auto fusion"""
import tvm import tvm
from tvm.contrib import miopen from tvm.contrib import miopen
import topi import topi
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment