Unverified Commit 84121966 by Thomas Viehmann Committed by GitHub

rocm: fix miopen convolutions (#5179)

* fix miopen convolutions

* fix overly long lines
parent b776ff39
......@@ -56,8 +56,7 @@ def test_conv2d():
yshape = [x.value for x in Y.shape]
import topi
with tvm.target.create("rocm -libs=miopen"):
s = topi.generic.schedule_extern(Y)
s = te.create_schedule(Y.op)
def verify():
ctx = tvm.rocm(0)
......@@ -67,10 +66,10 @@ def test_conv2d():
y = tvm.nd.array(np.random.uniform(-1, 1, yshape).astype(np.float32), ctx)
f(x, w, y)
Y_ref = topi.nn.conv2d_nchw(X, W, (stride_h, stride_w), (pad_h, pad_w), (dilation_h, dilation_w))
with tvm.target.rocm():
s_ref = topi.generic.schedule_conv2d_nchw([Y_ref])
f_ref = tvm.build(s_ref, [X, W, Y_ref], "rocm")
Y_ref = topi.nn.conv2d_nchw(X, W, (stride_h, stride_w), (pad_h, pad_w),
(dilation_h, dilation_w))
s_ref = te.create_schedule(Y_ref.op)
f_ref = tvm.build(s_ref, [X, W, Y_ref], "rocm", target_host="llvm")
y_ref = tvm.nd.array(np.random.uniform(-1, 1, yshape).astype(np.float32), ctx)
f_ref(x, w, y_ref)
print("Max abs diff:", np.max(np.abs(y.asnumpy() - y_ref.asnumpy())))
......
......@@ -24,7 +24,8 @@ from ..util import get_const_tuple
from ..nn.util import get_pad_tuple
@autotvm.register_topi_compute("conv2d_nchw_miopen.rocm")
def conv2d_nchw_miopen(cfg, data, kernel, strides, padding, dilation, out_dtype='float32'):
def conv2d_nchw_miopen(cfg, data, kernel, strides, padding, dilation,
layout='NCHW', out_dtype='float32'):
"""Conv2D operator for rocm backend.
Parameters
......@@ -58,6 +59,8 @@ def conv2d_nchw_miopen(cfg, data, kernel, strides, padding, dilation, out_dtype=
CO, CI, KH, KW = get_const_tuple(kernel.shape)
N, _, H, W = get_const_tuple(data.shape)
assert layout == 'NCHW'
# handle dilation
stride_h, stride_w = (strides, strides) if isinstance(strides, int) else strides
pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment