Unverified Commit 9c806621 by Josh Fromm Committed by GitHub

[Relay][Frontend][Pytorch] Fixed ConvTranspose2D parsing (#5157)

* Fixed conv transpose parsing.

* small format change.

* Chage test module names.

* Simplified test syntax.
parent dada6761
......@@ -251,7 +251,7 @@ def _hardtanh():
def _convolution():
def _impl(inputs, input_types):
# Use transpose or normal
use_transpose = True if inputs[6] == "1" else False
use_transpose = True if inputs[6] == 1 else False
data = inputs[0]
weight = inputs[1]
......@@ -268,6 +268,10 @@ def _convolution():
else:
assert "data type {} could not be parsed in conv op" % (type(weight))
# Transposed convolutions have IOHW layout.
if use_transpose:
weight_shape[0], weight_shape[1] = weight_shape[1], weight_shape[0]
channels = weight_shape[0]
groups = int(inputs[8])
......
......@@ -448,6 +448,14 @@ def test_forward_conv():
input_data=torch.randn((1, 8, 16, 16)))
def test_forward_conv_transpose():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
input_data = torch.rand(input_shape).float()
verify_model(torch.nn.ConvTranspose2d(3, 6, 7, bias=True), input_data=input_data)
verify_model(torch.nn.ConvTranspose2d(3, 12, 3, bias=False), input_data=input_data)
def test_forward_threshold():
torch.set_grad_enabled(False)
input_shape = [1, 3]
......@@ -1050,6 +1058,7 @@ if __name__ == "__main__":
test_forward_maxpool1d()
test_forward_hardtanh()
test_forward_conv()
test_forward_conv_transpose()
test_forward_threshold()
test_forward_contiguous()
test_forward_batchnorm()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment