Unverified Commit 84121966 by Thomas Viehmann Committed by GitHub

rocm: fix miopen convolutions (#5179)

* fix miopen convolutions

* fix overly long lines
parent b776ff39
...@@ -56,8 +56,7 @@ def test_conv2d(): ...@@ -56,8 +56,7 @@ def test_conv2d():
yshape = [x.value for x in Y.shape] yshape = [x.value for x in Y.shape]
import topi import topi
with tvm.target.create("rocm -libs=miopen"): s = te.create_schedule(Y.op)
s = topi.generic.schedule_extern(Y)
def verify(): def verify():
ctx = tvm.rocm(0) ctx = tvm.rocm(0)
...@@ -67,10 +66,10 @@ def test_conv2d(): ...@@ -67,10 +66,10 @@ def test_conv2d():
y = tvm.nd.array(np.random.uniform(-1, 1, yshape).astype(np.float32), ctx) y = tvm.nd.array(np.random.uniform(-1, 1, yshape).astype(np.float32), ctx)
f(x, w, y) f(x, w, y)
Y_ref = topi.nn.conv2d_nchw(X, W, (stride_h, stride_w), (pad_h, pad_w), (dilation_h, dilation_w)) Y_ref = topi.nn.conv2d_nchw(X, W, (stride_h, stride_w), (pad_h, pad_w),
with tvm.target.rocm(): (dilation_h, dilation_w))
s_ref = topi.generic.schedule_conv2d_nchw([Y_ref]) s_ref = te.create_schedule(Y_ref.op)
f_ref = tvm.build(s_ref, [X, W, Y_ref], "rocm") f_ref = tvm.build(s_ref, [X, W, Y_ref], "rocm", target_host="llvm")
y_ref = tvm.nd.array(np.random.uniform(-1, 1, yshape).astype(np.float32), ctx) y_ref = tvm.nd.array(np.random.uniform(-1, 1, yshape).astype(np.float32), ctx)
f_ref(x, w, y_ref) f_ref(x, w, y_ref)
print("Max abs diff:", np.max(np.abs(y.asnumpy() - y_ref.asnumpy()))) print("Max abs diff:", np.max(np.abs(y.asnumpy() - y_ref.asnumpy())))
......
...@@ -24,7 +24,8 @@ from ..util import get_const_tuple ...@@ -24,7 +24,8 @@ from ..util import get_const_tuple
from ..nn.util import get_pad_tuple from ..nn.util import get_pad_tuple
@autotvm.register_topi_compute("conv2d_nchw_miopen.rocm") @autotvm.register_topi_compute("conv2d_nchw_miopen.rocm")
def conv2d_nchw_miopen(cfg, data, kernel, strides, padding, dilation, out_dtype='float32'): def conv2d_nchw_miopen(cfg, data, kernel, strides, padding, dilation,
layout='NCHW', out_dtype='float32'):
"""Conv2D operator for rocm backend. """Conv2D operator for rocm backend.
Parameters Parameters
...@@ -58,6 +59,8 @@ def conv2d_nchw_miopen(cfg, data, kernel, strides, padding, dilation, out_dtype= ...@@ -58,6 +59,8 @@ def conv2d_nchw_miopen(cfg, data, kernel, strides, padding, dilation, out_dtype=
CO, CI, KH, KW = get_const_tuple(kernel.shape) CO, CI, KH, KW = get_const_tuple(kernel.shape)
N, _, H, W = get_const_tuple(data.shape) N, _, H, W = get_const_tuple(data.shape)
assert layout == 'NCHW'
# handle dilation # handle dilation
stride_h, stride_w = (strides, strides) if isinstance(strides, int) else strides stride_h, stride_w = (strides, strides) if isinstance(strides, int) else strides
pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW)) pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment