Unverified Commit 92a24278 by masahi Committed by GitHub

[Torch] Upsampling op support and enable registering a user defined op conversion map (#4961)

* add custom conversion map

* add roi align test using custom convert map

* refactor test

* add support for upsampling op and test on segmentation models

* remove redundant no_grad

* add upsampling test case

* make the default custom map None, instead of empty dict

* updated tests, remove packaging and drop PT 1.2 support

* add better support for aten::to and tests

* add a note on dilation in x86
parent 474c70d7
...@@ -19,7 +19,6 @@ ...@@ -19,7 +19,6 @@
# pylint: disable=import-outside-toplevel, simplifiable-if-expression, unnecessary-comprehension # pylint: disable=import-outside-toplevel, simplifiable-if-expression, unnecessary-comprehension
"""PT: PyTorch frontend.""" """PT: PyTorch frontend."""
import itertools import itertools
from packaging import version
import numpy as np import numpy as np
...@@ -31,6 +30,7 @@ from .. import expr as _expr ...@@ -31,6 +30,7 @@ from .. import expr as _expr
from .. import op as _op from .. import op as _op
from .common import get_relay_op from .common import get_relay_op
from .common import infer_shape as _infer_shape from .common import infer_shape as _infer_shape
from .common import infer_value as _infer_value
__all__ = ["from_pytorch"] __all__ = ["from_pytorch"]
...@@ -614,6 +614,61 @@ def _sqrt(): ...@@ -614,6 +614,61 @@ def _sqrt():
return _op.tensor.sqrt(data) return _op.tensor.sqrt(data)
return _impl return _impl
def _floor():
def _impl(inputs, input_types):
data = inputs[0]
return _op.floor(data)
return _impl
def _to():
def _impl(inputs, input_types):
data = inputs[0]
if inputs[3] in ["cpu", "cuda"]:
return data
# special handling for aten::to(data, 6, _, _, _) case
# 6 means dtype = float
# this happens when converting upsampling with scale factor
cast_func = {
6: float,
3: int,
}
cast_func_expr = {
6: lambda x: _op.cast(x, "float32"),
3: lambda x: _op.cast(x, "int32"),
}
if inputs[1] in cast_func and not isinstance(data, _expr.Expr):
return cast_func[inputs[1]](data)
elif inputs[1] in cast_func and isinstance(data, _expr.Expr):
return cast_func_expr[inputs[1]](data)
return data
return _impl
def _upsample(method):
def _impl(inputs, input_types):
if isinstance(inputs[1], _expr.Var):
out_size = _infer_shape(inputs[1])
elif isinstance(inputs[1], list):
infer_res = [_infer_value(size, {}) for size in inputs[1]]
out_size = [np.asscalar(res.asnumpy().astype(np.int))
for res in infer_res]
data = inputs[0]
if len(inputs) > 2:
align_corners = inputs[2]
else:
align_corners = False
if align_corners:
coord_trans = "align_corners"
else:
coord_trans = "half_pixel"
return _op.image.resize(data, out_size, "NCHW", method, coord_trans)
return _impl
# Helper functions for operator implementation # Helper functions for operator implementation
def _convert_data_type(input_type): def _convert_data_type(input_type):
...@@ -686,7 +741,7 @@ _convert_map = { ...@@ -686,7 +741,7 @@ _convert_map = {
"aten::div_" : _elemwise("divide"), "aten::div_" : _elemwise("divide"),
"aten::ones" : _ones(), "aten::ones" : _ones(),
"aten::zeros" : _zeros(), "aten::zeros" : _zeros(),
"aten::to" : _identity(), "aten::to" : _to(),
"aten::unsqueeze" : _unsqueeze(), "aten::unsqueeze" : _unsqueeze(),
"aten::cat" : _concatenate(), "aten::cat" : _concatenate(),
"aten::slice" : _slice(), "aten::slice" : _slice(),
...@@ -729,15 +784,18 @@ _convert_map = { ...@@ -729,15 +784,18 @@ _convert_map = {
"aten::permute" : _transpose(), "aten::permute" : _transpose(),
"aten::sum" : _reduce("sum"), "aten::sum" : _reduce("sum"),
"aten::prod" : _reduce("prod"), "aten::prod" : _reduce("prod"),
"aten::sqrt" : _sqrt() "aten::sqrt" : _sqrt(),
'aten::floor' : _floor(),
"aten::detach" : _identity(),
"aten::upsample_bilinear2d" : _upsample("bilinear"),
"aten::upsample_nearest2d" : _upsample("nearest_neighbor"),
} }
def _run_jit_passes(graph): def _run_jit_passes(graph):
""" The inline pass is necessary to unwrap prim::CallMethod """ """ The inline pass is necessary to unwrap prim::CallMethod """
import torch import torch
if version.parse(torch.__version__) >= version.parse("1.4.0"): torch._C._jit_pass_inline(graph)
torch._C._jit_pass_inline(graph)
def _is_int_seq(seq): def _is_int_seq(seq):
...@@ -985,8 +1043,7 @@ def parse_operators(operators, outputs, output_index_map, ret_name): ...@@ -985,8 +1043,7 @@ def parse_operators(operators, outputs, output_index_map, ret_name):
def get_all_op_names(graph): def get_all_op_names(graph):
""" Return all operator names in the input graph """ """ Return all operator names in the input graph """
nodes = list(graph.nodes()) return set(node.kind() for node in graph.nodes())
return set(node.kind() for node in nodes)
def get_graph_input_names(script_module): def get_graph_input_names(script_module):
...@@ -997,7 +1054,7 @@ def get_graph_input_names(script_module): ...@@ -997,7 +1054,7 @@ def get_graph_input_names(script_module):
return ir_inputs[1:] # remove self at the 0th arg return ir_inputs[1:] # remove self at the 0th arg
def from_pytorch(script_module, input_shapes): def from_pytorch(script_module, input_shapes, custom_convert_map=None):
""" Load PyTorch model in the form of a scripted PyTorch model and convert into relay. """ Load PyTorch model in the form of a scripted PyTorch model and convert into relay.
The companion parameters will be handled automatically. The companion parameters will be handled automatically.
...@@ -1011,6 +1068,9 @@ def from_pytorch(script_module, input_shapes): ...@@ -1011,6 +1068,9 @@ def from_pytorch(script_module, input_shapes):
Graph level input shape dictionary Graph level input shape dictionary
The keys should be the same one returned by get_graph_input_names(...) above The keys should be the same one returned by get_graph_input_names(...) above
custom_convert_map: Dictionary of str to Relay op
A custom op conversion map in the same format as _convert_map above
Returns Returns
------- -------
mod : tvm.relay.Module mod : tvm.relay.Module
...@@ -1021,6 +1081,10 @@ def from_pytorch(script_module, input_shapes): ...@@ -1021,6 +1081,10 @@ def from_pytorch(script_module, input_shapes):
""" """
graph = script_module.graph.copy() graph = script_module.graph.copy()
_run_jit_passes(graph) _run_jit_passes(graph)
if custom_convert_map:
_convert_map.update(custom_convert_map)
op_names = get_all_op_names(graph) op_names = get_all_op_names(graph)
_report_missing_conversion(op_names) _report_missing_conversion(op_names)
......
...@@ -17,15 +17,12 @@ ...@@ -17,15 +17,12 @@
# pylint: disable=import-self, invalid-name, unused-argument # pylint: disable=import-self, invalid-name, unused-argument
"""Unit tests for various models and operators""" """Unit tests for various models and operators"""
from time import time from time import time
import os
import sys import sys
from tempfile import TemporaryDirectory
from scipy.stats import t as tdistr from scipy.stats import t as tdistr
import numpy as np import numpy as np
import torch import torch
from torch.nn import Module from torch.nn import Module
import tvm import tvm
from tvm import te
import torchvision import torchvision
from tvm import relay from tvm import relay
...@@ -36,22 +33,6 @@ from tvm.relay.frontend.pytorch import get_graph_input_names ...@@ -36,22 +33,6 @@ from tvm.relay.frontend.pytorch import get_graph_input_names
sys.setrecursionlimit(10000) sys.setrecursionlimit(10000)
def _vectorize(ten):
return ten.reshape(-1)
def atol(tru, est):
def _atol_elt(tru, est):
return abs(tru - est)
tru = _vectorize(tru)
est = _vectorize(est)
return max([_atol_elt(x, y) for x, y in zip(tru, est)])
def rtol(tru, est):
def _rtol_elt(tru, est):
return abs(tru - est) / min(abs(tru), abs(est))
tru = _vectorize(tru)
est = _vectorize(est)
return max([_rtol_elt(x, y) for x, y in zip(tru, est)])
def assert_shapes_match(tru, est): def assert_shapes_match(tru, est):
if tru.shape != est.shape: if tru.shape != est.shape:
...@@ -77,7 +58,7 @@ def load_torchvision(model_name): ...@@ -77,7 +58,7 @@ def load_torchvision(model_name):
input_data[:, channel] /= std[channel] input_data[:, channel] /= std[channel]
model = getattr(torchvision.models, model_name)(pretrained=True) model = getattr(torchvision.models, model_name)(pretrained=True)
model = model.float().eval() model = model.float().eval()
return model, input_data return model, [input_data]
def load_pretrainedmodels(model_name): def load_pretrainedmodels(model_name):
"""Given a model name, returns a pretrainedmodels.pytorch model in eval """Given a model name, returns a pretrainedmodels.pytorch model in eval
...@@ -89,7 +70,7 @@ def load_pretrainedmodels(model_name): ...@@ -89,7 +70,7 @@ def load_pretrainedmodels(model_name):
for channel in range(3): for channel in range(3):
input_data[:, channel] -= model.mean[channel] input_data[:, channel] -= model.mean[channel]
input_data[:, channel] /= model.std[channel] input_data[:, channel] /= model.std[channel]
return model, input_data return model, [input_data]
def load_model(model_name): def load_model(model_name):
"""Given a model name, returns a model as well as an example input.""" """Given a model name, returns a model as well as an example input."""
...@@ -116,7 +97,7 @@ def measure_latency(model, input_shapes, output_shapes, thresh, dryruns=40): ...@@ -116,7 +97,7 @@ def measure_latency(model, input_shapes, output_shapes, thresh, dryruns=40):
latencies = [] latencies = []
count = 0 count = 0
while True: while True:
if isinstance(model, torch.nn.Module): if isinstance(model, Module):
input_data = [torch.rand(shape).float() for shape in input_shapes] input_data = [torch.rand(shape).float() for shape in input_shapes]
if torch.cuda.is_available(): if torch.cuda.is_available():
input_data = list(map(lambda x: x.cuda(), input_data)) input_data = list(map(lambda x: x.cuda(), input_data))
...@@ -153,23 +134,34 @@ def measure_latency(model, input_shapes, output_shapes, thresh, dryruns=40): ...@@ -153,23 +134,34 @@ def measure_latency(model, input_shapes, output_shapes, thresh, dryruns=40):
if err < thresh: if err < thresh:
return est return est
def verify_model(model_name, input_data=[]): def verify_model(model_name, input_data=[],
custom_convert_map={},
ctx_list=ctx_list()):
"""Assert that the output of a compiled model matches with that of its """Assert that the output of a compiled model matches with that of its
baseline.""" baseline."""
if len(input_data) == 0: if isinstance(model_name, str):
baseline_model, baseline_input = load_model(model_name) baseline_model, baseline_input = load_model(model_name)
else: elif isinstance(input_data, list):
baseline_model = model_name baseline_model = model_name
baseline_input = input_data baseline_input = input_data
elif isinstance(input_data, torch.Tensor) or len(input_data.shape) == 0:
baseline_model = model_name
baseline_input = [input_data]
else:
assert False, "Unexpected input format"
if torch.cuda.is_available(): if torch.cuda.is_available():
baseline_model = baseline_model.cuda() baseline_model = baseline_model.cuda()
baseline_input = baseline_input.cuda() baseline_input = [inp.cuda() for inp in baseline_input]
with torch.no_grad(): with torch.no_grad():
baseline_outputs = baseline_model(baseline_input) baseline_outputs = baseline_model(*baseline_input)
if isinstance(baseline_outputs, tuple): if isinstance(baseline_outputs, tuple):
baseline_outputs = tuple(out.cpu().numpy() for out in baseline_outputs) baseline_outputs = tuple(out.cpu().numpy() for out in baseline_outputs)
else: else:
baseline_outputs = (baseline_outputs.float().cpu().numpy(),) baseline_outputs = (baseline_outputs.float().cpu().numpy(),)
trace = torch.jit.trace(baseline_model, baseline_input).float().eval() trace = torch.jit.trace(baseline_model, baseline_input).float().eval()
if torch.cuda.is_available(): if torch.cuda.is_available():
...@@ -177,17 +169,21 @@ def verify_model(model_name, input_data=[]): ...@@ -177,17 +169,21 @@ def verify_model(model_name, input_data=[]):
else: else:
trace = trace.cpu() trace = trace.cpu()
input_name = get_graph_input_names(trace)[0] # only one input input_names = get_graph_input_names(trace)
input_shapes = {input_name: list(baseline_input.shape)} input_shapes = dict(zip(input_names,
mod, params = relay.frontend.from_pytorch(trace, input_shapes) [inp.shape for inp in baseline_input]))
compiled_input = {input_name: tvm.nd.array(baseline_input.cpu().numpy())} mod, params = relay.frontend.from_pytorch(trace, input_shapes,
custom_convert_map)
compiled_input = dict(zip(input_names,
[inp.cpu().numpy() for inp in baseline_input]))
with relay.build_config(opt_level=3): with relay.build_config(opt_level=3):
for target, ctx in ctx_list(): for target, ctx in ctx_list:
relay_graph, relay_lib, relay_params = relay.build(mod, target=target, params=params) relay_graph, relay_lib, relay_params = relay.build(mod, target=target, params=params)
relay_model = graph_runtime.create(relay_graph, relay_lib, ctx) relay_model = graph_runtime.create(relay_graph, relay_lib, ctx)
relay_model.set_input(**relay_params) relay_model.set_input(**relay_params)
relay_model.set_input(**compiled_input) for name, inp in compiled_input.items():
relay_model.set_input(name, inp)
relay_model.run() relay_model.run()
for i, baseline_output in enumerate(baseline_outputs): for i, baseline_output in enumerate(baseline_outputs):
...@@ -228,12 +224,11 @@ def test_forward_add(): ...@@ -228,12 +224,11 @@ def test_forward_add():
ones = ones.cuda() ones = ones.cuda()
return args[0] + ones return args[0] + ones
with torch.no_grad(): input_data = torch.rand(input_shape).float()
input_data = torch.rand(input_shape).float() verify_model(Add1().float().eval(), input_data=input_data)
verify_model(Add1().float().eval(), input_data=input_data) verify_model(Add2().float().eval(), input_data=input_data)
verify_model(Add2().float().eval(), input_data=input_data) verify_model(Add3().float().eval(), input_data=input_data)
verify_model(Add3().float().eval(), input_data=input_data) verify_model(Add4().float().eval(), input_data=input_data)
verify_model(Add4().float().eval(), input_data=input_data)
def test_forward_subtract(): def test_forward_subtract():
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
...@@ -261,12 +256,11 @@ def test_forward_subtract(): ...@@ -261,12 +256,11 @@ def test_forward_subtract():
ones = ones.cuda() ones = ones.cuda()
return args[0] - ones return args[0] - ones
with torch.no_grad(): input_data = torch.rand(input_shape).float()
input_data = torch.rand(input_shape).float() verify_model(Subtract1().float().eval(), input_data=input_data)
verify_model(Subtract1().float().eval(), input_data=input_data) verify_model(Subtract2().float().eval(), input_data=input_data)
verify_model(Subtract2().float().eval(), input_data=input_data) verify_model(Subtract3().float().eval(), input_data=input_data)
verify_model(Subtract3().float().eval(), input_data=input_data) verify_model(Subtract4().float().eval(), input_data=input_data)
verify_model(Subtract4().float().eval(), input_data=input_data)
def test_forward_multiply(): def test_forward_multiply():
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
...@@ -294,12 +288,11 @@ def test_forward_multiply(): ...@@ -294,12 +288,11 @@ def test_forward_multiply():
ones = ones.cuda() ones = ones.cuda()
return args[0] * ones return args[0] * ones
with torch.no_grad(): input_data = torch.rand(input_shape).float()
input_data = torch.rand(input_shape).float() verify_model(Multiply1().float().eval(), input_data=input_data)
verify_model(Multiply1().float().eval(), input_data=input_data) verify_model(Multiply2().float().eval(), input_data=input_data)
verify_model(Multiply2().float().eval(), input_data=input_data) verify_model(Multiply3().float().eval(), input_data=input_data)
verify_model(Multiply3().float().eval(), input_data=input_data) verify_model(Multiply4().float().eval(), input_data=input_data)
verify_model(Multiply4().float().eval(), input_data=input_data)
def test_forward_unsqueeze(): def test_forward_unsqueeze():
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
...@@ -327,10 +320,9 @@ def test_forward_concatenate(): ...@@ -327,10 +320,9 @@ def test_forward_concatenate():
c = (args[0][:, :, 2] + 5) * 13 c = (args[0][:, :, 2] + 5) * 13
return torch.cat([t.unsqueeze(2) for t in [a, b, c]], 2) return torch.cat([t.unsqueeze(2) for t in [a, b, c]], 2)
with torch.no_grad(): input_data = torch.rand(input_shape).float()
input_data = torch.rand(input_shape).float() verify_model(Concatenate1().float().eval(), input_data=input_data)
verify_model(Concatenate1().float().eval(), input_data=input_data) verify_model(Concatenate2().float().eval(), input_data=input_data)
verify_model(Concatenate2().float().eval(), input_data=input_data)
def test_forward_relu(): def test_forward_relu():
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
...@@ -340,9 +332,8 @@ def test_forward_relu(): ...@@ -340,9 +332,8 @@ def test_forward_relu():
def forward(self, *args): def forward(self, *args):
return torch.nn.ReLU()(args[0]) return torch.nn.ReLU()(args[0])
with torch.no_grad(): input_data = torch.rand(input_shape).float()
input_data = torch.rand(input_shape).float() verify_model(ReLU1().float().eval(), input_data=input_data)
verify_model(ReLU1().float().eval(), input_data=input_data)
def test_forward_adaptiveavgpool(): def test_forward_adaptiveavgpool():
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
...@@ -356,10 +347,9 @@ def test_forward_adaptiveavgpool(): ...@@ -356,10 +347,9 @@ def test_forward_adaptiveavgpool():
def forward(self, *args): def forward(self, *args):
return torch.nn.AdaptiveAvgPool2d([10, 10])(args[0]) return torch.nn.AdaptiveAvgPool2d([10, 10])(args[0])
with torch.no_grad(): input_data = torch.rand(input_shape).float()
input_data = torch.rand(input_shape).float() verify_model(AdaptiveAvgPool2D1().float().eval(), input_data=input_data)
verify_model(AdaptiveAvgPool2D1().float().eval(), input_data=input_data) verify_model(AdaptiveAvgPool2D2().float().eval(), input_data=input_data)
verify_model(AdaptiveAvgPool2D2().float().eval(), input_data=input_data)
def test_forward_maxpool(): def test_forward_maxpool():
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
...@@ -373,10 +363,9 @@ def test_forward_maxpool(): ...@@ -373,10 +363,9 @@ def test_forward_maxpool():
def forward(self, *args): def forward(self, *args):
return torch.nn.MaxPool2d(kernel_size=[10, 10])(args[0]) return torch.nn.MaxPool2d(kernel_size=[10, 10])(args[0])
with torch.no_grad(): input_data = torch.rand(input_shape).float()
input_data = torch.rand(input_shape).float() verify_model(MaxPool2D1().float().eval(), input_data=input_data)
verify_model(MaxPool2D1().float().eval(), input_data=input_data) verify_model(MaxPool2D2().float().eval(), input_data=input_data)
verify_model(MaxPool2D2().float().eval(), input_data=input_data)
def test_forward_avgpool(): def test_forward_avgpool():
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
...@@ -386,9 +375,8 @@ def test_forward_avgpool(): ...@@ -386,9 +375,8 @@ def test_forward_avgpool():
def forward(self, *args): def forward(self, *args):
return torch.nn.AvgPool2d(kernel_size=[10, 10])(args[0]) return torch.nn.AvgPool2d(kernel_size=[10, 10])(args[0])
with torch.no_grad(): input_data = torch.rand(input_shape).float()
input_data = torch.rand(input_shape).float() verify_model(AvgPool2D1().float().eval(), input_data=input_data)
verify_model(AvgPool2D1().float().eval(), input_data=input_data)
def test_forward_hardtanh(): def test_forward_hardtanh():
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
...@@ -398,9 +386,8 @@ def test_forward_hardtanh(): ...@@ -398,9 +386,8 @@ def test_forward_hardtanh():
def forward(self, *args): def forward(self, *args):
return torch.nn.Hardtanh()(args[0]) return torch.nn.Hardtanh()(args[0])
with torch.no_grad(): input_data = torch.rand(input_shape).float()
input_data = torch.rand(input_shape).float() verify_model(HardTanh1().float().eval(), input_data=input_data)
verify_model(HardTanh1().float().eval(), input_data=input_data)
def test_forward_conv(): def test_forward_conv():
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
...@@ -433,11 +420,10 @@ def test_forward_conv(): ...@@ -433,11 +420,10 @@ def test_forward_conv():
def forward(self, *args): def forward(self, *args):
return self.softmax(self.conv(args[0])) return self.softmax(self.conv(args[0]))
with torch.no_grad(): input_data = torch.rand(input_shape).float()
input_data = torch.rand(input_shape).float() verify_model(Conv2D1().float().eval(), input_data=input_data)
verify_model(Conv2D1().float().eval(), input_data=input_data) verify_model(Conv2D2().float().eval(), input_data=input_data)
verify_model(Conv2D2().float().eval(), input_data=input_data) verify_model(Conv2D3().float().eval(), input_data=input_data)
verify_model(Conv2D3().float().eval(), input_data=input_data)
def test_forward_threshold(): def test_forward_threshold():
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
...@@ -447,9 +433,8 @@ def test_forward_threshold(): ...@@ -447,9 +433,8 @@ def test_forward_threshold():
def forward(self, *args): def forward(self, *args):
return torch.nn.Threshold(0, 0)(args[0]) return torch.nn.Threshold(0, 0)(args[0])
with torch.no_grad(): input_data = torch.rand(input_shape).float()
input_data = torch.rand(input_shape).float() verify_model(Threshold1().float().eval(), input_data=input_data)
verify_model(Threshold1().float().eval(), input_data=input_data)
def test_forward_contiguous(): def test_forward_contiguous():
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
...@@ -459,9 +444,8 @@ def test_forward_contiguous(): ...@@ -459,9 +444,8 @@ def test_forward_contiguous():
def forward(self, *args): def forward(self, *args):
return args[0].contiguous() return args[0].contiguous()
with torch.no_grad(): input_data = torch.rand(input_shape).float()
input_data = torch.rand(input_shape).float() verify_model(Contiguous1().float().eval(), input_data=input_data)
verify_model(Contiguous1().float().eval(), input_data=input_data)
def test_forward_batchnorm(): def test_forward_batchnorm():
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
...@@ -481,10 +465,9 @@ def test_forward_batchnorm(): ...@@ -481,10 +465,9 @@ def test_forward_batchnorm():
def forward(self, *args): def forward(self, *args):
return self.batch_norm(args[0]) return self.batch_norm(args[0])
with torch.no_grad(): input_data = torch.rand(input_shape).float()
input_data = torch.rand(input_shape).float() verify_model(BatchNorm1().float().eval(), input_data=input_data)
verify_model(BatchNorm1().float().eval(), input_data=input_data) verify_model(BatchNorm2().float().eval(), input_data=input_data)
verify_model(BatchNorm2().float().eval(), input_data=input_data)
def test_forward_transpose(): def test_forward_transpose():
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
...@@ -498,10 +481,9 @@ def test_forward_transpose(): ...@@ -498,10 +481,9 @@ def test_forward_transpose():
def forward(self, *args): def forward(self, *args):
return args[0].transpose(-2, -1) return args[0].transpose(-2, -1)
with torch.no_grad(): input_data = torch.rand(input_shape).float()
input_data = torch.rand(input_shape).float() verify_model(Transpose1().float().eval(), input_data=input_data)
verify_model(Transpose1().float().eval(), input_data=input_data) verify_model(Transpose2().float().eval(), input_data=input_data)
verify_model(Transpose2().float().eval(), input_data=input_data)
def test_forward_size(): def test_forward_size():
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
...@@ -511,9 +493,8 @@ def test_forward_size(): ...@@ -511,9 +493,8 @@ def test_forward_size():
def forward(self, *args): def forward(self, *args):
return float(args[0].size(0)) * args[0] return float(args[0].size(0)) * args[0]
with torch.no_grad(): input_data = torch.rand(input_shape).float()
input_data = torch.rand(input_shape).float() verify_model(Size1().float().eval(), input_data=input_data)
verify_model(Size1().float().eval(), input_data=input_data)
def test_forward_view(): def test_forward_view():
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
...@@ -527,10 +508,9 @@ def test_forward_view(): ...@@ -527,10 +508,9 @@ def test_forward_view():
def forward(self, *args): def forward(self, *args):
return args[0].view(args[0].shape[0], -1) return args[0].view(args[0].shape[0], -1)
with torch.no_grad(): input_data = torch.rand(input_shape).float()
input_data = torch.rand(input_shape).float() verify_model(View1().float().eval(), input_data=input_data)
verify_model(View1().float().eval(), input_data=input_data) verify_model(View2().float().eval(), input_data=input_data)
verify_model(View2().float().eval(), input_data=input_data)
def test_forward_select(): def test_forward_select():
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
...@@ -540,9 +520,8 @@ def test_forward_select(): ...@@ -540,9 +520,8 @@ def test_forward_select():
def forward(self, *args): def forward(self, *args):
return args[0].select(1, 1) return args[0].select(1, 1)
with torch.no_grad(): input_data = torch.rand(input_shape).float()
input_data = torch.rand(input_shape).float() verify_model(Select1().float().eval(), input_data=input_data)
verify_model(Select1().float().eval(), input_data=input_data)
def test_forward_clone(): def test_forward_clone():
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
...@@ -552,9 +531,8 @@ def test_forward_clone(): ...@@ -552,9 +531,8 @@ def test_forward_clone():
def forward(self, *args): def forward(self, *args):
return args[0].clone() return args[0].clone()
with torch.no_grad(): input_data = torch.rand(input_shape).float()
input_data = torch.rand(input_shape).float() verify_model(Clone1().float().eval(), input_data=input_data)
verify_model(Clone1().float().eval(), input_data=input_data)
def test_forward_logsoftmax(): def test_forward_logsoftmax():
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
...@@ -564,9 +542,8 @@ def test_forward_logsoftmax(): ...@@ -564,9 +542,8 @@ def test_forward_logsoftmax():
def forward(self, *args): def forward(self, *args):
return torch.nn.LogSoftmax(dim=1)(args[0][0, 0]) return torch.nn.LogSoftmax(dim=1)(args[0][0, 0])
with torch.no_grad(): input_data = torch.rand(input_shape).float()
input_data = torch.rand(input_shape).float() verify_model(LogSoftmax1().float().eval(), input_data=input_data)
verify_model(LogSoftmax1().float().eval(), input_data=input_data)
def test_forward_sigmoid(): def test_forward_sigmoid():
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
...@@ -576,9 +553,8 @@ def test_forward_sigmoid(): ...@@ -576,9 +553,8 @@ def test_forward_sigmoid():
def forward(self, *args): def forward(self, *args):
return torch.nn.Sigmoid()(args[0]) return torch.nn.Sigmoid()(args[0])
with torch.no_grad(): input_data = torch.rand(input_shape).float()
input_data = torch.rand(input_shape).float() verify_model(Sigmoid1().float().eval(), input_data=input_data)
verify_model(Sigmoid1().float().eval(), input_data=input_data)
def test_forward_dense(): def test_forward_dense():
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
...@@ -598,10 +574,9 @@ def test_forward_dense(): ...@@ -598,10 +574,9 @@ def test_forward_dense():
def forward(self, *args): def forward(self, *args):
return self.linear(args[0][0, 0]) return self.linear(args[0][0, 0])
with torch.no_grad(): input_data = torch.rand(input_shape).float()
input_data = torch.rand(input_shape).float() verify_model(Dense1().float().eval(), input_data=input_data)
verify_model(Dense1().float().eval(), input_data=input_data) verify_model(Dense2().float().eval(), input_data=input_data)
verify_model(Dense2().float().eval(), input_data=input_data)
def test_forward_dropout(): def test_forward_dropout():
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
...@@ -611,9 +586,8 @@ def test_forward_dropout(): ...@@ -611,9 +586,8 @@ def test_forward_dropout():
def forward(self, *args): def forward(self, *args):
return torch.nn.functional.dropout(args[0][0, 0], 0.5, False) return torch.nn.functional.dropout(args[0][0, 0], 0.5, False)
with torch.no_grad(): input_data = torch.rand(input_shape).float()
input_data = torch.rand(input_shape).float() verify_model(Dropout1().float().eval(), input_data=input_data)
verify_model(Dropout1().float().eval(), input_data=input_data)
def test_forward_slice(): def test_forward_slice():
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
...@@ -627,10 +601,9 @@ def test_forward_slice(): ...@@ -627,10 +601,9 @@ def test_forward_slice():
def forward(self, *args): def forward(self, *args):
return args[0][0, :, :, :] return args[0][0, :, :, :]
with torch.no_grad(): input_data = torch.rand(input_shape).float()
input_data = torch.rand(input_shape).float() verify_model(Slice1().float().eval(), input_data=input_data)
verify_model(Slice1().float().eval(), input_data=input_data) verify_model(Slice2().float().eval(), input_data=input_data)
verify_model(Slice2().float().eval(), input_data=input_data)
def test_forward_mean(): def test_forward_mean():
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
...@@ -640,9 +613,8 @@ def test_forward_mean(): ...@@ -640,9 +613,8 @@ def test_forward_mean():
def forward(self, *args): def forward(self, *args):
return args[0].mean(2) return args[0].mean(2)
with torch.no_grad(): input_data = torch.rand(input_shape).float()
input_data = torch.rand(input_shape).float() verify_model(Mean1().float().eval(), input_data=input_data)
verify_model(Mean1().float().eval(), input_data=input_data)
def test_forward_expand(): def test_forward_expand():
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
...@@ -652,9 +624,8 @@ def test_forward_expand(): ...@@ -652,9 +624,8 @@ def test_forward_expand():
def forward(self, *args): def forward(self, *args):
return args[0].expand((3, -1, -1, -1)) return args[0].expand((3, -1, -1, -1))
with torch.no_grad(): input_data = torch.rand(input_shape).float()
input_data = torch.rand(input_shape).float() verify_model(Expand1().float().eval(), input_data=input_data)
verify_model(Expand1().float().eval(), input_data=input_data)
def test_forward_pow(): def test_forward_pow():
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
...@@ -664,9 +635,8 @@ def test_forward_pow(): ...@@ -664,9 +635,8 @@ def test_forward_pow():
def forward(self, *args): def forward(self, *args):
return args[0] ** 2 return args[0] ** 2
with torch.no_grad(): input_data = torch.rand(input_shape).float()
input_data = torch.rand(input_shape).float() verify_model(Pow1().float().eval(), input_data=input_data)
verify_model(Pow1().float().eval(), input_data=input_data)
def test_forward_chunk(): def test_forward_chunk():
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
...@@ -677,9 +647,61 @@ def test_forward_chunk(): ...@@ -677,9 +647,61 @@ def test_forward_chunk():
chunks = args[0].chunk(7, 2) chunks = args[0].chunk(7, 2)
return torch.cat(chunks, 2) return torch.cat(chunks, 2)
with torch.no_grad(): input_data = torch.rand(input_shape).float()
input_data = torch.rand(input_shape).float() verify_model(Chunk1().float().eval(), input_data=input_data)
verify_model(Chunk1().float().eval(), input_data=input_data)
def test_upsample():
class Upsample(Module):
def __init__(self, size=None, scale=None,
mode="nearest", align_corners=None):
super().__init__()
self.size = size
self.scale = scale
self.mode = mode
self.align_corners = align_corners
def forward(self, x):
return torch.nn.functional.interpolate(x, size=self.size,
scale_factor=self.scale,
mode=self.mode,
align_corners=self.align_corners)
inp = torch.rand((1, 3, 32, 32))
verify_model(Upsample(size=(64, 64), mode="nearest"), inp)
verify_model(Upsample(scale=2, mode="nearest"), inp)
verify_model(Upsample(size=(50, 50), mode="nearest"), inp)
verify_model(Upsample(size=(64, 64), mode="bilinear", align_corners=True), inp)
verify_model(Upsample(scale=2, mode="bilinear", align_corners=True), inp)
verify_model(Upsample(size=(50, 50), mode="bilinear", align_corners=True), inp)
def test_to():
""" test for aten::to(...) """
class ToCPU(Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x.to("cpu")
class ToFloat(Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x.float()
class ToInt(Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x.int()
verify_model(ToCPU().eval(), torch.rand((1, 3, 32, 32)))
verify_model(ToFloat().eval(), torch.zeros((1, 3, 32, 32), dtype=torch.int))
verify_model(ToFloat().eval(), torch.tensor(2, dtype=torch.int))
verify_model(ToInt().eval(), torch.zeros((1, 3, 32, 32)))
verify_model(ToInt().eval(), torch.tensor(2.0))
# Model tests # Model tests
def test_resnet18(): def test_resnet18():
...@@ -730,6 +752,57 @@ def test_vgg11_bn(): ...@@ -730,6 +752,57 @@ def test_vgg11_bn():
""" """
def test_custom_conversion_map():
def get_roi_align():
pool_size = 5
n_channels = 2 * (pool_size ** 2)
x = torch.rand(2, n_channels, 10, 10)
rois = torch.tensor([[0, 0, 0, 9, 9], # format is (xyxy)
[0, 0, 5, 4, 9],
[0, 5, 5, 9, 9],
[1, 0, 0, 9, 9]], dtype=torch.float)
roi_align = torchvision.ops.RoIAlign(pool_size, spatial_scale=1,
sampling_ratio=-1)
return roi_align.eval(), [x, rois]
def convert_roi_align():
def _impl(inputs, input_types):
spatial_scale = inputs[2]
pooled_size = (inputs[3], inputs[4])
sampling_ratio = inputs[5]
return relay.op.vision.roi_align(inputs[0], inputs[1],
pooled_size, spatial_scale,
sampling_ratio)
return _impl
custom_map = {'torchvision::roi_align': convert_roi_align()}
model, inputs = get_roi_align()
verify_model(model, inputs, custom_map)
def test_segmentaton_models():
class SegmentationModelWrapper(Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, inp):
out = self.model(inp)
return out["out"]
fcn = torchvision.models.segmentation.fcn_resnet101(pretrained=True)
deeplab = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=True)
inp = [torch.rand((1, 3, 300, 300), dtype=torch.float)]
for model in [fcn, deeplab]:
# depthwise + dilated covolution not supported on x86
# see https://github.com/apache/incubator-tvm/issues/4962
verify_model(SegmentationModelWrapper(model.eval()), inp,
ctx_list=[("cuda", tvm.gpu(0))])
if __name__ == "__main__": if __name__ == "__main__":
# Single operator tests # Single operator tests
test_forward_add() test_forward_add()
...@@ -760,6 +833,8 @@ if __name__ == "__main__": ...@@ -760,6 +833,8 @@ if __name__ == "__main__":
test_forward_expand() test_forward_expand()
test_forward_pow() test_forward_pow()
test_forward_chunk() test_forward_chunk()
test_upsample()
test_to()
# Model tests # Model tests
test_resnet18() test_resnet18()
...@@ -770,3 +845,7 @@ if __name__ == "__main__": ...@@ -770,3 +845,7 @@ if __name__ == "__main__":
test_googlenet() test_googlenet()
test_mnasnet0_5() test_mnasnet0_5()
test_mobilenet_v2() test_mobilenet_v2()
test_custom_conversion_map()
test_segmentaton_models()
...@@ -37,7 +37,7 @@ https://pytorch.org/get-started/locally/ ...@@ -37,7 +37,7 @@ https://pytorch.org/get-started/locally/
PyTorch versions should be backwards compatible but should be used PyTorch versions should be backwards compatible but should be used
with the proper TorchVision version. with the proper TorchVision version.
Currently, TVM supports PyTorch 1.4, 1.3, and 1.2. Other versions may Currently, TVM supports PyTorch 1.4 and 1.3. Other versions may
be unstable. be unstable.
""" """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment