Unverified Commit 92a24278 by masahi Committed by GitHub

[Torch] Upsampling op support and enable registering a user defined op conversion map (#4961)

* add custom conversion map

* add roi align test using custom convert map

* refactor test

* add support for upsampling op and test on segmentation models

* remove redundant no_grad

* add upsampling test case

* make the default custom map None, instead of empty dict

* updated tests, remove packaging and drop PT 1.2 support

* add better support for aten::to and tests

* add a note on dilation in x86
parent 474c70d7
......@@ -19,7 +19,6 @@
# pylint: disable=import-outside-toplevel, simplifiable-if-expression, unnecessary-comprehension
"""PT: PyTorch frontend."""
import itertools
from packaging import version
import numpy as np
......@@ -31,6 +30,7 @@ from .. import expr as _expr
from .. import op as _op
from .common import get_relay_op
from .common import infer_shape as _infer_shape
from .common import infer_value as _infer_value
__all__ = ["from_pytorch"]
......@@ -614,6 +614,61 @@ def _sqrt():
return _op.tensor.sqrt(data)
return _impl
def _floor():
def _impl(inputs, input_types):
data = inputs[0]
return _op.floor(data)
return _impl
def _to():
def _impl(inputs, input_types):
data = inputs[0]
if inputs[3] in ["cpu", "cuda"]:
return data
# special handling for aten::to(data, 6, _, _, _) case
# 6 means dtype = float
# this happens when converting upsampling with scale factor
cast_func = {
6: float,
3: int,
}
cast_func_expr = {
6: lambda x: _op.cast(x, "float32"),
3: lambda x: _op.cast(x, "int32"),
}
if inputs[1] in cast_func and not isinstance(data, _expr.Expr):
return cast_func[inputs[1]](data)
elif inputs[1] in cast_func and isinstance(data, _expr.Expr):
return cast_func_expr[inputs[1]](data)
return data
return _impl
def _upsample(method):
def _impl(inputs, input_types):
if isinstance(inputs[1], _expr.Var):
out_size = _infer_shape(inputs[1])
elif isinstance(inputs[1], list):
infer_res = [_infer_value(size, {}) for size in inputs[1]]
out_size = [np.asscalar(res.asnumpy().astype(np.int))
for res in infer_res]
data = inputs[0]
if len(inputs) > 2:
align_corners = inputs[2]
else:
align_corners = False
if align_corners:
coord_trans = "align_corners"
else:
coord_trans = "half_pixel"
return _op.image.resize(data, out_size, "NCHW", method, coord_trans)
return _impl
# Helper functions for operator implementation
def _convert_data_type(input_type):
......@@ -686,7 +741,7 @@ _convert_map = {
"aten::div_" : _elemwise("divide"),
"aten::ones" : _ones(),
"aten::zeros" : _zeros(),
"aten::to" : _identity(),
"aten::to" : _to(),
"aten::unsqueeze" : _unsqueeze(),
"aten::cat" : _concatenate(),
"aten::slice" : _slice(),
......@@ -729,15 +784,18 @@ _convert_map = {
"aten::permute" : _transpose(),
"aten::sum" : _reduce("sum"),
"aten::prod" : _reduce("prod"),
"aten::sqrt" : _sqrt()
"aten::sqrt" : _sqrt(),
'aten::floor' : _floor(),
"aten::detach" : _identity(),
"aten::upsample_bilinear2d" : _upsample("bilinear"),
"aten::upsample_nearest2d" : _upsample("nearest_neighbor"),
}
def _run_jit_passes(graph):
""" The inline pass is necessary to unwrap prim::CallMethod """
import torch
if version.parse(torch.__version__) >= version.parse("1.4.0"):
torch._C._jit_pass_inline(graph)
torch._C._jit_pass_inline(graph)
def _is_int_seq(seq):
......@@ -985,8 +1043,7 @@ def parse_operators(operators, outputs, output_index_map, ret_name):
def get_all_op_names(graph):
""" Return all operator names in the input graph """
nodes = list(graph.nodes())
return set(node.kind() for node in nodes)
return set(node.kind() for node in graph.nodes())
def get_graph_input_names(script_module):
......@@ -997,7 +1054,7 @@ def get_graph_input_names(script_module):
return ir_inputs[1:] # remove self at the 0th arg
def from_pytorch(script_module, input_shapes):
def from_pytorch(script_module, input_shapes, custom_convert_map=None):
""" Load PyTorch model in the form of a scripted PyTorch model and convert into relay.
The companion parameters will be handled automatically.
......@@ -1011,6 +1068,9 @@ def from_pytorch(script_module, input_shapes):
Graph level input shape dictionary
The keys should be the same one returned by get_graph_input_names(...) above
custom_convert_map: Dictionary of str to Relay op
A custom op conversion map in the same format as _convert_map above
Returns
-------
mod : tvm.relay.Module
......@@ -1021,6 +1081,10 @@ def from_pytorch(script_module, input_shapes):
"""
graph = script_module.graph.copy()
_run_jit_passes(graph)
if custom_convert_map:
_convert_map.update(custom_convert_map)
op_names = get_all_op_names(graph)
_report_missing_conversion(op_names)
......
......@@ -17,15 +17,12 @@
# pylint: disable=import-self, invalid-name, unused-argument
"""Unit tests for various models and operators"""
from time import time
import os
import sys
from tempfile import TemporaryDirectory
from scipy.stats import t as tdistr
import numpy as np
import torch
from torch.nn import Module
import tvm
from tvm import te
import torchvision
from tvm import relay
......@@ -36,22 +33,6 @@ from tvm.relay.frontend.pytorch import get_graph_input_names
sys.setrecursionlimit(10000)
def _vectorize(ten):
return ten.reshape(-1)
def atol(tru, est):
def _atol_elt(tru, est):
return abs(tru - est)
tru = _vectorize(tru)
est = _vectorize(est)
return max([_atol_elt(x, y) for x, y in zip(tru, est)])
def rtol(tru, est):
def _rtol_elt(tru, est):
return abs(tru - est) / min(abs(tru), abs(est))
tru = _vectorize(tru)
est = _vectorize(est)
return max([_rtol_elt(x, y) for x, y in zip(tru, est)])
def assert_shapes_match(tru, est):
if tru.shape != est.shape:
......@@ -77,7 +58,7 @@ def load_torchvision(model_name):
input_data[:, channel] /= std[channel]
model = getattr(torchvision.models, model_name)(pretrained=True)
model = model.float().eval()
return model, input_data
return model, [input_data]
def load_pretrainedmodels(model_name):
"""Given a model name, returns a pretrainedmodels.pytorch model in eval
......@@ -89,7 +70,7 @@ def load_pretrainedmodels(model_name):
for channel in range(3):
input_data[:, channel] -= model.mean[channel]
input_data[:, channel] /= model.std[channel]
return model, input_data
return model, [input_data]
def load_model(model_name):
"""Given a model name, returns a model as well as an example input."""
......@@ -116,7 +97,7 @@ def measure_latency(model, input_shapes, output_shapes, thresh, dryruns=40):
latencies = []
count = 0
while True:
if isinstance(model, torch.nn.Module):
if isinstance(model, Module):
input_data = [torch.rand(shape).float() for shape in input_shapes]
if torch.cuda.is_available():
input_data = list(map(lambda x: x.cuda(), input_data))
......@@ -153,23 +134,34 @@ def measure_latency(model, input_shapes, output_shapes, thresh, dryruns=40):
if err < thresh:
return est
def verify_model(model_name, input_data=[]):
def verify_model(model_name, input_data=[],
custom_convert_map={},
ctx_list=ctx_list()):
"""Assert that the output of a compiled model matches with that of its
baseline."""
if len(input_data) == 0:
if isinstance(model_name, str):
baseline_model, baseline_input = load_model(model_name)
else:
elif isinstance(input_data, list):
baseline_model = model_name
baseline_input = input_data
elif isinstance(input_data, torch.Tensor) or len(input_data.shape) == 0:
baseline_model = model_name
baseline_input = [input_data]
else:
assert False, "Unexpected input format"
if torch.cuda.is_available():
baseline_model = baseline_model.cuda()
baseline_input = baseline_input.cuda()
baseline_input = [inp.cuda() for inp in baseline_input]
with torch.no_grad():
baseline_outputs = baseline_model(baseline_input)
baseline_outputs = baseline_model(*baseline_input)
if isinstance(baseline_outputs, tuple):
baseline_outputs = tuple(out.cpu().numpy() for out in baseline_outputs)
else:
baseline_outputs = (baseline_outputs.float().cpu().numpy(),)
trace = torch.jit.trace(baseline_model, baseline_input).float().eval()
if torch.cuda.is_available():
......@@ -177,17 +169,21 @@ def verify_model(model_name, input_data=[]):
else:
trace = trace.cpu()
input_name = get_graph_input_names(trace)[0] # only one input
input_shapes = {input_name: list(baseline_input.shape)}
mod, params = relay.frontend.from_pytorch(trace, input_shapes)
compiled_input = {input_name: tvm.nd.array(baseline_input.cpu().numpy())}
input_names = get_graph_input_names(trace)
input_shapes = dict(zip(input_names,
[inp.shape for inp in baseline_input]))
mod, params = relay.frontend.from_pytorch(trace, input_shapes,
custom_convert_map)
compiled_input = dict(zip(input_names,
[inp.cpu().numpy() for inp in baseline_input]))
with relay.build_config(opt_level=3):
for target, ctx in ctx_list():
for target, ctx in ctx_list:
relay_graph, relay_lib, relay_params = relay.build(mod, target=target, params=params)
relay_model = graph_runtime.create(relay_graph, relay_lib, ctx)
relay_model.set_input(**relay_params)
relay_model.set_input(**compiled_input)
for name, inp in compiled_input.items():
relay_model.set_input(name, inp)
relay_model.run()
for i, baseline_output in enumerate(baseline_outputs):
......@@ -228,12 +224,11 @@ def test_forward_add():
ones = ones.cuda()
return args[0] + ones
with torch.no_grad():
input_data = torch.rand(input_shape).float()
verify_model(Add1().float().eval(), input_data=input_data)
verify_model(Add2().float().eval(), input_data=input_data)
verify_model(Add3().float().eval(), input_data=input_data)
verify_model(Add4().float().eval(), input_data=input_data)
input_data = torch.rand(input_shape).float()
verify_model(Add1().float().eval(), input_data=input_data)
verify_model(Add2().float().eval(), input_data=input_data)
verify_model(Add3().float().eval(), input_data=input_data)
verify_model(Add4().float().eval(), input_data=input_data)
def test_forward_subtract():
torch.set_grad_enabled(False)
......@@ -261,12 +256,11 @@ def test_forward_subtract():
ones = ones.cuda()
return args[0] - ones
with torch.no_grad():
input_data = torch.rand(input_shape).float()
verify_model(Subtract1().float().eval(), input_data=input_data)
verify_model(Subtract2().float().eval(), input_data=input_data)
verify_model(Subtract3().float().eval(), input_data=input_data)
verify_model(Subtract4().float().eval(), input_data=input_data)
input_data = torch.rand(input_shape).float()
verify_model(Subtract1().float().eval(), input_data=input_data)
verify_model(Subtract2().float().eval(), input_data=input_data)
verify_model(Subtract3().float().eval(), input_data=input_data)
verify_model(Subtract4().float().eval(), input_data=input_data)
def test_forward_multiply():
torch.set_grad_enabled(False)
......@@ -294,12 +288,11 @@ def test_forward_multiply():
ones = ones.cuda()
return args[0] * ones
with torch.no_grad():
input_data = torch.rand(input_shape).float()
verify_model(Multiply1().float().eval(), input_data=input_data)
verify_model(Multiply2().float().eval(), input_data=input_data)
verify_model(Multiply3().float().eval(), input_data=input_data)
verify_model(Multiply4().float().eval(), input_data=input_data)
input_data = torch.rand(input_shape).float()
verify_model(Multiply1().float().eval(), input_data=input_data)
verify_model(Multiply2().float().eval(), input_data=input_data)
verify_model(Multiply3().float().eval(), input_data=input_data)
verify_model(Multiply4().float().eval(), input_data=input_data)
def test_forward_unsqueeze():
torch.set_grad_enabled(False)
......@@ -327,10 +320,9 @@ def test_forward_concatenate():
c = (args[0][:, :, 2] + 5) * 13
return torch.cat([t.unsqueeze(2) for t in [a, b, c]], 2)
with torch.no_grad():
input_data = torch.rand(input_shape).float()
verify_model(Concatenate1().float().eval(), input_data=input_data)
verify_model(Concatenate2().float().eval(), input_data=input_data)
input_data = torch.rand(input_shape).float()
verify_model(Concatenate1().float().eval(), input_data=input_data)
verify_model(Concatenate2().float().eval(), input_data=input_data)
def test_forward_relu():
torch.set_grad_enabled(False)
......@@ -340,9 +332,8 @@ def test_forward_relu():
def forward(self, *args):
return torch.nn.ReLU()(args[0])
with torch.no_grad():
input_data = torch.rand(input_shape).float()
verify_model(ReLU1().float().eval(), input_data=input_data)
input_data = torch.rand(input_shape).float()
verify_model(ReLU1().float().eval(), input_data=input_data)
def test_forward_adaptiveavgpool():
torch.set_grad_enabled(False)
......@@ -356,10 +347,9 @@ def test_forward_adaptiveavgpool():
def forward(self, *args):
return torch.nn.AdaptiveAvgPool2d([10, 10])(args[0])
with torch.no_grad():
input_data = torch.rand(input_shape).float()
verify_model(AdaptiveAvgPool2D1().float().eval(), input_data=input_data)
verify_model(AdaptiveAvgPool2D2().float().eval(), input_data=input_data)
input_data = torch.rand(input_shape).float()
verify_model(AdaptiveAvgPool2D1().float().eval(), input_data=input_data)
verify_model(AdaptiveAvgPool2D2().float().eval(), input_data=input_data)
def test_forward_maxpool():
torch.set_grad_enabled(False)
......@@ -373,10 +363,9 @@ def test_forward_maxpool():
def forward(self, *args):
return torch.nn.MaxPool2d(kernel_size=[10, 10])(args[0])
with torch.no_grad():
input_data = torch.rand(input_shape).float()
verify_model(MaxPool2D1().float().eval(), input_data=input_data)
verify_model(MaxPool2D2().float().eval(), input_data=input_data)
input_data = torch.rand(input_shape).float()
verify_model(MaxPool2D1().float().eval(), input_data=input_data)
verify_model(MaxPool2D2().float().eval(), input_data=input_data)
def test_forward_avgpool():
torch.set_grad_enabled(False)
......@@ -386,9 +375,8 @@ def test_forward_avgpool():
def forward(self, *args):
return torch.nn.AvgPool2d(kernel_size=[10, 10])(args[0])
with torch.no_grad():
input_data = torch.rand(input_shape).float()
verify_model(AvgPool2D1().float().eval(), input_data=input_data)
input_data = torch.rand(input_shape).float()
verify_model(AvgPool2D1().float().eval(), input_data=input_data)
def test_forward_hardtanh():
torch.set_grad_enabled(False)
......@@ -398,9 +386,8 @@ def test_forward_hardtanh():
def forward(self, *args):
return torch.nn.Hardtanh()(args[0])
with torch.no_grad():
input_data = torch.rand(input_shape).float()
verify_model(HardTanh1().float().eval(), input_data=input_data)
input_data = torch.rand(input_shape).float()
verify_model(HardTanh1().float().eval(), input_data=input_data)
def test_forward_conv():
torch.set_grad_enabled(False)
......@@ -433,11 +420,10 @@ def test_forward_conv():
def forward(self, *args):
return self.softmax(self.conv(args[0]))
with torch.no_grad():
input_data = torch.rand(input_shape).float()
verify_model(Conv2D1().float().eval(), input_data=input_data)
verify_model(Conv2D2().float().eval(), input_data=input_data)
verify_model(Conv2D3().float().eval(), input_data=input_data)
input_data = torch.rand(input_shape).float()
verify_model(Conv2D1().float().eval(), input_data=input_data)
verify_model(Conv2D2().float().eval(), input_data=input_data)
verify_model(Conv2D3().float().eval(), input_data=input_data)
def test_forward_threshold():
torch.set_grad_enabled(False)
......@@ -447,9 +433,8 @@ def test_forward_threshold():
def forward(self, *args):
return torch.nn.Threshold(0, 0)(args[0])
with torch.no_grad():
input_data = torch.rand(input_shape).float()
verify_model(Threshold1().float().eval(), input_data=input_data)
input_data = torch.rand(input_shape).float()
verify_model(Threshold1().float().eval(), input_data=input_data)
def test_forward_contiguous():
torch.set_grad_enabled(False)
......@@ -459,9 +444,8 @@ def test_forward_contiguous():
def forward(self, *args):
return args[0].contiguous()
with torch.no_grad():
input_data = torch.rand(input_shape).float()
verify_model(Contiguous1().float().eval(), input_data=input_data)
input_data = torch.rand(input_shape).float()
verify_model(Contiguous1().float().eval(), input_data=input_data)
def test_forward_batchnorm():
torch.set_grad_enabled(False)
......@@ -481,10 +465,9 @@ def test_forward_batchnorm():
def forward(self, *args):
return self.batch_norm(args[0])
with torch.no_grad():
input_data = torch.rand(input_shape).float()
verify_model(BatchNorm1().float().eval(), input_data=input_data)
verify_model(BatchNorm2().float().eval(), input_data=input_data)
input_data = torch.rand(input_shape).float()
verify_model(BatchNorm1().float().eval(), input_data=input_data)
verify_model(BatchNorm2().float().eval(), input_data=input_data)
def test_forward_transpose():
torch.set_grad_enabled(False)
......@@ -498,10 +481,9 @@ def test_forward_transpose():
def forward(self, *args):
return args[0].transpose(-2, -1)
with torch.no_grad():
input_data = torch.rand(input_shape).float()
verify_model(Transpose1().float().eval(), input_data=input_data)
verify_model(Transpose2().float().eval(), input_data=input_data)
input_data = torch.rand(input_shape).float()
verify_model(Transpose1().float().eval(), input_data=input_data)
verify_model(Transpose2().float().eval(), input_data=input_data)
def test_forward_size():
torch.set_grad_enabled(False)
......@@ -511,9 +493,8 @@ def test_forward_size():
def forward(self, *args):
return float(args[0].size(0)) * args[0]
with torch.no_grad():
input_data = torch.rand(input_shape).float()
verify_model(Size1().float().eval(), input_data=input_data)
input_data = torch.rand(input_shape).float()
verify_model(Size1().float().eval(), input_data=input_data)
def test_forward_view():
torch.set_grad_enabled(False)
......@@ -527,10 +508,9 @@ def test_forward_view():
def forward(self, *args):
return args[0].view(args[0].shape[0], -1)
with torch.no_grad():
input_data = torch.rand(input_shape).float()
verify_model(View1().float().eval(), input_data=input_data)
verify_model(View2().float().eval(), input_data=input_data)
input_data = torch.rand(input_shape).float()
verify_model(View1().float().eval(), input_data=input_data)
verify_model(View2().float().eval(), input_data=input_data)
def test_forward_select():
torch.set_grad_enabled(False)
......@@ -540,9 +520,8 @@ def test_forward_select():
def forward(self, *args):
return args[0].select(1, 1)
with torch.no_grad():
input_data = torch.rand(input_shape).float()
verify_model(Select1().float().eval(), input_data=input_data)
input_data = torch.rand(input_shape).float()
verify_model(Select1().float().eval(), input_data=input_data)
def test_forward_clone():
torch.set_grad_enabled(False)
......@@ -552,9 +531,8 @@ def test_forward_clone():
def forward(self, *args):
return args[0].clone()
with torch.no_grad():
input_data = torch.rand(input_shape).float()
verify_model(Clone1().float().eval(), input_data=input_data)
input_data = torch.rand(input_shape).float()
verify_model(Clone1().float().eval(), input_data=input_data)
def test_forward_logsoftmax():
torch.set_grad_enabled(False)
......@@ -564,9 +542,8 @@ def test_forward_logsoftmax():
def forward(self, *args):
return torch.nn.LogSoftmax(dim=1)(args[0][0, 0])
with torch.no_grad():
input_data = torch.rand(input_shape).float()
verify_model(LogSoftmax1().float().eval(), input_data=input_data)
input_data = torch.rand(input_shape).float()
verify_model(LogSoftmax1().float().eval(), input_data=input_data)
def test_forward_sigmoid():
torch.set_grad_enabled(False)
......@@ -576,9 +553,8 @@ def test_forward_sigmoid():
def forward(self, *args):
return torch.nn.Sigmoid()(args[0])
with torch.no_grad():
input_data = torch.rand(input_shape).float()
verify_model(Sigmoid1().float().eval(), input_data=input_data)
input_data = torch.rand(input_shape).float()
verify_model(Sigmoid1().float().eval(), input_data=input_data)
def test_forward_dense():
torch.set_grad_enabled(False)
......@@ -598,10 +574,9 @@ def test_forward_dense():
def forward(self, *args):
return self.linear(args[0][0, 0])
with torch.no_grad():
input_data = torch.rand(input_shape).float()
verify_model(Dense1().float().eval(), input_data=input_data)
verify_model(Dense2().float().eval(), input_data=input_data)
input_data = torch.rand(input_shape).float()
verify_model(Dense1().float().eval(), input_data=input_data)
verify_model(Dense2().float().eval(), input_data=input_data)
def test_forward_dropout():
torch.set_grad_enabled(False)
......@@ -611,9 +586,8 @@ def test_forward_dropout():
def forward(self, *args):
return torch.nn.functional.dropout(args[0][0, 0], 0.5, False)
with torch.no_grad():
input_data = torch.rand(input_shape).float()
verify_model(Dropout1().float().eval(), input_data=input_data)
input_data = torch.rand(input_shape).float()
verify_model(Dropout1().float().eval(), input_data=input_data)
def test_forward_slice():
torch.set_grad_enabled(False)
......@@ -627,10 +601,9 @@ def test_forward_slice():
def forward(self, *args):
return args[0][0, :, :, :]
with torch.no_grad():
input_data = torch.rand(input_shape).float()
verify_model(Slice1().float().eval(), input_data=input_data)
verify_model(Slice2().float().eval(), input_data=input_data)
input_data = torch.rand(input_shape).float()
verify_model(Slice1().float().eval(), input_data=input_data)
verify_model(Slice2().float().eval(), input_data=input_data)
def test_forward_mean():
torch.set_grad_enabled(False)
......@@ -640,9 +613,8 @@ def test_forward_mean():
def forward(self, *args):
return args[0].mean(2)
with torch.no_grad():
input_data = torch.rand(input_shape).float()
verify_model(Mean1().float().eval(), input_data=input_data)
input_data = torch.rand(input_shape).float()
verify_model(Mean1().float().eval(), input_data=input_data)
def test_forward_expand():
torch.set_grad_enabled(False)
......@@ -652,9 +624,8 @@ def test_forward_expand():
def forward(self, *args):
return args[0].expand((3, -1, -1, -1))
with torch.no_grad():
input_data = torch.rand(input_shape).float()
verify_model(Expand1().float().eval(), input_data=input_data)
input_data = torch.rand(input_shape).float()
verify_model(Expand1().float().eval(), input_data=input_data)
def test_forward_pow():
torch.set_grad_enabled(False)
......@@ -664,9 +635,8 @@ def test_forward_pow():
def forward(self, *args):
return args[0] ** 2
with torch.no_grad():
input_data = torch.rand(input_shape).float()
verify_model(Pow1().float().eval(), input_data=input_data)
input_data = torch.rand(input_shape).float()
verify_model(Pow1().float().eval(), input_data=input_data)
def test_forward_chunk():
torch.set_grad_enabled(False)
......@@ -677,9 +647,61 @@ def test_forward_chunk():
chunks = args[0].chunk(7, 2)
return torch.cat(chunks, 2)
with torch.no_grad():
input_data = torch.rand(input_shape).float()
verify_model(Chunk1().float().eval(), input_data=input_data)
input_data = torch.rand(input_shape).float()
verify_model(Chunk1().float().eval(), input_data=input_data)
def test_upsample():
class Upsample(Module):
def __init__(self, size=None, scale=None,
mode="nearest", align_corners=None):
super().__init__()
self.size = size
self.scale = scale
self.mode = mode
self.align_corners = align_corners
def forward(self, x):
return torch.nn.functional.interpolate(x, size=self.size,
scale_factor=self.scale,
mode=self.mode,
align_corners=self.align_corners)
inp = torch.rand((1, 3, 32, 32))
verify_model(Upsample(size=(64, 64), mode="nearest"), inp)
verify_model(Upsample(scale=2, mode="nearest"), inp)
verify_model(Upsample(size=(50, 50), mode="nearest"), inp)
verify_model(Upsample(size=(64, 64), mode="bilinear", align_corners=True), inp)
verify_model(Upsample(scale=2, mode="bilinear", align_corners=True), inp)
verify_model(Upsample(size=(50, 50), mode="bilinear", align_corners=True), inp)
def test_to():
""" test for aten::to(...) """
class ToCPU(Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x.to("cpu")
class ToFloat(Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x.float()
class ToInt(Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x.int()
verify_model(ToCPU().eval(), torch.rand((1, 3, 32, 32)))
verify_model(ToFloat().eval(), torch.zeros((1, 3, 32, 32), dtype=torch.int))
verify_model(ToFloat().eval(), torch.tensor(2, dtype=torch.int))
verify_model(ToInt().eval(), torch.zeros((1, 3, 32, 32)))
verify_model(ToInt().eval(), torch.tensor(2.0))
# Model tests
def test_resnet18():
......@@ -730,6 +752,57 @@ def test_vgg11_bn():
"""
def test_custom_conversion_map():
def get_roi_align():
pool_size = 5
n_channels = 2 * (pool_size ** 2)
x = torch.rand(2, n_channels, 10, 10)
rois = torch.tensor([[0, 0, 0, 9, 9], # format is (xyxy)
[0, 0, 5, 4, 9],
[0, 5, 5, 9, 9],
[1, 0, 0, 9, 9]], dtype=torch.float)
roi_align = torchvision.ops.RoIAlign(pool_size, spatial_scale=1,
sampling_ratio=-1)
return roi_align.eval(), [x, rois]
def convert_roi_align():
def _impl(inputs, input_types):
spatial_scale = inputs[2]
pooled_size = (inputs[3], inputs[4])
sampling_ratio = inputs[5]
return relay.op.vision.roi_align(inputs[0], inputs[1],
pooled_size, spatial_scale,
sampling_ratio)
return _impl
custom_map = {'torchvision::roi_align': convert_roi_align()}
model, inputs = get_roi_align()
verify_model(model, inputs, custom_map)
def test_segmentaton_models():
class SegmentationModelWrapper(Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, inp):
out = self.model(inp)
return out["out"]
fcn = torchvision.models.segmentation.fcn_resnet101(pretrained=True)
deeplab = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=True)
inp = [torch.rand((1, 3, 300, 300), dtype=torch.float)]
for model in [fcn, deeplab]:
# depthwise + dilated covolution not supported on x86
# see https://github.com/apache/incubator-tvm/issues/4962
verify_model(SegmentationModelWrapper(model.eval()), inp,
ctx_list=[("cuda", tvm.gpu(0))])
if __name__ == "__main__":
# Single operator tests
test_forward_add()
......@@ -760,6 +833,8 @@ if __name__ == "__main__":
test_forward_expand()
test_forward_pow()
test_forward_chunk()
test_upsample()
test_to()
# Model tests
test_resnet18()
......@@ -770,3 +845,7 @@ if __name__ == "__main__":
test_googlenet()
test_mnasnet0_5()
test_mobilenet_v2()
test_custom_conversion_map()
test_segmentaton_models()
......@@ -37,7 +37,7 @@ https://pytorch.org/get-started/locally/
PyTorch versions should be backwards compatible but should be used
with the proper TorchVision version.
Currently, TVM supports PyTorch 1.4, 1.3, and 1.2. Other versions may
Currently, TVM supports PyTorch 1.4 and 1.3. Other versions may
be unstable.
"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment