Commit 7211c277 by Cody Hao Yu Committed by Leyuan Wang

[TOPI] Fix bug in Winograd on CUDA (#4260)

* fix winograd

* move get padding after kernel transform
parent ddaa9530
......@@ -17,6 +17,7 @@
# pylint: disable=invalid-name,unused-variable,unused-argument
"""Winograd template for cuda backend"""
import logging
import tvm
from tvm import autotvm
......@@ -27,6 +28,8 @@ from ..generic import schedule_conv2d_winograd_without_weight_transform
from ..nn.winograd_util import winograd_transform_matrices
logger = logging.getLogger('conv2d_winograd')
def _infer_tile_size(data, kernel):
N, CI, H, W = get_const_tuple(data.shape)
......@@ -42,26 +45,25 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dty
N, CI, H, W = get_const_tuple(data.shape)
if isinstance(dilation, int):
dilation_h = dilation_w = dilation
else:
dilation_h, dilation_w = dilation
HSTR, WSTR = (strides, strides) if isinstance(strides, int) else strides
if not pre_computed: # kernel tensor is raw tensor, do strict check
if isinstance(dilation, int):
dilation_h = dilation_w = dilation
else:
dilation_h, dilation_w = dilation
if dilation_h != 1 or dilation_w != 1:
kernel = dilate(kernel, (1, 1, dilation_h, dilation_w))
kernel = dilation(kernel, (1, 1, dilation_h, dilation_w))
CO, CI, KH, KW = get_const_tuple(kernel.shape)
HPAD, WPAD, _, _ = nn.get_pad_tuple(padding, kernel)
HSTR, WSTR = (strides, strides) if isinstance(strides, int) else strides
assert HSTR == 1 and WSTR == 1 and KH == KW
else: # kernel tensor is pre-transfomred. this op is created by
# alter op layout, do not check
else:
# kernel tensor is pre-transfomred. this op is created by alter op layout.
# dilation is not supported
HSTR = WSTR = 1
HPAD = WPAD = 1
KH = KW = 3
_, _, CI, CO = get_const_tuple(kernel.shape)
KH = KW = 3
assert HSTR == 1 and WSTR == 1 and dilation_h == 1 and dilation_w == 1
HPAD, WPAD, _, _ = nn.get_pad_tuple(padding, kernel)
data_pad = nn.pad(data, (0, 0, HPAD, WPAD), (0, 0, HPAD, WPAD), name="data_pad")
r = KW
......@@ -384,7 +386,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, F):
return F.nn.conv2d(*copy_inputs, **new_attrs)
if attrs.get_int_tuple("dilation") != (1, 1):
warnings.warn("Does not support weight pre-transform for dilated convolution.")
logger.warning("Does not support weight pre-transform for dilated convolution.")
return None
# pre-compute weight transformation in winograd
......
......@@ -40,4 +40,5 @@ class Int8Fallback(autotvm.FallbackContext):
cfg = FallbackConfigEntity()
cfg.template_key = 'int8'
self.memory[key] = cfg
cfg.is_fallback = False
return cfg
......@@ -99,6 +99,7 @@ class WinogradFallback(autotvm.FallbackContext):
cfg = FallbackConfigEntity()
cfg.template_key = 'winograd'
self.memory[key] = cfg
cfg.is_fallback = False
return cfg
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment