Commit c1008ec4 by Yuwei Hu Committed by Tianqi Chen

[TOPI] fix weight layout in conv2d_transpose (#616)

parent 8214d6ca
......@@ -18,7 +18,7 @@ def conv2d_transpose_nchw(Input, Filter, strides, padding):
4-D with shape [batch, in_channel, in_height, in_width]
Filter : tvm.Tensor
4-D with shape [num_filter, in_channel, filter_height, filter_width]
4-D with shape [in_channel, num_filter, filter_height, filter_width]
strides : tuple of two ints
The spatial stride along height and width
......@@ -32,7 +32,7 @@ def conv2d_transpose_nchw(Input, Filter, strides, padding):
4-D with shape [batch, out_channel, out_height, out_width]
"""
batch, in_c, in_h, in_w = Input.shape
out_c, _, filter_h, filter_w = Filter.shape
_, out_c, filter_h, filter_w = Filter.shape
stride_h, stride_w = strides
# dilate stage
DilatedInput = dilate(Input, [1, 1, stride_h, stride_w], name='DilatedInput')
......@@ -57,7 +57,7 @@ def conv2d_transpose_nchw(Input, Filter, strides, padding):
Output = tvm.compute(
(batch, out_c, out_h, out_w),
lambda b, c, h, w: tvm.sum(
PaddedInput[b, dc, h+dh, w+dw] * Filter[c, dc, filter_h-1-dh, filter_w-1-dw],
PaddedInput[b, dc, h+dh, w+dw] * Filter[dc, c, filter_h-1-dh, filter_w-1-dw],
axis=[dc, dh, dw]), tag="conv2d_transpose_nchw")
return Output
# pylint: disable=unused-variable
"""Transposed convolution in python"""
import numpy as np
import scipy
import topi
from topi.nn.util import get_pad_tuple
......@@ -14,7 +15,7 @@ def conv2d_transpose_nchw_python(a_np, w_np, stride, padding):
4-D with shape [batch, in_channel, in_height, in_width]
w_np : numpy.ndarray
4-D with shape [num_filter, in_channel, filter_height, filter_width]
4-D with shape [in_channel, num_filter, filter_height, filter_width]
stride : int or a list/tuple of two ints
Stride size, or [stride_height, stride_width]
......@@ -28,7 +29,7 @@ def conv2d_transpose_nchw_python(a_np, w_np, stride, padding):
4-D with shape [batch, out_channel, out_height, out_width]
"""
batch, in_c, in_h, in_w = a_np.shape
out_c, _, filter_h, filter_w = w_np.shape
_, out_c, filter_h, filter_w = w_np.shape
if isinstance(stride, int):
stride_h = stride_w = stride
else:
......@@ -46,6 +47,13 @@ def conv2d_transpose_nchw_python(a_np, w_np, stride, padding):
padded_a_np[:, :, bpad_top:dilated_a_np.shape[2]+bpad_top, \
bpad_left:dilated_a_np.shape[3]+bpad_left] = dilated_a_np
# convolution stage
rotated_w_np = np.rot90(w_np, k=2, axes=(2, 3))
b_np = topi.testing.conv2d_nchw_python(padded_a_np, rotated_w_np, stride=1, padding='VALID')
out_h = (in_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h
out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w
b_np = np.zeros((batch, out_c, out_h, out_w))
for n in range(batch):
for f in range(out_c):
for c in range(in_c):
out = scipy.signal.convolve2d(
padded_a_np[n, c], w_np[c, f], mode='valid')
b_np[n, f] += out
return b_np
......@@ -10,7 +10,7 @@ def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel,
in_height = in_width = in_size
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W')
W = tvm.placeholder((in_channel, num_filter, kernel, kernel), name='W')
B = topi.nn.conv2d_transpose_nchw(A, W, [stride, stride], padding)
C = topi.nn.relu(B)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment