Unverified Commit 86079479 by masahi Committed by GitHub

[Torch] Add initial 3D op support and test on Resnet 3D (#5075)

* fix minor lint issue

* add conv3d and adaptive avg pool3d conversion with test

* fix max pool handling

* add batch norm 3d test

* add resnet 3d test

* add more conv3d test

* clean up batch norm test

* add note on disabling inception v3 test

* add more tests

* add more tests

* fix names
parent a11391eb
......@@ -163,7 +163,7 @@ def _relu():
return _op.nn.relu(data)
return _impl
def _adaptive_avg_2d():
def _adaptive_avg_pool_2d():
def _impl(inputs, input_types):
data = inputs[0]
output_size = _infer_shape(inputs[1])
......@@ -178,14 +178,32 @@ def _adaptive_avg_2d():
return _impl
def _adaptive_max_2d():
def _adaptive_max_pool_2d():
def _impl(inputs, input_types):
data = inputs[0]
output_size = _infer_shape(inputs[1])
# returns dummy indices too
return _op.nn.adaptive_max_pool2d(
data,
output_size=output_size)
output_size=output_size), None
return _impl
def _adaptive_max_pool_3d():
def _impl(inputs, input_types):
data = inputs[0]
output_size = _infer_shape(inputs[1])
# returns dummy indices too
return _op.nn.adaptive_max_pool3d(data, output_size=output_size), None
return _impl
def _adaptive_avg_pool_3d():
def _impl(inputs, input_types):
data = inputs[0]
output_size = _infer_shape(inputs[1])
return _op.nn.adaptive_avg_pool3d(data, output_size=output_size)
return _impl
def _maxpool_2d():
......@@ -249,33 +267,30 @@ def _convolution():
if isinstance(dilation, _expr.Expr):
dilation = _infer_shape(dilation)
if use_transpose:
conv_out = _op.nn.conv2d_transpose(data,
weight,
strides=strides,
padding=padding,
dilation=dilation,
groups=groups,
channels=channels,
kernel_size=kernel_size,
data_layout="NCHW",
kernel_layout="OIHW",
out_layout="",
out_dtype="")
else:
conv_out = _op.nn.conv2d(data,
weight,
strides=strides,
padding=padding,
dilation=dilation,
groups=groups,
channels=channels,
kernel_size=kernel_size,
data_layout="NCHW",
kernel_layout="OIHW",
out_layout="",
out_dtype="")
data_layout = "NCHW"
kernel_layout = "OIHW"
conv_op = _op.nn.conv2d
if use_transpose:
assert len(kernel_size) == 2, "ConvTranspose 3D not supported"
conv_op = _op.nn.conv2d_transpose
if len(kernel_size) == 3:
conv_op = _op.nn.conv3d
data_layout = "NCDHW"
kernel_layout = "OIDHW"
conv_out = conv_op(data,
weight,
strides=strides,
padding=padding,
dilation=dilation,
groups=groups,
channels=channels,
kernel_size=kernel_size,
data_layout=data_layout,
kernel_layout=kernel_layout,
out_layout="",
out_dtype="")
if use_bias:
return _op.nn.bias_add(conv_out, bias)
else:
......@@ -844,8 +859,8 @@ _convert_map = {
"aten::select" : _select(),
"aten::relu" : _relu(),
"aten::relu_" : _relu(),
"aten::adaptive_avg_pool2d" : _adaptive_avg_2d(),
"aten::adaptive_max_pool2d" : _adaptive_max_2d(),
"aten::adaptive_avg_pool2d" : _adaptive_avg_pool_2d(),
"aten::adaptive_max_pool2d" : _adaptive_max_pool_2d(),
"aten::max_pool2d" : _maxpool_2d(),
"aten::max_pool2d_with_indices" : _maxpool_2d(),
"aten::hardtanh" : _hardtanh(),
......@@ -895,6 +910,8 @@ _convert_map = {
"aten::Float" : _Float(),
"aten::neg" : _neg(),
"aten::tanh" : _tanh(),
"aten::adaptive_avg_pool3d" : _adaptive_avg_pool_3d(),
"aten::adaptive_max_pool3d" : _adaptive_max_pool_3d()
}
......@@ -955,6 +972,7 @@ def _report_missing_conversion(op_names):
msg = "The following operators are not implemented: {}".format(missing)
raise NotImplementedError(msg)
def _check_input_names(script_module, input_shapes):
""" Check the graph inputs match the inputs """
ir_inputs = get_graph_input_names(script_module)
......@@ -1272,9 +1290,18 @@ def convert_operators(operators, outputs, output_index_map, ret_names):
_update_outputs_from_pairs(zip(unpacked_names, loop_out),
outputs, output_index_map)
else:
output_index_map[node_name] = len(outputs)
relay_op = _convert_map[operator]
outputs.append(relay_op(inputs, _get_input_types(op_node)))
relay_out = relay_op(inputs, _get_input_types(op_node))
if isinstance(relay_out, tuple):
# This is for torch operators that return multiple outputs
# See _adaptive_max_2d above for example
out_names = _get_output_names(op_node)
_update_outputs_from_pairs(zip(out_names, relay_out),
outputs, output_index_map)
else:
output_index_map[node_name] = len(outputs)
outputs.append(relay_out)
return [_wrap_const(outputs[output_index_map[ret_name]])
for ret_name in ret_names]
......
......@@ -358,8 +358,9 @@ def test_quantized_imagenet():
qmodels += [
("resnet18", qresnet.resnet18(pretrained=True), per_channel),
("mobilenet_v2", qmobilenet.mobilenet_v2(pretrained=True), per_channel),
# disable inception test for now, since loading it takes ~5min on torchvision-0.5
#("inception_v3", qinception.inception_v3(pretrained=True), per_channel),
# disable inception test for now, since loading it takes ~5min on torchvision-0.5 due to scipy bug
# See https://discuss.pytorch.org/t/torchvisions-inception-v3-takes-much-longer-to-load-than-other-models/68756
# ("inception_v3", qinception.inception_v3(pretrained=True), per_channel),
("googlenet", qgooglenet(pretrained=True), per_channel),
]
......
......@@ -452,27 +452,20 @@ def test_forward_contiguous():
input_data = torch.rand(input_shape).float()
verify_model(Contiguous1().float().eval(), input_data=input_data)
def test_forward_batchnorm():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
def init_weight(m):
torch.nn.init.normal_(m.weight, 0, 0.01)
torch.nn.init.normal_(m.bias)
class BatchNorm1(Module):
def __init__(self):
super(BatchNorm1, self).__init__()
self.batch_norm = torch.nn.BatchNorm2d(3, affine=True)
def forward(self, *args):
return self.batch_norm(args[0])
inp_2d = torch.rand((1, 16, 10, 10))
inp_3d = torch.rand((1, 16, 10, 10, 10))
class BatchNorm2(Module):
def __init__(self):
super(BatchNorm2, self).__init__()
self.batch_norm = torch.nn.BatchNorm2d(3, affine=False)
def forward(self, *args):
return self.batch_norm(args[0])
for bn, inp in [(torch.nn.BatchNorm2d(16), inp_2d),
(torch.nn.BatchNorm3d(16), inp_3d)]:
init_weight(bn.eval())
verify_model(bn.eval(), input_data=inp)
input_data = torch.rand(input_shape).float()
verify_model(BatchNorm1().float().eval(), input_data=input_data)
verify_model(BatchNorm2().float().eval(), input_data=input_data)
def test_forward_transpose():
torch.set_grad_enabled(False)
......@@ -708,6 +701,37 @@ def test_to():
verify_model(ToInt().eval(), torch.tensor(2.0))
def test_adaptive_pool3d():
for ishape in [(1, 32, 16, 16, 16),
(1, 32, 9, 15, 15),
(1, 32, 13, 7, 7)]:
inp = torch.rand(ishape)
verify_model(torch.nn.AdaptiveMaxPool3d((1, 1, 1)).eval(), inp)
verify_model(torch.nn.AdaptiveMaxPool3d((2, 2, 2)).eval(), inp)
verify_model(torch.nn.AdaptiveAvgPool3d((1, 1, 1)).eval(), inp)
verify_model(torch.nn.AdaptiveAvgPool3d((2, 2, 2)).eval(), inp)
verify_model(torch.nn.AdaptiveAvgPool3d((4, 8, 8)).eval(), inp)
verify_model(torch.nn.AdaptiveMaxPool3d((7, 8, 9)).eval(), inp)
def test_conv3d():
for ishape in [(1, 32, 16, 16, 16),
(1, 32, 9, 15, 15),
(1, 32, 13, 7, 7)]:
inp = torch.rand(ishape)
verify_model(torch.nn.Conv3d(32, 16, (3, 3, 3),
padding=(1, 1, 1)).eval(),
inp),
verify_model(torch.nn.Conv3d(32, 16, (5, 5, 5),
padding=(2, 2, 2)).eval(),
inp),
verify_model(torch.nn.Conv3d(32, 16, kernel_size=1).eval(),
inp)
# downsample
verify_model(torch.nn.Conv3d(32, 16, kernel_size=1, stride=2).eval(),
inp)
# Model tests
def test_resnet18():
torch.set_grad_enabled(False)
......@@ -809,6 +833,12 @@ def test_segmentaton_models():
verify_model(SegmentationModelWrapper(deeplab.eval()), inp, [cuda_ctx])
def test_3d_models():
input_shape = (1, 3, 4, 56, 56)
resnet3d = torchvision.models.video.r3d_18(pretrained=True).eval()
verify_model(resnet3d, [torch.rand(input_shape)])
def verify_script_model(pt_model, ishapes):
script_module = torch.jit.script(pt_model)
input_names = get_graph_input_names(script_module)
......@@ -1021,13 +1051,17 @@ if __name__ == "__main__":
test_forward_chunk()
test_upsample()
test_to()
test_adaptive_pool3d()
test_conv3d()
# Model tests
test_resnet18()
test_squeezenet1_0()
test_squeezenet1_1()
test_densenet121()
test_inception_v3()
# disable inception test for now, since loading it takes ~5min on torchvision-0.5 due to scipy bug
# See https://discuss.pytorch.org/t/torchvisions-inception-v3-takes-much-longer-to-load-than-other-models/68756
# test_inception_v3()
test_googlenet()
test_mnasnet0_5()
test_mobilenet_v2()
......@@ -1035,6 +1069,7 @@ if __name__ == "__main__":
test_custom_conversion_map()
test_segmentaton_models()
test_3d_models()
# Quantization test
from qnn_test import test_quantized_imagenet, test_quantized_modules
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment