Unverified Commit 2dbe6261 by Thomas Viehmann Committed by GitHub

fix miopen pad (#5433)

parent 83930a3b
...@@ -36,6 +36,7 @@ def conv2d_strategy_rocm(attrs, inputs, out_type, target): ...@@ -36,6 +36,7 @@ def conv2d_strategy_rocm(attrs, inputs, out_type, target):
layout = attrs.data_layout layout = attrs.data_layout
stride_h, stride_w = attrs.get_int_tuple("strides") stride_h, stride_w = attrs.get_int_tuple("strides")
kernel_layout = attrs.kernel_layout kernel_layout = attrs.kernel_layout
padding = attrs.get_int_tuple("padding")
if dilation_h < 1 or dilation_w < 1: if dilation_h < 1 or dilation_w < 1:
raise ValueError("dilation should be positive value") raise ValueError("dilation should be positive value")
...@@ -77,7 +78,8 @@ def conv2d_strategy_rocm(attrs, inputs, out_type, target): ...@@ -77,7 +78,8 @@ def conv2d_strategy_rocm(attrs, inputs, out_type, target):
else: else:
raise RuntimeError("Unsupported conv2d layout {} for CUDA".format(layout)) raise RuntimeError("Unsupported conv2d layout {} for CUDA".format(layout))
# add miopen implementation # add miopen implementation
if "miopen" in target.libs and layout == "NCHW": if "miopen" in target.libs and layout == "NCHW" and padding[0] == padding[2] and \
padding[1] == padding[3]:
strategy.add_implementation( strategy.add_implementation(
wrap_compute_conv2d(topi.rocm.conv2d_nchw_miopen, True), wrap_compute_conv2d(topi.rocm.conv2d_nchw_miopen, True),
wrap_topi_schedule(topi.rocm.schedule_conv2d_nchw_miopen), wrap_topi_schedule(topi.rocm.schedule_conv2d_nchw_miopen),
......
...@@ -66,7 +66,7 @@ def conv2d_nchw_miopen(cfg, data, kernel, strides, padding, dilation, ...@@ -66,7 +66,7 @@ def conv2d_nchw_miopen(cfg, data, kernel, strides, padding, dilation,
pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW)) pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW))
pad_h, pad_w = pt + pb, pl + pr pad_h, pad_w = pt + pb, pl + pr
dilation_h, dilation_w = (dilation, dilation) if isinstance(dilation, int) else dilation dilation_h, dilation_w = (dilation, dilation) if isinstance(dilation, int) else dilation
assert (pt == pb) and (pl == pr)
OH = (H + 2 * pad_h - KH) // stride_h + 1 OH = (H + 2 * pad_h - KH) // stride_h + 1
OW = (W + 2 * pad_w - KW) // stride_w + 1 OW = (W + 2 * pad_w - KW) // stride_w + 1
cfg.add_flop(2 * N * OH * OW * CO * CI * ((KH - 1) * dilation_h + 1) *\ cfg.add_flop(2 * N * OH * OW * CO * CI * ((KH - 1) * dilation_h + 1) *\
...@@ -76,8 +76,8 @@ def conv2d_nchw_miopen(cfg, data, kernel, strides, padding, dilation, ...@@ -76,8 +76,8 @@ def conv2d_nchw_miopen(cfg, data, kernel, strides, padding, dilation,
kernel, kernel,
stride_h, stride_h,
stride_w, stride_w,
pad_h, pt,
pad_w, pl,
dilation_h, dilation_h,
dilation_w, dilation_w,
conv_mode=0, conv_mode=0,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment