Unverified Commit 2dbe6261 by Thomas Viehmann Committed by GitHub

fix miopen pad (#5433)

parent 83930a3b
......@@ -36,6 +36,7 @@ def conv2d_strategy_rocm(attrs, inputs, out_type, target):
layout = attrs.data_layout
stride_h, stride_w = attrs.get_int_tuple("strides")
kernel_layout = attrs.kernel_layout
padding = attrs.get_int_tuple("padding")
if dilation_h < 1 or dilation_w < 1:
raise ValueError("dilation should be positive value")
......@@ -77,7 +78,8 @@ def conv2d_strategy_rocm(attrs, inputs, out_type, target):
else:
raise RuntimeError("Unsupported conv2d layout {} for CUDA".format(layout))
# add miopen implementation
if "miopen" in target.libs and layout == "NCHW":
if "miopen" in target.libs and layout == "NCHW" and padding[0] == padding[2] and \
padding[1] == padding[3]:
strategy.add_implementation(
wrap_compute_conv2d(topi.rocm.conv2d_nchw_miopen, True),
wrap_topi_schedule(topi.rocm.schedule_conv2d_nchw_miopen),
......
......@@ -66,7 +66,7 @@ def conv2d_nchw_miopen(cfg, data, kernel, strides, padding, dilation,
pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW))
pad_h, pad_w = pt + pb, pl + pr
dilation_h, dilation_w = (dilation, dilation) if isinstance(dilation, int) else dilation
assert (pt == pb) and (pl == pr)
OH = (H + 2 * pad_h - KH) // stride_h + 1
OW = (W + 2 * pad_w - KW) // stride_w + 1
cfg.add_flop(2 * N * OH * OW * CO * CI * ((KH - 1) * dilation_h + 1) *\
......@@ -76,8 +76,8 @@ def conv2d_nchw_miopen(cfg, data, kernel, strides, padding, dilation,
kernel,
stride_h,
stride_w,
pad_h,
pad_w,
pt,
pl,
dilation_h,
dilation_w,
conv_mode=0,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment