Commit 10ca0b1f by Yuwei Hu Committed by Tianqi Chen

[TOP] add conv2d_transpose (#217)

* [TOP] add conv2d_transpose

* update tvm

* fix pylint
parent ad7ffd35
......@@ -126,6 +126,37 @@ def schedule_conv2d(attrs, outs, target):
reg.register_pattern("conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
# conv2d_transpose
@reg.register_compute("conv2d_transpose")
def compute_conv2d_transpose(attrs, inputs, _):
"""Compute definition of conv2d_transpose"""
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation")
groups = attrs.get_int("groups")
layout = attrs["layout"]
assert layout == "NCHW", "only support nchw for now"
assert dilation == (1, 1), "not support dilate now"
assert groups == 1, "only support groups == 1 for now"
out = topi.nn.conv2d_transpose_nchw(inputs[0], inputs[1], strides, padding)
if attrs.get_bool("use_bias"):
bias = inputs[2]
bias = topi.expand_dims(bias, axis=1, num_newaxis=2)
out = topi.broadcast_add(out, bias)
output_padding = attrs.get_int_tuple("output_padding")
out = topi.nn.pad(out, \
[0, 0, 0, 0], [0, 0, output_padding[0], output_padding[1]])
return out
@reg.register_schedule("conv2d_transpose")
def schedule_conv2d_transpose(attrs, outs, target):
"""Schedule definition of conv2d_transpose"""
with tvm.target.create(target):
return topi.generic.schedule_conv2d_transpose_nchw(outs)
reg.register_pattern("conv2d_transpose", OpPattern.OUT_ELEMWISE_FUSABLE)
# max_pool2d
@reg.register_compute("max_pool2d")
def compute_max_pool2d(attrs, inputs, _):
......
......@@ -150,10 +150,13 @@ inline bool Conv2DTransposeInferShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(param.dilation.ndim(), 2U)
<< "incorrect dilate size: " << param.dilation;
TShape wshape({dshape_nchw[1],
param.channels / param.groups,
param.kernel_size[0], param.kernel_size[1]});
TShape wshape({param.channels / param.groups,
dshape_nchw[1] / param.groups,
param.kernel_size[0],
param.kernel_size[1]});
wshape = ConvertLayout(wshape, kNCHW, param.layout);
wshape[0] *= param.groups;
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, Conv2DTransposeParam::kWeight, wshape);
......
......@@ -54,6 +54,31 @@ def test_grouped_conv2d():
np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)
def test_conv2d_transpose():
x = sym.Variable("x")
y = sym.conv2d_transpose(x, channels=10, kernel_size=(3,3), strides=(2,2),
name="y", padding=(1,1), output_padding=(2,2))
dtype = "float32"
dshape = (1, 3, 18, 18)
kshape = (10, 3, 3, 3)
oshape = (1, 10, 37, 37)
shape_dict = {"x": dshape}
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, shape_dict)
m = graph_runtime.create(graph, lib, ctx)
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
kernel = tvm.nd.array(np.random.uniform(size=kshape).astype(dtype))
bias = tvm.nd.array(np.random.uniform(size=kshape[0]).astype(dtype))
m.run(x=data, y_weight=kernel, y_bias=bias)
out = m.get_output(0, tvm.nd.empty(oshape, dtype))
c_np = topi.testing.conv2d_transpose_nchw_python(
data.asnumpy(), kernel.asnumpy(), 2, 1)
c_np = c_np + bias.asnumpy().reshape(kshape[0], 1, 1)
d_np = np.zeros(shape=oshape)
d_np[:,:,0:c_np.shape[2],0:c_np.shape[3]] = c_np
np.testing.assert_allclose(out.asnumpy(), d_np, rtol=1e-5)
def test_max_pool2d():
x = sym.Variable("x")
y = sym.max_pool2d(x, pool_size=(2,2), strides=(2,2),
......@@ -126,6 +151,7 @@ def test_global_avg_pool2d():
if __name__ == "__main__":
test_conv2d()
test_grouped_conv2d()
test_conv2d_transpose()
test_max_pool2d()
test_avg_pool2d()
test_global_max_pool2d()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment