Unverified Commit 92a24278 by masahi Committed by GitHub

[Torch] Upsampling op support and enable registering a user defined op conversion map (#4961)

* add custom conversion map

* add roi align test using custom convert map

* refactor test

* add support for upsampling op and test on segmentation models

* remove redundant no_grad

* add upsampling test case

* make the default custom map None, instead of empty dict

* updated tests, remove packaging and drop PT 1.2 support

* add better support for aten::to and tests

* add a note on dilation in x86
parent 474c70d7
......@@ -19,7 +19,6 @@
# pylint: disable=import-outside-toplevel, simplifiable-if-expression, unnecessary-comprehension
"""PT: PyTorch frontend."""
import itertools
from packaging import version
import numpy as np
......@@ -31,6 +30,7 @@ from .. import expr as _expr
from .. import op as _op
from .common import get_relay_op
from .common import infer_shape as _infer_shape
from .common import infer_value as _infer_value
__all__ = ["from_pytorch"]
......@@ -614,6 +614,61 @@ def _sqrt():
return _op.tensor.sqrt(data)
return _impl
def _floor():
def _impl(inputs, input_types):
data = inputs[0]
return _op.floor(data)
return _impl
def _to():
def _impl(inputs, input_types):
data = inputs[0]
if inputs[3] in ["cpu", "cuda"]:
return data
# special handling for aten::to(data, 6, _, _, _) case
# 6 means dtype = float
# this happens when converting upsampling with scale factor
cast_func = {
6: float,
3: int,
}
cast_func_expr = {
6: lambda x: _op.cast(x, "float32"),
3: lambda x: _op.cast(x, "int32"),
}
if inputs[1] in cast_func and not isinstance(data, _expr.Expr):
return cast_func[inputs[1]](data)
elif inputs[1] in cast_func and isinstance(data, _expr.Expr):
return cast_func_expr[inputs[1]](data)
return data
return _impl
def _upsample(method):
def _impl(inputs, input_types):
if isinstance(inputs[1], _expr.Var):
out_size = _infer_shape(inputs[1])
elif isinstance(inputs[1], list):
infer_res = [_infer_value(size, {}) for size in inputs[1]]
out_size = [np.asscalar(res.asnumpy().astype(np.int))
for res in infer_res]
data = inputs[0]
if len(inputs) > 2:
align_corners = inputs[2]
else:
align_corners = False
if align_corners:
coord_trans = "align_corners"
else:
coord_trans = "half_pixel"
return _op.image.resize(data, out_size, "NCHW", method, coord_trans)
return _impl
# Helper functions for operator implementation
def _convert_data_type(input_type):
......@@ -686,7 +741,7 @@ _convert_map = {
"aten::div_" : _elemwise("divide"),
"aten::ones" : _ones(),
"aten::zeros" : _zeros(),
"aten::to" : _identity(),
"aten::to" : _to(),
"aten::unsqueeze" : _unsqueeze(),
"aten::cat" : _concatenate(),
"aten::slice" : _slice(),
......@@ -729,15 +784,18 @@ _convert_map = {
"aten::permute" : _transpose(),
"aten::sum" : _reduce("sum"),
"aten::prod" : _reduce("prod"),
"aten::sqrt" : _sqrt()
"aten::sqrt" : _sqrt(),
'aten::floor' : _floor(),
"aten::detach" : _identity(),
"aten::upsample_bilinear2d" : _upsample("bilinear"),
"aten::upsample_nearest2d" : _upsample("nearest_neighbor"),
}
def _run_jit_passes(graph):
""" The inline pass is necessary to unwrap prim::CallMethod """
import torch
if version.parse(torch.__version__) >= version.parse("1.4.0"):
torch._C._jit_pass_inline(graph)
torch._C._jit_pass_inline(graph)
def _is_int_seq(seq):
......@@ -985,8 +1043,7 @@ def parse_operators(operators, outputs, output_index_map, ret_name):
def get_all_op_names(graph):
""" Return all operator names in the input graph """
nodes = list(graph.nodes())
return set(node.kind() for node in nodes)
return set(node.kind() for node in graph.nodes())
def get_graph_input_names(script_module):
......@@ -997,7 +1054,7 @@ def get_graph_input_names(script_module):
return ir_inputs[1:] # remove self at the 0th arg
def from_pytorch(script_module, input_shapes):
def from_pytorch(script_module, input_shapes, custom_convert_map=None):
""" Load PyTorch model in the form of a scripted PyTorch model and convert into relay.
The companion parameters will be handled automatically.
......@@ -1011,6 +1068,9 @@ def from_pytorch(script_module, input_shapes):
Graph level input shape dictionary
The keys should be the same one returned by get_graph_input_names(...) above
custom_convert_map: Dictionary of str to Relay op
A custom op conversion map in the same format as _convert_map above
Returns
-------
mod : tvm.relay.Module
......@@ -1021,6 +1081,10 @@ def from_pytorch(script_module, input_shapes):
"""
graph = script_module.graph.copy()
_run_jit_passes(graph)
if custom_convert_map:
_convert_map.update(custom_convert_map)
op_names = get_all_op_names(graph)
_report_missing_conversion(op_names)
......
......@@ -37,7 +37,7 @@ https://pytorch.org/get-started/locally/
PyTorch versions should be backwards compatible but should be used
with the proper TorchVision version.
Currently, TVM supports PyTorch 1.4, 1.3, and 1.2. Other versions may
Currently, TVM supports PyTorch 1.4 and 1.3. Other versions may
be unstable.
"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment