Commit cdb2f873 by masahi Committed by Tianqi Chen

[TOPI] add extern schedule for cudnn and miopen (#724)

* add extern schedule for miopen

* fix comment

* optionally dispatch to miopen from topi

* fix lint

* check if current target is None

* use generic dispatch for rocm conv2d

* fix lint

* fix workspace bug

* remove blank line

* remove blank line

* remove blank line
parent f9acdd77
......@@ -105,6 +105,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup")
const int request_algo_count = 4;
const bool exhaustive_search = false;
void* workspace = entry_ptr->conv_entry.workspace;
if (workspace_size == 0) workspace = nullptr;
int returned_algo_count = 0;
miopenConvAlgoPerf_t perfs[4];
......@@ -119,7 +121,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup")
request_algo_count,
&returned_algo_count,
perfs,
entry_ptr->conv_entry.workspace,
workspace,
workspace_size,
exhaustive_search));
......
......@@ -4,8 +4,8 @@ import numpy as np
def test_conv2d():
in_channel = 64
out_channel = 128
in_channel = 3
out_channel = 64
filter_h = 3
filter_w = 3
pad_h = 1
......@@ -15,7 +15,7 @@ def test_conv2d():
dilation_h = 1
dilation_w = 1
xshape = [1, in_channel, 64, 64]
xshape = [1, in_channel, 128, 128]
if not tvm.module.enabled("rocm"):
print("skip because rocm is not enabled...")
return
......@@ -37,7 +37,9 @@ def test_conv2d():
conv_mode=0)
yshape = [x.value for x in Y.shape]
s = tvm.create_schedule(Y.op)
import topi
with tvm.target.create("rocm -libs=miopen"):
s = topi.generic.schedule_extern(Y)
def verify():
ctx = tvm.rocm(0)
......@@ -47,7 +49,6 @@ def test_conv2d():
y = tvm.nd.array(np.random.uniform(-1, 1, yshape).astype(np.float32), ctx)
f(x, w, y)
import topi
Y_ref = topi.nn.conv2d_nchw(X, W, (stride_h, stride_w), (pad_h, pad_w))
with tvm.target.rocm():
s_ref = topi.generic.schedule_conv2d_nchw([Y_ref])
......
......@@ -19,3 +19,4 @@ from . import cuda
from . import rasp
from . import testing
from . import util
from . import rocm
......@@ -13,3 +13,4 @@ from .injective import schedule_injective, schedule_elemwise, schedule_broadcast
from .dense import schedule_dense
from .pooling import schedule_pool, schedule_global_pool
from .conv2d_transpose_nchw import schedule_conv2d_transpose_nchw
from .extern import schedule_extern
# pylint: disable=invalid-name, unused-variable,
"""Schedule for cudnn and miopen extern op"""
import tvm
from .. import generic
def _schedule_output(op, sch):
x = op.output(0)
fused = sch[x].fuse(*sch[x].op.axis)
num_thread = tvm.target.current_target(allow_none=False).max_num_threads
bx, tx = sch[x].split(fused, factor=num_thread)
sch[x].bind(bx, tvm.thread_axis("blockIdx.x"))
sch[x].bind(tx, tvm.thread_axis("threadIdx.x"))
return sch
@generic.schedule_extern.register(["cuda", "gpu"])
def schedule_extern(outs):
"""Schedule for an extern op followed by injective operations.
For example, cudnn kernel + bias add + relu.
Parameters
----------
outs: Array of Tensor
The computation graph description of extern plus injective ops in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
tvm.schedule.AutoInlineInjective(s)
for out in outs:
if isinstance(out.op, tvm.tensor.ExternOp):
continue
_schedule_output(out.op, s)
return s
......@@ -17,3 +17,4 @@ from __future__ import absolute_import as _abs
from .nn import *
from .injective import *
from .extern import *
# pylint: disable=invalid-name
"""generic declaration and schedules."""
from __future__ import absolute_import as _abs
import tvm
@tvm.target.generic_func
def schedule_extern(outs):
"""Schedule for an extern op followed by injective operations.
Parameters
----------
outs: Array of Tensor
The computation graph description of extern plus injective ops in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
target = tvm.target.current_target(allow_none=False)
if target.target_name != "llvm":
raise RuntimeError("schedule_injective not registered for '%s'" % target)
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
return tvm.create_schedule([x.op for x in outs])
# pylint: disable=redefined-builtin, wildcard-import
"""rocm specific declaration and schedules."""
from __future__ import absolute_import as _abs
from .conv2d import *
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-statements, too-many-arguments, too-many-branches, line-too-long
"""Schedule for rocm conv2d_nchw with auto fusion"""
import tvm
from tvm.contrib import miopen
import topi
from .. import generic
from ..nn.conv2d import conv2d
@conv2d.register("rocm")
def conv2d_rocm(data, kernel, stride, padding, layout='NCHW', out_dtype='float32'):
"""Conv2D operator for rocm backend.
Parameters
----------
input : tvm.Tensor
4-D with shape [batch, in_channel, in_height, in_width]
filter : tvm.Tensor
4-D with shape [num_filter, in_channel, filter_height, filter_width]
stride : int or a list/tuple of two ints
stride size, or [stride_height, stride_width]
padding : int or a list/tuple of two ints
padding size, or [pad_height, pad_width]
layout : str
layout of data
Returns
-------
output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
"""
assert layout == 'NCHW', "Only NCHW layout is supported."
assert isinstance(stride, int) or len(stride) == 2
if isinstance(stride, int):
stride_h = stride_w = stride
else:
stride_h, stride_w = stride
if isinstance(padding, int):
pad_h = pad_w = padding
else:
pad_h, pad_w = padding
target = tvm.target.current_target()
if "miopen" in target.libs:
return miopen.conv2d_forward(data,
kernel,
stride_h,
stride_w,
pad_h,
pad_w,
1, # dilation_h
1, # dilation_w
conv_mode=0)
return topi.nn.conv2d_nchw(data, kernel, stride, padding, out_dtype)
@generic.schedule_conv2d_nchw.register(["rocm"])
def schedule_conv2d_nchw(outs):
"""Schedule for conv2d_nchw with rocm backend.
Parameters
----------
outs: Array of Tensor
The computation graph description of conv2d_nchw
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for conv2d_nchw.
"""
target = tvm.target.current_target()
if target and "miopen" in target.libs:
return topi.generic.schedule_extern(outs)
return topi.cuda.schedule_conv2d_nchw(outs)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment