Commit 11ebd8d3 by Tatsuya Nishiyama Committed by Tianqi Chen

support dilation in conv2d (#439)

parent d5744844
......@@ -85,11 +85,18 @@ def compute_conv2d(attrs, inputs, _):
channels = attrs.get_int("channels")
layout = attrs["layout"]
assert layout == "NCHW" or layout == "NHWC"
assert dilation == (1, 1), "not support dilate now"
(dilation_h, dilation_w) = dilation
if dilation_h < 1 or dilation_w < 1:
raise ValueError("dilation should be positive value")
elif layout == "NCHW":
kernel = topi.nn.dilate(inputs[1], [1, 1, dilation_h, dilation_w])
else: #layout == NHWC
kernel = topi.nn.dilate(inputs[1], [1, dilation_h, dilation_w, 1])
if groups == 1:
out = topi.nn.conv2d(inputs[0], inputs[1], strides, padding, layout)
out = topi.nn.conv2d(inputs[0], kernel, strides, padding, layout)
elif groups == get_const_int(inputs[0].shape[1]) and groups == channels:
out = topi.nn.depthwise_conv2d_nchw(inputs[0], inputs[1], strides, padding)
out = topi.nn.depthwise_conv2d_nchw(inputs[0], kernel, strides, padding)
else:
raise ValueError("not support arbitrary group number for now")
if attrs.get_bool("use_bias"):
......
......@@ -32,6 +32,32 @@ def test_conv2d():
np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)
def test_dilated_conv2d():
dilation = 3
x = sym.Variable("x")
y = sym.conv2d(x, channels=10, kernel_size=(3, 3), dilation=(dilation, dilation),
name="y", padding=(1, 1))
dtype = "float32"
dshape = (1, 3, 18, 18)
kshape = (10, 3, 3, 3)
oshape = (1, 10, 14, 14)
shape_dict = {"x": dshape}
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, shape_dict)
m = graph_runtime.create(graph, lib, ctx)
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
bias = tvm.nd.array(np.random.uniform(size=kshape[0]).astype(dtype))
kernel_np = np.random.uniform(size=kshape).astype(dtype)
kernel = tvm.nd.array(kernel_np)
dkernel_np = topi.testing.dilate_python(kernel_np, (1, 1, dilation, dilation))
m.run(x=data, y_weight=kernel, y_bias=bias)
out = m.get_output(0, tvm.nd.empty(oshape, dtype))
c_np = topi.testing.conv2d_nchw_python(
data.asnumpy(), dkernel_np, 1, 1)
c_np = c_np + bias.asnumpy().reshape(kshape[0], 1, 1)
np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)
def test_grouped_conv2d():
x = sym.Variable("x")
y = sym.conv2d(x, channels=32, kernel_size=(3,3), groups=32,
......@@ -170,6 +196,7 @@ def test_upsampling():
if __name__ == "__main__":
test_conv2d()
test_dilated_conv2d()
test_grouped_conv2d()
test_conv2d_transpose()
test_max_pool2d()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment