Unverified Commit fcf8420a by Thomas Viehmann Committed by GitHub

fix ROCm strategy for winograd conv selection (#5001)

parent de346493
...@@ -48,12 +48,13 @@ def conv2d_strategy_rocm(attrs, inputs, out_type, target): ...@@ -48,12 +48,13 @@ def conv2d_strategy_rocm(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw), wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw),
name="conv2d_nchw.cuda") name="conv2d_nchw.cuda")
_, _, kh, kw = get_const_tuple(kernel.shape) _, _, kh, kw = get_const_tuple(kernel.shape)
if kh <= 7 and kw <= 7 and kh == kw and stride_h == 1 and stride_w == 1: if 2 < kh < 8 and 2 < kw < 8 and kh == kw and stride_h == 1 and stride_w == 1 and \
dilation_h == 1 and dilation_w == 1:
strategy.add_implementation( strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_nchw_winograd), wrap_compute_conv2d(topi.cuda.conv2d_nchw_winograd),
wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw_winograd), wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw_winograd),
name="conv2d_nchw_winograd.cuda", name="conv2d_nchw_winograd.cuda",
plevel=15) plevel=5)
elif layout == "HWCN": elif layout == "HWCN":
assert kernel_layout == "HWIO" assert kernel_layout == "HWIO"
strategy.add_implementation( strategy.add_implementation(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment