Unverified Commit 686911e2 by masahi Committed by GitHub

[Torch] Fix conv2d conversion for group conv (group > 1 but != in channels) (#5132)

* Fix conv2d conversion for group conv

* add more comment for clarification
parent 0a0e58bf
......@@ -250,7 +250,12 @@ def _convolution():
channels = weight_shape[0]
groups = int(inputs[8])
if groups > 1:
# Check if this is depth wise convolution
# We need to reshape weight so that Relay could recognize this is depth wise
# weight_shape[1] is always in_channels // groups
# For depthwise, in_channels == groups, so weight_shape[1] == 1
# If groups > 1 but weight_shape[1] != 1, this is group convolution
if groups > 1 and weight_shape[1] == 1:
channel_multiplier = channels // groups
new_weight_shape = (groups, channel_multiplier, weight_shape[2], weight_shape[3])
weight = _op.transform.reshape(weight, new_weight_shape)
......
......@@ -428,7 +428,13 @@ def test_forward_conv():
input_data = torch.rand(input_shape).float()
verify_model(Conv2D1().float().eval(), input_data=input_data)
verify_model(Conv2D2().float().eval(), input_data=input_data)
# depth wise conv with channel mult 2
verify_model(Conv2D3().float().eval(), input_data=input_data)
# group conv
verify_model(torch.nn.Conv2d(8, 8, kernel_size=(3, 3),
stride=(1, 1), groups=2).eval(),
input_data=torch.randn((1, 8, 16, 16)))
def test_forward_threshold():
torch.set_grad_enabled(False)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment