Unverified Commit 86079479 by masahi Committed by GitHub

[Torch] Add initial 3D op support and test on Resnet 3D (#5075)

* fix minor lint issue

* add conv3d and adaptive avg pool3d conversion with test

* fix max pool handling

* add batch norm 3d test

* add resnet 3d test

* add more conv3d test

* clean up batch norm test

* add note on disabling inception v3 test

* add more tests

* add more tests

* fix names
parent a11391eb
...@@ -163,7 +163,7 @@ def _relu(): ...@@ -163,7 +163,7 @@ def _relu():
return _op.nn.relu(data) return _op.nn.relu(data)
return _impl return _impl
def _adaptive_avg_2d(): def _adaptive_avg_pool_2d():
def _impl(inputs, input_types): def _impl(inputs, input_types):
data = inputs[0] data = inputs[0]
output_size = _infer_shape(inputs[1]) output_size = _infer_shape(inputs[1])
...@@ -178,14 +178,32 @@ def _adaptive_avg_2d(): ...@@ -178,14 +178,32 @@ def _adaptive_avg_2d():
return _impl return _impl
def _adaptive_max_2d(): def _adaptive_max_pool_2d():
def _impl(inputs, input_types): def _impl(inputs, input_types):
data = inputs[0] data = inputs[0]
output_size = _infer_shape(inputs[1]) output_size = _infer_shape(inputs[1])
# returns dummy indices too
return _op.nn.adaptive_max_pool2d( return _op.nn.adaptive_max_pool2d(
data, data,
output_size=output_size) output_size=output_size), None
return _impl
def _adaptive_max_pool_3d():
def _impl(inputs, input_types):
data = inputs[0]
output_size = _infer_shape(inputs[1])
# returns dummy indices too
return _op.nn.adaptive_max_pool3d(data, output_size=output_size), None
return _impl
def _adaptive_avg_pool_3d():
def _impl(inputs, input_types):
data = inputs[0]
output_size = _infer_shape(inputs[1])
return _op.nn.adaptive_avg_pool3d(data, output_size=output_size)
return _impl return _impl
def _maxpool_2d(): def _maxpool_2d():
...@@ -249,21 +267,19 @@ def _convolution(): ...@@ -249,21 +267,19 @@ def _convolution():
if isinstance(dilation, _expr.Expr): if isinstance(dilation, _expr.Expr):
dilation = _infer_shape(dilation) dilation = _infer_shape(dilation)
data_layout = "NCHW"
kernel_layout = "OIHW"
conv_op = _op.nn.conv2d
if use_transpose: if use_transpose:
conv_out = _op.nn.conv2d_transpose(data, assert len(kernel_size) == 2, "ConvTranspose 3D not supported"
weight, conv_op = _op.nn.conv2d_transpose
strides=strides, if len(kernel_size) == 3:
padding=padding, conv_op = _op.nn.conv3d
dilation=dilation, data_layout = "NCDHW"
groups=groups, kernel_layout = "OIDHW"
channels=channels,
kernel_size=kernel_size, conv_out = conv_op(data,
data_layout="NCHW",
kernel_layout="OIHW",
out_layout="",
out_dtype="")
else:
conv_out = _op.nn.conv2d(data,
weight, weight,
strides=strides, strides=strides,
padding=padding, padding=padding,
...@@ -271,11 +287,10 @@ def _convolution(): ...@@ -271,11 +287,10 @@ def _convolution():
groups=groups, groups=groups,
channels=channels, channels=channels,
kernel_size=kernel_size, kernel_size=kernel_size,
data_layout="NCHW", data_layout=data_layout,
kernel_layout="OIHW", kernel_layout=kernel_layout,
out_layout="", out_layout="",
out_dtype="") out_dtype="")
if use_bias: if use_bias:
return _op.nn.bias_add(conv_out, bias) return _op.nn.bias_add(conv_out, bias)
else: else:
...@@ -844,8 +859,8 @@ _convert_map = { ...@@ -844,8 +859,8 @@ _convert_map = {
"aten::select" : _select(), "aten::select" : _select(),
"aten::relu" : _relu(), "aten::relu" : _relu(),
"aten::relu_" : _relu(), "aten::relu_" : _relu(),
"aten::adaptive_avg_pool2d" : _adaptive_avg_2d(), "aten::adaptive_avg_pool2d" : _adaptive_avg_pool_2d(),
"aten::adaptive_max_pool2d" : _adaptive_max_2d(), "aten::adaptive_max_pool2d" : _adaptive_max_pool_2d(),
"aten::max_pool2d" : _maxpool_2d(), "aten::max_pool2d" : _maxpool_2d(),
"aten::max_pool2d_with_indices" : _maxpool_2d(), "aten::max_pool2d_with_indices" : _maxpool_2d(),
"aten::hardtanh" : _hardtanh(), "aten::hardtanh" : _hardtanh(),
...@@ -895,6 +910,8 @@ _convert_map = { ...@@ -895,6 +910,8 @@ _convert_map = {
"aten::Float" : _Float(), "aten::Float" : _Float(),
"aten::neg" : _neg(), "aten::neg" : _neg(),
"aten::tanh" : _tanh(), "aten::tanh" : _tanh(),
"aten::adaptive_avg_pool3d" : _adaptive_avg_pool_3d(),
"aten::adaptive_max_pool3d" : _adaptive_max_pool_3d()
} }
...@@ -955,6 +972,7 @@ def _report_missing_conversion(op_names): ...@@ -955,6 +972,7 @@ def _report_missing_conversion(op_names):
msg = "The following operators are not implemented: {}".format(missing) msg = "The following operators are not implemented: {}".format(missing)
raise NotImplementedError(msg) raise NotImplementedError(msg)
def _check_input_names(script_module, input_shapes): def _check_input_names(script_module, input_shapes):
""" Check the graph inputs match the inputs """ """ Check the graph inputs match the inputs """
ir_inputs = get_graph_input_names(script_module) ir_inputs = get_graph_input_names(script_module)
...@@ -1272,9 +1290,18 @@ def convert_operators(operators, outputs, output_index_map, ret_names): ...@@ -1272,9 +1290,18 @@ def convert_operators(operators, outputs, output_index_map, ret_names):
_update_outputs_from_pairs(zip(unpacked_names, loop_out), _update_outputs_from_pairs(zip(unpacked_names, loop_out),
outputs, output_index_map) outputs, output_index_map)
else: else:
output_index_map[node_name] = len(outputs)
relay_op = _convert_map[operator] relay_op = _convert_map[operator]
outputs.append(relay_op(inputs, _get_input_types(op_node))) relay_out = relay_op(inputs, _get_input_types(op_node))
if isinstance(relay_out, tuple):
# This is for torch operators that return multiple outputs
# See _adaptive_max_2d above for example
out_names = _get_output_names(op_node)
_update_outputs_from_pairs(zip(out_names, relay_out),
outputs, output_index_map)
else:
output_index_map[node_name] = len(outputs)
outputs.append(relay_out)
return [_wrap_const(outputs[output_index_map[ret_name]]) return [_wrap_const(outputs[output_index_map[ret_name]])
for ret_name in ret_names] for ret_name in ret_names]
......
...@@ -358,8 +358,9 @@ def test_quantized_imagenet(): ...@@ -358,8 +358,9 @@ def test_quantized_imagenet():
qmodels += [ qmodels += [
("resnet18", qresnet.resnet18(pretrained=True), per_channel), ("resnet18", qresnet.resnet18(pretrained=True), per_channel),
("mobilenet_v2", qmobilenet.mobilenet_v2(pretrained=True), per_channel), ("mobilenet_v2", qmobilenet.mobilenet_v2(pretrained=True), per_channel),
# disable inception test for now, since loading it takes ~5min on torchvision-0.5 # disable inception test for now, since loading it takes ~5min on torchvision-0.5 due to scipy bug
#("inception_v3", qinception.inception_v3(pretrained=True), per_channel), # See https://discuss.pytorch.org/t/torchvisions-inception-v3-takes-much-longer-to-load-than-other-models/68756
# ("inception_v3", qinception.inception_v3(pretrained=True), per_channel),
("googlenet", qgooglenet(pretrained=True), per_channel), ("googlenet", qgooglenet(pretrained=True), per_channel),
] ]
......
...@@ -452,27 +452,20 @@ def test_forward_contiguous(): ...@@ -452,27 +452,20 @@ def test_forward_contiguous():
input_data = torch.rand(input_shape).float() input_data = torch.rand(input_shape).float()
verify_model(Contiguous1().float().eval(), input_data=input_data) verify_model(Contiguous1().float().eval(), input_data=input_data)
def test_forward_batchnorm(): def test_forward_batchnorm():
torch.set_grad_enabled(False) def init_weight(m):
input_shape = [1, 3, 10, 10] torch.nn.init.normal_(m.weight, 0, 0.01)
torch.nn.init.normal_(m.bias)
class BatchNorm1(Module): inp_2d = torch.rand((1, 16, 10, 10))
def __init__(self): inp_3d = torch.rand((1, 16, 10, 10, 10))
super(BatchNorm1, self).__init__()
self.batch_norm = torch.nn.BatchNorm2d(3, affine=True)
def forward(self, *args):
return self.batch_norm(args[0])
class BatchNorm2(Module): for bn, inp in [(torch.nn.BatchNorm2d(16), inp_2d),
def __init__(self): (torch.nn.BatchNorm3d(16), inp_3d)]:
super(BatchNorm2, self).__init__() init_weight(bn.eval())
self.batch_norm = torch.nn.BatchNorm2d(3, affine=False) verify_model(bn.eval(), input_data=inp)
def forward(self, *args):
return self.batch_norm(args[0])
input_data = torch.rand(input_shape).float()
verify_model(BatchNorm1().float().eval(), input_data=input_data)
verify_model(BatchNorm2().float().eval(), input_data=input_data)
def test_forward_transpose(): def test_forward_transpose():
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
...@@ -708,6 +701,37 @@ def test_to(): ...@@ -708,6 +701,37 @@ def test_to():
verify_model(ToInt().eval(), torch.tensor(2.0)) verify_model(ToInt().eval(), torch.tensor(2.0))
def test_adaptive_pool3d():
for ishape in [(1, 32, 16, 16, 16),
(1, 32, 9, 15, 15),
(1, 32, 13, 7, 7)]:
inp = torch.rand(ishape)
verify_model(torch.nn.AdaptiveMaxPool3d((1, 1, 1)).eval(), inp)
verify_model(torch.nn.AdaptiveMaxPool3d((2, 2, 2)).eval(), inp)
verify_model(torch.nn.AdaptiveAvgPool3d((1, 1, 1)).eval(), inp)
verify_model(torch.nn.AdaptiveAvgPool3d((2, 2, 2)).eval(), inp)
verify_model(torch.nn.AdaptiveAvgPool3d((4, 8, 8)).eval(), inp)
verify_model(torch.nn.AdaptiveMaxPool3d((7, 8, 9)).eval(), inp)
def test_conv3d():
for ishape in [(1, 32, 16, 16, 16),
(1, 32, 9, 15, 15),
(1, 32, 13, 7, 7)]:
inp = torch.rand(ishape)
verify_model(torch.nn.Conv3d(32, 16, (3, 3, 3),
padding=(1, 1, 1)).eval(),
inp),
verify_model(torch.nn.Conv3d(32, 16, (5, 5, 5),
padding=(2, 2, 2)).eval(),
inp),
verify_model(torch.nn.Conv3d(32, 16, kernel_size=1).eval(),
inp)
# downsample
verify_model(torch.nn.Conv3d(32, 16, kernel_size=1, stride=2).eval(),
inp)
# Model tests # Model tests
def test_resnet18(): def test_resnet18():
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
...@@ -809,6 +833,12 @@ def test_segmentaton_models(): ...@@ -809,6 +833,12 @@ def test_segmentaton_models():
verify_model(SegmentationModelWrapper(deeplab.eval()), inp, [cuda_ctx]) verify_model(SegmentationModelWrapper(deeplab.eval()), inp, [cuda_ctx])
def test_3d_models():
input_shape = (1, 3, 4, 56, 56)
resnet3d = torchvision.models.video.r3d_18(pretrained=True).eval()
verify_model(resnet3d, [torch.rand(input_shape)])
def verify_script_model(pt_model, ishapes): def verify_script_model(pt_model, ishapes):
script_module = torch.jit.script(pt_model) script_module = torch.jit.script(pt_model)
input_names = get_graph_input_names(script_module) input_names = get_graph_input_names(script_module)
...@@ -1021,13 +1051,17 @@ if __name__ == "__main__": ...@@ -1021,13 +1051,17 @@ if __name__ == "__main__":
test_forward_chunk() test_forward_chunk()
test_upsample() test_upsample()
test_to() test_to()
test_adaptive_pool3d()
test_conv3d()
# Model tests # Model tests
test_resnet18() test_resnet18()
test_squeezenet1_0() test_squeezenet1_0()
test_squeezenet1_1() test_squeezenet1_1()
test_densenet121() test_densenet121()
test_inception_v3() # disable inception test for now, since loading it takes ~5min on torchvision-0.5 due to scipy bug
# See https://discuss.pytorch.org/t/torchvisions-inception-v3-takes-much-longer-to-load-than-other-models/68756
# test_inception_v3()
test_googlenet() test_googlenet()
test_mnasnet0_5() test_mnasnet0_5()
test_mobilenet_v2() test_mobilenet_v2()
...@@ -1035,6 +1069,7 @@ if __name__ == "__main__": ...@@ -1035,6 +1069,7 @@ if __name__ == "__main__":
test_custom_conversion_map() test_custom_conversion_map()
test_segmentaton_models() test_segmentaton_models()
test_3d_models()
# Quantization test # Quantization test
from qnn_test import test_quantized_imagenet, test_quantized_modules from qnn_test import test_quantized_imagenet, test_quantized_modules
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment