Commit 359e335f by Yuwei Hu Committed by Tianqi Chen

[TOPI] support dilated conv2d and depthwise_conv2d (#1129)

* support dilation in conv2d and depthwise_conv2d

* handle dilated conv in extern libs (cudnn, miopen)
parent 5af51280
...@@ -4,6 +4,7 @@ import tvm ...@@ -4,6 +4,7 @@ import tvm
from tvm.contrib import cudnn from tvm.contrib import cudnn
import topi import topi
from ..nn.conv2d import conv2d from ..nn.conv2d import conv2d
from ..util import get_const_int
@conv2d.register("cuda") @conv2d.register("cuda")
def conv2d_cuda(data, kernel, stride, padding, layout='NCHW', out_dtype='float32'): def conv2d_cuda(data, kernel, stride, padding, layout='NCHW', out_dtype='float32'):
...@@ -40,6 +41,23 @@ def conv2d_cuda(data, kernel, stride, padding, layout='NCHW', out_dtype='float32 ...@@ -40,6 +41,23 @@ def conv2d_cuda(data, kernel, stride, padding, layout='NCHW', out_dtype='float32
pad_h = pad_w = padding pad_h = pad_w = padding
else: else:
pad_h, pad_w = padding pad_h, pad_w = padding
# handle dilation
dilation_h = dilation_w = 1
kernel_tvm = kernel
kernel_cudnn = kernel
if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
kernel_before_dilation = kernel.op.input_tensors[0]
kernel_cudnn = kernel_before_dilation
if layout == 'NCHW':
dilation_h = (get_const_int(kernel.shape[2]) + get_const_int(kernel_before_dilation.shape[2]) - 1) \
// get_const_int(kernel_before_dilation.shape[2])
dilation_w = (get_const_int(kernel.shape[3]) + get_const_int(kernel_before_dilation.shape[3]) - 1) \
// get_const_int(kernel_before_dilation.shape[2])
elif layout == 'NHWC':
dilation_h = (get_const_int(kernel.shape[1]) + get_const_int(kernel_before_dilation.shape[1]) - 1) \
// get_const_int(kernel_before_dilation.shape[1])
dilation_w = (get_const_int(kernel.shape[2]) + get_const_int(kernel_before_dilation.shape[2]) - 1) \
// get_const_int(kernel_before_dilation.shape[2])
target = tvm.target.current_target() target = tvm.target.current_target()
if "cudnn" in target.libs: if "cudnn" in target.libs:
assert layout != 'HWCN', "HWCN layout not supported with CUDNN." assert layout != 'HWCN', "HWCN layout not supported with CUDNN."
...@@ -47,19 +65,19 @@ def conv2d_cuda(data, kernel, stride, padding, layout='NCHW', out_dtype='float32 ...@@ -47,19 +65,19 @@ def conv2d_cuda(data, kernel, stride, padding, layout='NCHW', out_dtype='float32
if layout == 'NHWC': if layout == 'NHWC':
tensor_format = 1 # CUDNN_TENSOR_NHWC tensor_format = 1 # CUDNN_TENSOR_NHWC
return cudnn.conv2d_forward(data, return cudnn.conv2d_forward(data,
kernel, kernel_cudnn,
stride_h, stride_h,
stride_w, stride_w,
pad_h, pad_h,
pad_w, pad_w,
1, # dilation_h dilation_h,
1, # dilation_w dilation_w,
conv_mode=1, conv_mode=1,
tensor_format=tensor_format, tensor_format=tensor_format,
algo=-1) # let CUDNN choose the best algo algo=-1) # let CUDNN choose the best algo
elif layout == 'NCHW': elif layout == 'NCHW':
return topi.nn.conv2d_nchw(data, kernel, stride, padding, out_dtype) return topi.nn.conv2d_nchw(data, kernel_tvm, stride, padding, out_dtype)
elif layout == 'HWCN': elif layout == 'HWCN':
return topi.nn.conv2d_hwcn(data, kernel, stride, padding, out_dtype) return topi.nn.conv2d_hwcn(data, kernel_tvm, stride, padding, out_dtype)
else: else:
raise ValueError("not support this layout {} yet".format(layout)) raise ValueError("not support this layout {} yet".format(layout))
...@@ -110,6 +110,8 @@ def schedule_conv2d_hwcn(outs): ...@@ -110,6 +110,8 @@ def schedule_conv2d_hwcn(outs):
elif operator.tag == 'conv2d_hwcn': elif operator.tag == 'conv2d_hwcn':
Apad = operator.input_tensors[0] Apad = operator.input_tensors[0]
W = operator.input_tensors[1] W = operator.input_tensors[1]
if isinstance(W.op, tvm.tensor.ComputeOp) and 'dilate' in W.op.tag:
sch[W].compute_inline()
B = operator.output(0) B = operator.output(0)
schedule(Apad, W, B) schedule(Apad, W, B)
else: else:
......
...@@ -495,6 +495,8 @@ def schedule_conv2d_small_batch(outs): ...@@ -495,6 +495,8 @@ def schedule_conv2d_small_batch(outs):
if 'conv2d_nchw' in OP.tag: if 'conv2d_nchw' in OP.tag:
temp = OP.input_tensors[0] temp = OP.input_tensors[0]
Filter = OP.input_tensors[1] Filter = OP.input_tensors[1]
if isinstance(Filter.op, tvm.tensor.ComputeOp) and 'dilate' in Filter.op.tag:
s[Filter].compute_inline()
Output = OP.output(0) Output = OP.output(0)
schedule(temp, Filter, Output) schedule(temp, Filter, Output)
......
...@@ -114,6 +114,8 @@ def schedule_depthwise_conv2d_nchw(outs): ...@@ -114,6 +114,8 @@ def schedule_depthwise_conv2d_nchw(outs):
if OP.tag == 'depthwise_conv2d_nchw': if OP.tag == 'depthwise_conv2d_nchw':
PaddedInput = OP.input_tensors[0] PaddedInput = OP.input_tensors[0]
Filter = OP.input_tensors[1] Filter = OP.input_tensors[1]
if isinstance(Filter.op, tvm.tensor.ComputeOp) and 'dilate' in Filter.op.tag:
s[Filter].compute_inline()
DepthwiseConv2d = OP.output(0) DepthwiseConv2d = OP.output(0)
_schedule(PaddedInput, Filter, DepthwiseConv2d) _schedule(PaddedInput, Filter, DepthwiseConv2d)
...@@ -191,6 +193,8 @@ def schedule_depthwise_conv2d_nhwc(outs): ...@@ -191,6 +193,8 @@ def schedule_depthwise_conv2d_nhwc(outs):
if OP.tag == 'depthwise_conv2d_nhwc': if OP.tag == 'depthwise_conv2d_nhwc':
PaddedInput = OP.input_tensors[0] PaddedInput = OP.input_tensors[0]
Filter = OP.input_tensors[1] Filter = OP.input_tensors[1]
if isinstance(Filter.op, tvm.tensor.ComputeOp) and 'dilate' in Filter.op.tag:
s[Filter].compute_inline()
DepthwiseConv2d = OP.output(0) DepthwiseConv2d = OP.output(0)
_schedule(PaddedInput, Filter, DepthwiseConv2d) _schedule(PaddedInput, Filter, DepthwiseConv2d)
......
...@@ -21,6 +21,6 @@ def schedule_extern(outs): ...@@ -21,6 +21,6 @@ def schedule_extern(outs):
""" """
target = tvm.target.current_target(allow_none=False) target = tvm.target.current_target(allow_none=False)
if target.target_name != "llvm": if target.target_name != "llvm":
raise RuntimeError("schedule_injective not registered for '%s'" % target) raise RuntimeError("schedule_extern not registered for '%s'" % target)
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
return tvm.create_schedule([x.op for x in outs]) return tvm.create_schedule([x.op for x in outs])
...@@ -289,6 +289,10 @@ def _schedule_spatialpack_conv2d(s, op): ...@@ -289,6 +289,10 @@ def _schedule_spatialpack_conv2d(s, op):
if data.dtype == 'float16' and (util.get_const_int(conv.shape[1]) == 4 or output_height == 28): if data.dtype == 'float16' and (util.get_const_int(conv.shape[1]) == 4 or output_height == 28):
num_thread //= 2 num_thread //= 2
# schedule dilation
if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
s[kernel].compute_inline()
# schedule padding # schedule padding
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
data_pad = data data_pad = data
...@@ -431,6 +435,10 @@ def _schedule_im2col_conv2d(s, op): ...@@ -431,6 +435,10 @@ def _schedule_im2col_conv2d(s, op):
num_thread1 = num_thread * 2 num_thread1 = num_thread * 2
num_thread2 = num_thread // 2 num_thread2 = num_thread // 2
# schedule dilation
if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
s[kernel].compute_inline()
# schedule padding # schedule padding
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
data_pad = data data_pad = data
...@@ -616,6 +624,10 @@ def _schedule_winograd(s, op): ...@@ -616,6 +624,10 @@ def _schedule_winograd(s, op):
data_pad = s[d].op.input_tensors[0] data_pad = s[d].op.input_tensors[0]
data = s[data_pad].op.input_tensors[0] data = s[data_pad].op.input_tensors[0]
# dilation
if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
s[kernel].compute_inline()
# padding # padding
s[data_pad].compute_inline() s[data_pad].compute_inline()
......
...@@ -100,6 +100,8 @@ def schedule_depthwise_conv2d_nchw(outs): ...@@ -100,6 +100,8 @@ def schedule_depthwise_conv2d_nchw(outs):
if op.tag == 'depthwise_conv2d_nchw': if op.tag == 'depthwise_conv2d_nchw':
pad_data = op.input_tensors[0] pad_data = op.input_tensors[0]
kernel = op.input_tensors[1] kernel = op.input_tensors[1]
if isinstance(kernel.op, tvm.tensor.ComputeOp) and 'dilate' in kernel.op.tag:
s[kernel].compute_inline()
conv = op.output(0) conv = op.output(0)
_schedule(pad_data, kernel, conv) _schedule(pad_data, kernel, conv)
......
...@@ -5,7 +5,7 @@ import tvm ...@@ -5,7 +5,7 @@ import tvm
from .. import util from .. import util
from .. import tag from .. import tag
@tvm.tag_scope(tag=tag.INJECTIVE) @tvm.tag_scope(tag=tag.INJECTIVE+",dilate")
def dilate(data, strides, name="DilatedInput"): def dilate(data, strides, name="DilatedInput"):
"""Dilate data with zeros. """Dilate data with zeros.
......
...@@ -6,7 +6,7 @@ from .. import tag ...@@ -6,7 +6,7 @@ from .. import tag
@tvm.tag_scope(tag=tag.INJECTIVE+",pad") @tvm.tag_scope(tag=tag.INJECTIVE+",pad")
def pad(data, pad_before, pad_after=None, pad_value=0.0, name="PadInput"): def pad(data, pad_before, pad_after=None, pad_value=0.0, name="PadInput"):
"""Dilate Input with zeros. """Pad Input with zeros.
Parameters Parameters
---------- ----------
......
...@@ -43,6 +43,9 @@ def schedule_conv2d_nchw(outs): ...@@ -43,6 +43,9 @@ def schedule_conv2d_nchw(outs):
elif OP.tag.startswith('conv2d_nchw'): elif OP.tag.startswith('conv2d_nchw'):
conv2d = OP.output(0) conv2d = OP.output(0)
data = OP.input_tensors[0] data = OP.input_tensors[0]
kernel = OP.input_tensors[1]
if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
s[kernel].compute_inline()
_schedule(conv2d, data) _schedule(conv2d, data)
else: else:
raise RuntimeError("Unsupported operator: %s" % OP.tag) raise RuntimeError("Unsupported operator: %s" % OP.tag)
......
...@@ -328,6 +328,8 @@ def schedule_conv2d_nchw(outs): ...@@ -328,6 +328,8 @@ def schedule_conv2d_nchw(outs):
conv_out = op.input_tensors[0] conv_out = op.input_tensors[0]
kernel_vec = conv_out.op.input_tensors[1] kernel_vec = conv_out.op.input_tensors[1]
kernel = kernel_vec.op.input_tensors[0] kernel = kernel_vec.op.input_tensors[0]
if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
s[kernel].compute_inline()
data_vec = conv_out.op.input_tensors[0] data_vec = conv_out.op.input_tensors[0]
data = data_vec.op.input_tensors[0] data = data_vec.op.input_tensors[0]
data_pad = None data_pad = None
......
...@@ -194,6 +194,8 @@ def schedule_depthwise_conv2d_nchw(outs): ...@@ -194,6 +194,8 @@ def schedule_depthwise_conv2d_nchw(outs):
if op.tag == 'depthwise_conv2d_nchw': if op.tag == 'depthwise_conv2d_nchw':
output = op.output(0) output = op.output(0)
kernel = op.input_tensors[1] kernel = op.input_tensors[1]
if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
s[kernel].compute_inline()
data = op.input_tensors[0] data = op.input_tensors[0]
data_pad = None data_pad = None
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
......
...@@ -5,6 +5,8 @@ from tvm.contrib import miopen ...@@ -5,6 +5,8 @@ from tvm.contrib import miopen
import topi import topi
from .. import generic from .. import generic
from ..nn.conv2d import conv2d from ..nn.conv2d import conv2d
from ..util import get_const_int
@conv2d.register("rocm") @conv2d.register("rocm")
def conv2d_rocm(data, kernel, stride, padding, layout='NCHW', out_dtype='float32'): def conv2d_rocm(data, kernel, stride, padding, layout='NCHW', out_dtype='float32'):
...@@ -42,18 +44,29 @@ def conv2d_rocm(data, kernel, stride, padding, layout='NCHW', out_dtype='float32 ...@@ -42,18 +44,29 @@ def conv2d_rocm(data, kernel, stride, padding, layout='NCHW', out_dtype='float32
pad_h = pad_w = padding pad_h = pad_w = padding
else: else:
pad_h, pad_w = padding pad_h, pad_w = padding
# handle dilation
dilation_h = dilation_w = 1
kernel_tvm = kernel
kernel_cudnn = kernel
if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
kernel_before_dilation = kernel.op.input_tensors[0]
kernel_cudnn = kernel_before_dilation
dilation_h = (get_const_int(kernel.shape[2]) + get_const_int(kernel_before_dilation.shape[2]) - 1) \
// get_const_int(kernel_before_dilation.shape[2])
dilation_w = (get_const_int(kernel.shape[3]) + get_const_int(kernel_before_dilation.shape[3]) - 1) \
// get_const_int(kernel_before_dilation.shape[2])
target = tvm.target.current_target() target = tvm.target.current_target()
if "miopen" in target.libs: if "miopen" in target.libs:
return miopen.conv2d_forward(data, return miopen.conv2d_forward(data,
kernel, kernel_cudnn,
stride_h, stride_h,
stride_w, stride_w,
pad_h, pad_h,
pad_w, pad_w,
1, # dilation_h dilation_h,
1, # dilation_w dilation_w,
conv_mode=0) conv_mode=0)
return topi.nn.conv2d_nchw(data, kernel, stride, padding, out_dtype) return topi.nn.conv2d_nchw(data, kernel_tvm, stride, padding, out_dtype)
@generic.schedule_conv2d_nchw.register(["rocm"]) @generic.schedule_conv2d_nchw.register(["rocm"])
......
...@@ -91,6 +91,8 @@ def schedule_conv2d(outs): ...@@ -91,6 +91,8 @@ def schedule_conv2d(outs):
"""NCHW conv2d schedule for non imagenet workloads""" """NCHW conv2d schedule for non imagenet workloads"""
conv = op.output(0) conv = op.output(0)
kernel = op.input_tensors[1] kernel = op.input_tensors[1]
if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
s[kernel].compute_inline()
data = op.input_tensors[0] data = op.input_tensors[0]
data_pad = None data_pad = None
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
...@@ -134,6 +136,8 @@ def schedule_conv2d(outs): ...@@ -134,6 +136,8 @@ def schedule_conv2d(outs):
conv_out = op.input_tensors[0] conv_out = op.input_tensors[0]
kernel_vec = conv_out.op.input_tensors[1] kernel_vec = conv_out.op.input_tensors[1]
kernel = kernel_vec.op.input_tensors[0] kernel = kernel_vec.op.input_tensors[0]
if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
s[kernel].compute_inline()
data_vec = conv_out.op.input_tensors[0] data_vec = conv_out.op.input_tensors[0]
data = data_vec.op.input_tensors[0] data = data_vec.op.input_tensors[0]
data_pad = None data_pad = None
...@@ -184,6 +188,9 @@ def schedule_conv2d_nhwc(outs): ...@@ -184,6 +188,9 @@ def schedule_conv2d_nhwc(outs):
if 'conv2d_nhwc' in op.tag: if 'conv2d_nhwc' in op.tag:
conv = op.output(0) conv = op.output(0)
kernel = op.input_tensors[1] kernel = op.input_tensors[1]
if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
s[kernel].compute_inline()
data = op.input_tensors[0] data = op.input_tensors[0]
data_pad = None data_pad = None
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
......
...@@ -8,12 +8,13 @@ from tvm.contrib.pickle_memoize import memoize ...@@ -8,12 +8,13 @@ from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple from topi.util import get_const_tuple
def verify_conv2d_hwcn(batch, in_channel, in_size, num_filter, kernel, stride, padding): def verify_conv2d_hwcn(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1):
in_height = in_width = in_size in_height = in_width = in_size
A = tvm.placeholder((in_height, in_width, in_channel, batch), name='A') A = tvm.placeholder((in_height, in_width, in_channel, batch), name='A')
W = tvm.placeholder((kernel, kernel, in_channel, num_filter), name='W') W = tvm.placeholder((kernel, kernel, in_channel, num_filter), name='W')
B = topi.nn.conv2d_hwcn(A, W, stride, padding) dW = topi.nn.dilate(W, (dilation, dilation, 1, 1))
B = topi.nn.conv2d_hwcn(A, dW, stride, padding)
C = topi.nn.relu(B) C = topi.nn.relu(B)
s1 = topi.cuda.schedule_conv2d_hwcn([B]) s1 = topi.cuda.schedule_conv2d_hwcn([B])
s2 = topi.cuda.schedule_conv2d_hwcn([C]) s2 = topi.cuda.schedule_conv2d_hwcn([C])
...@@ -26,7 +27,8 @@ def verify_conv2d_hwcn(batch, in_channel, in_size, num_filter, kernel, stride, p ...@@ -26,7 +27,8 @@ def verify_conv2d_hwcn(batch, in_channel, in_size, num_filter, kernel, stride, p
def get_ref_data(): def get_ref_data():
a_np = np.random.uniform(size=a_shape).astype(dtype) a_np = np.random.uniform(size=a_shape).astype(dtype)
w_np = np.random.uniform(size=w_shape).astype(dtype) w_np = np.random.uniform(size=w_shape).astype(dtype)
b_np = topi.testing.conv2d_hwcn_python(a_np, w_np, stride, padding) dw_np = topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1))
b_np = topi.testing.conv2d_hwcn_python(a_np, dw_np, stride, padding)
c_np = np.maximum(b_np, 0) c_np = np.maximum(b_np, 0)
return a_np, w_np, b_np, c_np return a_np, w_np, b_np, c_np
a_np, w_np, b_np, c_np = get_ref_data() a_np, w_np, b_np, c_np = get_ref_data()
...@@ -63,7 +65,8 @@ def test_conv2d_hwcn(): ...@@ -63,7 +65,8 @@ def test_conv2d_hwcn():
verify_conv2d_hwcn(1, 256, 32, 256, 3, 1, "VALID") verify_conv2d_hwcn(1, 256, 32, 256, 3, 1, "VALID")
verify_conv2d_hwcn(4, 128, 16, 128, 5, 2, "VALID") verify_conv2d_hwcn(4, 128, 16, 128, 5, 2, "VALID")
verify_conv2d_hwcn(4, 128, 16, 256, 5, 2, "VALID") verify_conv2d_hwcn(4, 128, 16, 256, 5, 2, "VALID")
# dilation = 2
verify_conv2d_hwcn(1, 256, 32, 256, 3, 1, "SAME", dilation=2)
if __name__ == "__main__": if __name__ == "__main__":
test_conv2d_hwcn() test_conv2d_hwcn()
...@@ -3,10 +3,11 @@ import os ...@@ -3,10 +3,11 @@ import os
import numpy as np import numpy as np
import tvm import tvm
import topi import topi
import topi.testing
from tvm.contrib.pickle_memoize import memoize from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple from topi.util import get_const_tuple
def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding): def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1):
in_height = in_width = in_size in_height = in_width = in_size
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A') A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
...@@ -16,11 +17,12 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p ...@@ -16,11 +17,12 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
w_shape = get_const_tuple(W.shape) w_shape = get_const_tuple(W.shape)
dtype = A.dtype dtype = A.dtype
@memoize("topi.tests.test_topi_conv2d.verify_con2d_nchw") @memoize("topi.tests.test_topi_conv2d_nchw.verify_conv2d_nchw")
def get_ref_data(): def get_ref_data():
a_np = np.random.uniform(size=a_shape).astype(dtype) a_np = np.random.uniform(size=a_shape).astype(dtype)
w_np = np.random.uniform(size=w_shape).astype(dtype) w_np = np.random.uniform(size=w_shape).astype(dtype)
b_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding) dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
b_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding)
c_np = np.maximum(b_np, 0) c_np = np.maximum(b_np, 0)
return a_np, w_np, b_np, c_np return a_np, w_np, b_np, c_np
...@@ -33,7 +35,8 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p ...@@ -33,7 +35,8 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
B = topi.nn.conv2d(A, W, stride, padding, layout='NCHW') dW = topi.nn.dilate(W, (1, 1, dilation, dilation))
B = topi.nn.conv2d(A, dW, stride, padding, layout='NCHW')
C = topi.nn.relu(B) C = topi.nn.relu(B)
s1 = topi.generic.schedule_conv2d_nchw([B]) s1 = topi.generic.schedule_conv2d_nchw([B])
s2 = topi.generic.schedule_conv2d_nchw([C]) s2 = topi.generic.schedule_conv2d_nchw([C])
...@@ -43,8 +46,8 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p ...@@ -43,8 +46,8 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
with tvm.build_config(auto_unroll_max_step=1400, with tvm.build_config(auto_unroll_max_step=1400,
unroll_explicit=(device != "cuda")): unroll_explicit=(device != "cuda")):
func1 = tvm.build(s1, [A, W, B], device, name="conv2d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding)) func1 = tvm.build(s1, [A, W, B], device, name="conv2d_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
func2 = tvm.build(s2, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding)) func2 = tvm.build(s2, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
func1(a, w, b) func1(a, w, b)
func2(a, w, c) func2(a, w, c)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
...@@ -75,6 +78,8 @@ def test_conv2d_nchw(): ...@@ -75,6 +78,8 @@ def test_conv2d_nchw():
verify_conv2d_nchw(1, 64, 224, 64, 3, 1, 1) verify_conv2d_nchw(1, 64, 224, 64, 3, 1, 1)
verify_conv2d_nchw(1, 64, 224, 32, 3, 1, 1) verify_conv2d_nchw(1, 64, 224, 32, 3, 1, 1)
verify_conv2d_nchw(1, 32, 224, 9, 3, 1, 1) verify_conv2d_nchw(1, 32, 224, 9, 3, 1, 1)
# dilation = 2
verify_conv2d_nchw(1, 128, 122, 128, 3, 1, 1, dilation=2)
if __name__ == "__main__": if __name__ == "__main__":
test_conv2d_nchw() test_conv2d_nchw()
...@@ -8,12 +8,13 @@ from tvm.contrib.pickle_memoize import memoize ...@@ -8,12 +8,13 @@ from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple from topi.util import get_const_tuple
def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride, padding): def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1):
in_height = in_width = in_size in_height = in_width = in_size
A = tvm.placeholder((batch, in_height, in_width, in_channel), name='A') A = tvm.placeholder((batch, in_height, in_width, in_channel), name='A')
W = tvm.placeholder((kernel, kernel, in_channel, num_filter), name='W') W = tvm.placeholder((kernel, kernel, in_channel, num_filter), name='W')
B = topi.nn.conv2d_nhwc(A, W, stride, padding) dW = topi.nn.dilate(W, (1, dilation, dilation, 1))
B = topi.nn.conv2d_nhwc(A, dW, stride, padding)
a_shape = get_const_tuple(A.shape) a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape) w_shape = get_const_tuple(W.shape)
...@@ -23,7 +24,8 @@ def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride, p ...@@ -23,7 +24,8 @@ def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride, p
def get_ref_data(): def get_ref_data():
a_np = np.random.uniform(size=a_shape).astype(dtype) a_np = np.random.uniform(size=a_shape).astype(dtype)
w_np = np.random.uniform(size=w_shape).astype(dtype) w_np = np.random.uniform(size=w_shape).astype(dtype)
b_np = topi.testing.conv2d_nhwc_python(a_np, w_np, stride, padding) dw_np = topi.testing.dilate_python(w_np, (1, dilation, dilation, 1))
b_np = topi.testing.conv2d_nhwc_python(a_np, dw_np, stride, padding)
return a_np, w_np, b_np return a_np, w_np, b_np
a_np, w_np, b_np = get_ref_data() a_np, w_np, b_np = get_ref_data()
...@@ -54,6 +56,8 @@ def test_conv2d_nhwc(): ...@@ -54,6 +56,8 @@ def test_conv2d_nhwc():
verify_conv2d_nhwc(1, 256, 32, 256, 3, 1, "VALID") verify_conv2d_nhwc(1, 256, 32, 256, 3, 1, "VALID")
verify_conv2d_nhwc(4, 128, 16, 128, 5, 2, "VALID") verify_conv2d_nhwc(4, 128, 16, 128, 5, 2, "VALID")
verify_conv2d_nhwc(4, 128, 16, 256, 5, 2, "VALID") verify_conv2d_nhwc(4, 128, 16, 256, 5, 2, "VALID")
# dilation = 2
verify_conv2d_nhwc(1, 256, 32, 256, 3, 1, "SAME", dilation=2)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -8,21 +8,21 @@ from tvm.contrib.pickle_memoize import memoize ...@@ -8,21 +8,21 @@ from tvm.contrib.pickle_memoize import memoize
from topi.cuda.depthwise_conv2d import schedule_depthwise_conv2d_nhwc from topi.cuda.depthwise_conv2d import schedule_depthwise_conv2d_nhwc
def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_multiplier, filter_height, stride, padding): def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_multiplier, filter_height, stride, padding, dilation=1):
in_width = in_height in_width = in_height
filter_channel = in_channel filter_channel = in_channel
filter_width = filter_height filter_width = filter_height
# placeholder # placeholder
Input = tvm.placeholder((batch, in_channel, in_height, in_width), name='Input') Input = tvm.placeholder((batch, in_channel, in_height, in_width), name='Input')
Filter = tvm.placeholder((filter_channel, channel_multiplier, filter_height, filter_width), name='Filter') Filter = tvm.placeholder((filter_channel, channel_multiplier, filter_height, filter_width), name='Filter')
DilatedFilter = topi.nn.dilate(Filter, (1, 1, dilation, dilation), name='DilatedFilter')
Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale') Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale')
Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift') Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift')
# declare # declare
DepthwiseConv2d = topi.nn.depthwise_conv2d_nchw(Input, Filter, stride=stride, padding=padding) DepthwiseConv2d = topi.nn.depthwise_conv2d_nchw(Input, DilatedFilter, stride=stride, padding=padding)
ScaleShift = topi.nn.scale_shift_nchw(DepthwiseConv2d, Scale, Shift) ScaleShift = topi.nn.scale_shift_nchw(DepthwiseConv2d, Scale, Shift)
Relu = topi.nn.relu(ScaleShift) Relu = topi.nn.relu(ScaleShift)
def check_device(device): def check_device(device):
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
if not ctx.exist: if not ctx.exist:
...@@ -52,11 +52,12 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu ...@@ -52,11 +52,12 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
def get_ref_data(): def get_ref_data():
input_np = np.random.uniform(size=input_shape).astype(dtype) input_np = np.random.uniform(size=input_shape).astype(dtype)
filter_np = np.random.uniform(size=filter_shape).astype(dtype) filter_np = np.random.uniform(size=filter_shape).astype(dtype)
dilated_filter_np = topi.testing.dilate_python(filter_np, (1, 1, dilation, dilation))
scale_np = np.random.uniform(size=scale_shape).astype(dtype) scale_np = np.random.uniform(size=scale_shape).astype(dtype)
shift_np = np.random.uniform(size=shift_shape).astype(dtype) shift_np = np.random.uniform(size=shift_shape).astype(dtype)
# correctness with scipy # correctness with scipy
depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nchw( depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nchw(
input_np, filter_np, stride=stride, padding=padding) input_np, dilated_filter_np, stride=stride, padding=padding)
scale_shift_scipy = np.zeros(shape=scale_shift_shape) scale_shift_scipy = np.zeros(shape=scale_shift_shape)
for c in range(in_channel * channel_multiplier): for c in range(in_channel * channel_multiplier):
scale_shift_scipy[:,c,:,:] = depthwise_conv2d_scipy[:,c,:,:] * scale_np[c] + shift_np[c] scale_shift_scipy[:,c,:,:] = depthwise_conv2d_scipy[:,c,:,:] * scale_np[c] + shift_np[c]
...@@ -94,7 +95,7 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu ...@@ -94,7 +95,7 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
check_device("vulkan") check_device("vulkan")
def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding): def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding, dilation=1):
in_width = in_height in_width = in_height
filter_channel = in_channel filter_channel = in_channel
filter_width = filter_height filter_width = filter_height
...@@ -102,10 +103,11 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu ...@@ -102,10 +103,11 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu
# placeholder # placeholder
Input = tvm.placeholder((batch, in_height, in_width, in_channel), name='Input') Input = tvm.placeholder((batch, in_height, in_width, in_channel), name='Input')
Filter = tvm.placeholder((filter_height, filter_width,filter_channel, channel_multiplier), name='Filter') Filter = tvm.placeholder((filter_height, filter_width,filter_channel, channel_multiplier), name='Filter')
DilatedFilter = topi.nn.dilate(Filter, (1, 1, dilation, dilation), name='DilatedFilter')
Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale') Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale')
Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift') Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift')
# declare # declare
DepthwiseConv2d = topi.nn.depthwise_conv2d_nhwc(Input, Filter, stride=[stride_h, stride_w], padding=padding) DepthwiseConv2d = topi.nn.depthwise_conv2d_nhwc(Input, DilatedFilter, stride=[stride_h, stride_w], padding=padding)
ScaleShift = topi.nn.scale_shift_nhwc(DepthwiseConv2d, Scale, Shift) ScaleShift = topi.nn.scale_shift_nhwc(DepthwiseConv2d, Scale, Shift)
Relu = topi.nn.relu(ScaleShift) Relu = topi.nn.relu(ScaleShift)
# schedule # schedule
...@@ -139,11 +141,12 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu ...@@ -139,11 +141,12 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu
def get_ref_data(): def get_ref_data():
input_np = np.random.uniform(size=input_shape).astype(dtype) input_np = np.random.uniform(size=input_shape).astype(dtype)
filter_np = np.random.uniform(size=filter_shape).astype(dtype) filter_np = np.random.uniform(size=filter_shape).astype(dtype)
dilated_filter_np = topi.testing.dilate_python(filter_np, (1, 1, dilation, dilation))
scale_np = np.random.uniform(size=scale_shape).astype(dtype) scale_np = np.random.uniform(size=scale_shape).astype(dtype)
shift_np = np.random.uniform(size=shift_shape).astype(dtype) shift_np = np.random.uniform(size=shift_shape).astype(dtype)
# correctness with scipy # correctness with scipy
depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nhwc( depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nhwc(
input_np, filter_np, stride=[stride_h, stride_w], padding=padding) input_np, dilated_filter_np, stride=[stride_h, stride_w], padding=padding)
scale_shift_scipy = np.zeros(shape=scale_shift_shape) scale_shift_scipy = np.zeros(shape=scale_shift_shape)
for c in range(in_channel * channel_multiplier): for c in range(in_channel * channel_multiplier):
scale_shift_scipy[:,:,:,c] = depthwise_conv2d_scipy[:,:,:,c] * scale_np[c] + shift_np[c] scale_shift_scipy[:,:,:,c] = depthwise_conv2d_scipy[:,:,:,c] * scale_np[c] + shift_np[c]
...@@ -192,6 +195,8 @@ def test_depthwise_conv2d(): ...@@ -192,6 +195,8 @@ def test_depthwise_conv2d():
depthwise_conv2d_with_workload_nchw(1, 728, 32, 1, 3, 1, "VALID") depthwise_conv2d_with_workload_nchw(1, 728, 32, 1, 3, 1, "VALID")
depthwise_conv2d_with_workload_nchw(4, 256, 64, 2, 5, 2, "VALID") depthwise_conv2d_with_workload_nchw(4, 256, 64, 2, 5, 2, "VALID")
depthwise_conv2d_with_workload_nchw(4, 256, 32, 2, 5, 2, "VALID") depthwise_conv2d_with_workload_nchw(4, 256, 32, 2, 5, 2, "VALID")
# dilation = 2
depthwise_conv2d_with_workload_nchw(1, 728, 64, 1, 3, 1, "SAME", dilation=2)
print("testing nhwc") print("testing nhwc")
depthwise_conv2d_with_workload_nhwc(1, 728, 64, 1, 3, 1, "SAME") depthwise_conv2d_with_workload_nhwc(1, 728, 64, 1, 3, 1, "SAME")
depthwise_conv2d_with_workload_nhwc(1, 728, 32, 1, 3, 1, "SAME") depthwise_conv2d_with_workload_nhwc(1, 728, 32, 1, 3, 1, "SAME")
...@@ -201,6 +206,8 @@ def test_depthwise_conv2d(): ...@@ -201,6 +206,8 @@ def test_depthwise_conv2d():
depthwise_conv2d_with_workload_nhwc(1, 728, 32, 1, 3, 1, "VALID") depthwise_conv2d_with_workload_nhwc(1, 728, 32, 1, 3, 1, "VALID")
depthwise_conv2d_with_workload_nhwc(4, 256, 64, 2, 5, 2, "VALID") depthwise_conv2d_with_workload_nhwc(4, 256, 64, 2, 5, 2, "VALID")
depthwise_conv2d_with_workload_nhwc(4, 256, 32, 2, 5, 2, "VALID") depthwise_conv2d_with_workload_nhwc(4, 256, 32, 2, 5, 2, "VALID")
# dilation = 2
depthwise_conv2d_with_workload_nhwc(1, 728, 64, 1, 3, 1, "SAME", dilation=2)
if __name__ == "__main__": if __name__ == "__main__":
test_depthwise_conv2d() test_depthwise_conv2d()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment