Unverified Commit 7ccb4363 by masahi Committed by GitHub

[Relay, Torch] Clean up and refactor PyTorch frontend (#4944)

* The initial import of refactored implementation, all tests passed

* enable mobilenet v2 test

* minor cleanup

* reorg

* fix lint

* use input names that come with torch IR

* fix typo

* introduce parse_operators

* fix lint

* add _ prefix
parent a6fae5ed
...@@ -18,6 +18,9 @@ ...@@ -18,6 +18,9 @@
# pylint: disable=consider-iterating-dictionary, invalid-name, unused-argument, unused-variable, broad-except # pylint: disable=consider-iterating-dictionary, invalid-name, unused-argument, unused-variable, broad-except
# pylint: disable=import-outside-toplevel, simplifiable-if-expression, unnecessary-comprehension # pylint: disable=import-outside-toplevel, simplifiable-if-expression, unnecessary-comprehension
"""PT: PyTorch frontend.""" """PT: PyTorch frontend."""
import itertools
from packaging import version
import numpy as np import numpy as np
import tvm import tvm
...@@ -396,9 +399,11 @@ def _dense(): ...@@ -396,9 +399,11 @@ def _dense():
def _size(): def _size():
def _impl(inputs, input_types): def _impl(inputs, input_types):
axis = int(inputs[1])
shape = _infer_shape(inputs[0]) shape = _infer_shape(inputs[0])
return shape[axis] if len(inputs) > 1:
axis = int(inputs[1])
return shape[axis]
return shape
return _impl return _impl
def _numtotensor(): def _numtotensor():
...@@ -484,10 +489,19 @@ def _reduce(name): ...@@ -484,10 +489,19 @@ def _reduce(name):
def _mean(): def _mean():
def _impl(inputs, input_types): def _impl(inputs, input_types):
data = inputs[0] data = inputs[0]
axis = _infer_shape(inputs[1])
keepdims = int(inputs[2]) if inputs[1]:
exclude = int(inputs[3]) axis = _infer_shape(inputs[1])
else:
axis = None
if len(inputs) > 2 and inputs[2]:
keepdims = int(inputs[2])
else:
keepdims = False
if len(inputs) > 3 and inputs[3]:
exclude = int(inputs[3])
else:
exclude = False
return _op.mean(data, axis, keepdims, exclude) return _op.mean(data, axis, keepdims, exclude)
return _impl return _impl
...@@ -651,7 +665,7 @@ def _convert_elemwise_input(data, input_type): ...@@ -651,7 +665,7 @@ def _convert_elemwise_input(data, input_type):
if isinstance(data, torch.Tensor): if isinstance(data, torch.Tensor):
return _expr.const(data.item(), dtype=_convert_data_type(input_type)) return _expr.const(data.item(), dtype=_convert_data_type(input_type))
elif not isinstance(data, _expr.Expr): elif not isinstance(data, _expr.Expr):
return _expr.const(int(data), dtype=_convert_data_type(input_type)) return _expr.const(data, dtype=_convert_data_type(input_type))
else: else:
return data return data
...@@ -718,293 +732,270 @@ _convert_map = { ...@@ -718,293 +732,270 @@ _convert_map = {
"aten::sqrt" : _sqrt() "aten::sqrt" : _sqrt()
} }
# Internal graph for parsing
class Graph(object): def _run_jit_passes(graph):
""" A helper class for parsing PyTorch model to Relay graph.""" """ The inline pass is necessary to unwrap prim::CallMethod """
import torch
if version.parse(torch.__version__) >= version.parse("1.4.0"):
torch._C._jit_pass_inline(graph)
def __init__(self, script_module, input_shapes):
self._script_module = script_module def _is_int_seq(seq):
self._graph = script_module.graph.copy() return len(seq) > 0 and all([isinstance(i, int) for i in seq])
# TODO: Temporary fix to remove prim::CallMethod node introduced in PT 1.4
import torch def _get_tensor_and_var(torch_tensor, name):
from packaging import version tensor = tvm.nd.array(torch_tensor.cpu().numpy())
if version.parse(torch.__version__) >= version.parse("1.4.0"): var = _expr.var(name, shape=tensor.shape)
torch._C._jit_pass_inline(self._graph) return tensor, var
self._inputs_r = {}
self._params = {} def _get_output_name(node):
self._param_tensors = {} assert node.outputsSize() == 1
self._consts = {} return node.output().debugName()
self._ops = {}
self._op_inputs_r = {}
self._op_inputs_types = {} def _get_output_names(node):
self._input_shapes = input_shapes if input_shapes else {} return [output.debugName() for output in node.outputs()]
self._parsed_node_names = {}
def from_pytorch(self): def _get_input_names(node_or_graph):
""" Construct relay nodes from PyTorch graph return [inp.debugName() for inp in node_or_graph.inputs()]
Currently only supports traced PyTorch format which means no control flow.
User must perform torch.jit.trace on a model and pass this in. def _get_op_inputs(op_node, outputs, output_index_map):
Future support should include support scripted models (torch.jit.script) which input_names = [output_index_map[name]
preserves control flow. for name in _get_input_names(op_node)]
return [outputs[name] for name in input_names]
Returns
-------
mod : tvm.relay.Module def _update_outputs_from_pairs(name_output_pairs, outputs, output_index_map):
The module that optimizations will be performed on. for output_name, output in name_output_pairs:
output_index_map[output_name] = len(outputs)
params : dict of str to tvm.runtime outputs.append(output)
Dict of converted parameters stored in tvm.runtime format
"""
# Check for missing ops def _report_missing_conversion(op_names):
missing_operators = self._parse_import_prerequisites() """ Check if all ops in an input graph are supported by TVM """
known_ops = ["prim::Constant", "prim::GetAttr",
if missing_operators: "prim::ListConstruct", "prim::ListUnpack",
raise tvm.error.OpNotImplemented( \ "prim::TupleConstruct", "prim::TupleUnpack"]
"The following operators are not implemented: {}".format(missing_operators)) known_ops += list(_convert_map.keys())
# Translate PyTorch graph to by decorating Graph with state dict and inputs into each op missing = [op_name for op_name in op_names
self._parse_inputs() if op_name not in known_ops]
self._parse_params()
self._parse_ops() if missing:
msg = "The following operators are not implemented: {}".format(missing)
outputs = [] raise NotImplementedError(msg)
nid = 0
for op_name, op_node in self._ops.items(): def _getattr_attr_name(node):
if op_node.kind() == "prim::ListConstruct": attribute_names = node.attributeNames()
if any(inp.debugName() in self._parsed_node_names.keys() \ assert len(attribute_names) == 1
for inp in op_node.inputs()): attr_name = node.s(attribute_names[0])
list_constr = [] return attr_name
for i in op_node.inputs():
if i.debugName() in self._parsed_node_names.keys():
list_constr.append( \ def _getattr_full_name(getattrs):
outputs[self._parsed_node_names[i.debugName()]]) return ".".join([_getattr_attr_name(node) for node in getattrs])
elif i.node().kind() == "prim::Constant":
list_constr.append(int(self._consts[i.debugName()]))
elif i.debugName() in self._inputs_r.keys(): def _get_input_types(op_node):
list_constr.append(int(self._inputs_r[i.debugName()])) """ Returns a torch type for each input nodes """
input_list_types = []
# Unwrap for tensors for input_node in op_node.inputs():
if len(list_constr) == 1: in_ty = input_node.type()
list_constr = list_constr[0] input_node_kind = in_ty.kind()
if input_node_kind == 'TensorType':
outputs.append(list_constr) if in_ty.scalarType() is None:
self._parsed_node_names[op_name] = nid input_list_types.append(None)
nid = nid+1
elif op_node.kind() != "prim::Constant":
for i in op_node.inputs():
if i.debugName() in self._parsed_node_names.keys():
for cnt in range(0, len(self._op_inputs_r[op_name])):
if isinstance(self._op_inputs_r[op_name][cnt], str):
if "call/var" in self._op_inputs_r[op_name][cnt]:
self._op_inputs_r[op_name][cnt] = \
outputs[self._parsed_node_names[i.debugName()]]
break
call = _convert_map[op_node.kind()](self._op_inputs_r[op_name],
self._op_inputs_types[op_name])
outputs.append(call)
self._parsed_node_names[op_name] = nid
nid = nid+1
func = tvm.relay.Function(_analysis.free_vars(outputs[-1]), outputs[-1])
param = {k: tvm.nd.array(v) for k, v in self._param_tensors.items()}
return _module.IRModule.from_expr(func), param
def _parse_inputs(self):
""" Map inputs to parser and inputs to graph. """
# Get names and objects of inputs for IR
ir_inputs = [i for i in self._graph.inputs()]
# Create corresponding shape and add to input
for input_name, ir_input in zip(self._input_shapes, ir_inputs[1:]):
input_shape = self._input_shapes[input_name]
ir_input.setDebugName(input_name)
ir_dtype = _convert_data_type(ir_input.type().scalarType().lower())
self._inputs_r[input_name] = _expr.var(input_name,
shape=self._input_shapes[input_name],
dtype=ir_dtype)
# Add self (first input of a PyTorch graph) to inputs, the value doesn't matter here
input_name = ir_inputs[0].debugName()
self._inputs_r[input_name] = "self"
def _parse_params(self):
""" Map state dictionary values to corresponding prim::GetAttr op node. """
# Grab weights, biases, etc. from graph
state_dict = self._script_module.state_dict()
param_names = []
for key, value in state_dict.items():
param_str = str(key)
param_name = param_str.split(".")[-1]
param_names.append(param_name)
# Get names of all inputs
input_names = [i for i in self._inputs_r.keys()]
# Iterate through graph for getAttr nodes and match full state_dict name to nodes
node_weight_map = {}
for node in self._graph.nodes():
if node.kind() == "prim::GetAttr":
attribute_names = node.attributeNames()
assert len(attribute_names) == 1
node_getattr_name = node.s(attribute_names[0])
node_arg = node.input().debugName()
if node.outputsSize() == 1:
node_name = node.output().debugName()
else:
node_name = [output.debugName() for output in node.outputs()][0]
if node_arg in input_names:
node_weight_map[node_name] = node_getattr_name
else:
previous_map = node_weight_map[node_arg[:]]
node_weight_map[node_name] = previous_map+"."+node_getattr_name
if node_getattr_name in param_names:
value = state_dict[node_weight_map[node_name]]
tensor = tvm.nd.array(value.cpu().numpy())
shape = tensor.shape
self._param_tensors[node_name] = tensor
self._params[node_name] = _expr.var(node_name,
shape=shape,
dtype=_convert_data_type(str(value.dtype)))
def _parse_ops(self):
""" Iterate through nodes and decorate graph with constants, operators,
and the inputs to each operator. """
# Traverse nodes and add to graph
for node in self._graph.nodes():
if node.outputsSize() == 1:
node_name = node.output().debugName()
else:
node_name = [output.debugName() for output in node.outputs()][0]
if node.kind() == "prim::Constant":
if node.hasAttributes():
attribute_names = node.attributeNames()
attr_name = attribute_names[0]
ty = node.output().type().kind()
if ty in ["IntType", "BoolType"]:
self._consts[node_name] = node.i(attr_name)
elif ty in ["FloatType", "LongType"]:
self._consts[node_name] = node.f(attr_name)
elif ty in ["TensorType", "CompleteTensorType"]:
self._consts[node_name] = node.output().toIValue()
else:
self._consts[node_name] = "0"
else:
self._consts[node_name] = "0"
elif node.kind() == "prim::ListConstruct":
list_shape = []
for input_node in node.inputs():
if input_node.debugName() in self._inputs_r.keys():
c = self._inputs_r[input_node.debugName()]
assert isinstance(c, int)
list_shape.append(c)
elif input_node.debugName() in self._consts.keys():
c = self._consts[input_node.debugName()]
assert isinstance(c, int)
list_shape.append(c)
self._inputs_r[node_name] = _expr.var(node_name, shape=list_shape)
if node.kind() != "prim::GetAttr":
self._add_op(node_name, node)
# Graph Helper Functions
def _add_op(self, node_id, op_node):
""" Add an operator and its operators inputs to the graph and insert placeholders
where an input is a call node.
Parameters
----------
node_id : string
The ID of the op node
op_node : PyTorch Node object
The full Node object for the op node
"""
self._ops[(node_id)] = op_node
input_list_r = []
input_list_types = []
for input_value in op_node.inputs():
inode_id = input_value.debugName()
inode = input_value.node()
if inode_id in self._inputs_r.keys():
input_list_r.append(self._inputs_r[inode_id])
elif inode_id in self._params.keys():
input_list_r.append(self._params[inode_id])
elif inode.kind() == "prim::Constant":
input_list_r.append(self._consts[inode_id])
else: else:
input_list_r.append("call/var."+inode_id) input_list_types.append(in_ty.scalarType().lower())
elif input_node_kind == 'ListType':
# If the inputs of a ListConstruct op is a call or var, remove it from inputs input_list_types.append(str(in_ty.getElementType()).lower())
if op_node.kind() == "prim::ListConstruct": elif input_node_kind in ['IntType', 'FloatType', 'BoolType',
if node_id in self._inputs_r.keys(): 'StringType', 'OptionalType']:
self._inputs_r.pop(node_id) input_list_types.append(str(in_ty).lower())
else:
try: input_list_types.append('UnsupportedType')
input_value_kind = input_value.type().kind()
if input_value_kind in ["TensorType", "CompleteTensorType"]: if op_node.kind() in ['aten::ones', 'aten::zeros']:
if input_value.type().scalarType() is None: node_type = op_node.output().type()
input_list_types.append("float") scalar_type = node_type.scalarType()
else: if scalar_type:
input_list_types.append(input_value.type().scalarType().lower()) input_list_types[0] = scalar_type.lower()
elif input_value_kind == "ListType":
input_list_types.append(str(input_value.type().getElementType()).lower()) return input_list_types
elif input_value_kind in ["IntType", "FloatType", "BoolType", "StringType",
"OptionalType"]:
input_list_types.append(str(input_value.type()).lower()) def _get_constant(node):
else: """ Retrieve a constant associated with this prim::Constant node """
input_list_types.append("UnsupportedType") attribute_names = node.attributeNames()
print("UnsupportedType "+str(input_value.type())+" and "+str(input_value_kind)) num_attributes = len(attribute_names)
except Exception as e:
print("Internal PyTorch error. Failed to grab type.") if num_attributes == 1:
attr_name = attribute_names[0]
if op_node.kind() in ["aten::ones", "aten::zeros"]: ty = node.output().type().kind()
node_type = op_node.output().type().scalarType()
input_list_types[0] = node_type.lower() if ty in ["IntType", "BoolType"]:
return node.i(attr_name)
self._op_inputs_r[node_id] = input_list_r elif ty in ["FloatType", "LongType"]:
self._op_inputs_types[node_id] = input_list_types return node.f(attr_name)
elif ty in ["TensorType", "CompleteTensorType"]:
def _parse_import_prerequisites(self): tensor = node.t(attr_name)
""" Calculate the named preconditions from PyTorch graph. if len(tensor.shape) == 0: # tensor(0.1)
return float(tensor)
Returns return tensor
------- elif ty == "DeviceObjType":
missing_operators : set object return node.s(attr_name)
Set of operator names which don't have their mapping in TVM elif ty == "FunctionType":
i.e. which are not supported return None
else:
""" raise NotImplementedError("Unsupported type: %s" % ty)
missing_operators = set() else:
for node in self._graph.nodes(): assert num_attributes == 0
if not node.kind() in ["prim::Constant", "prim::ListConstruct", "prim::GetAttr"] \ return None
and not node.kind() in _convert_map:
missing_operators.add(node.kind())
def _get_operator_nodes(nodes):
return missing_operators """ Returns torch IR nodes that need conversion to Relay """
ops = {}
# Traverse nodes and add to graph
for node in nodes:
if node.outputsSize() > 1:
node_name = "_".join(_get_output_names(node))
else:
node_name = _get_output_name(node)
if node.kind() != "prim::GetAttr":
ops[node_name] = node
return ops
def parse_inputs(graph_inputs, input_shapes):
""" Return Relay vars from torch input vars """
ir_inputs = list(graph_inputs)
input_vars = {}
for input_name, ir_input in zip(input_shapes, ir_inputs[1:]):
input_vars[input_name] = _expr.var(input_name,
shape=input_shapes[input_name])
return input_vars
def get_use_chains(root_node, terminate=lambda _: False):
"""
Track a chain of users of this node forward, returning a list of chains
See get_attr_chains below for its usage
"""
def concat_lists(lists):
return itertools.chain.from_iterable(lists)
def inner(current, accum):
users = []
for output in current.outputs():
users += [use.user for use in output.uses()]
if not users or terminate(users):
return [accum]
return concat_lists([inner(nxt, accum + [nxt]) for nxt in users])
return inner(root_node, [root_node])
def get_attr_chains(root_getattr_node):
""" Returns chains of attribute access starting from root_getattr_node
For example, given attribute "block", as in "self.block" when "self" points
to the top level torch.nn.Module, it returns lists of attribute "chains",
e.g. ['block', '2'], ['block', '1'], ['block', '0', '_packed_params']
These sets of attributes form full attribute accessors. For example,
"self.block.1", "self.block.2" will return the second and third submodule,
and "self.block.0._packed_params" will return the parameters of the first
submodule.
"""
def terminate(users):
next_attrs = [user for user in users if user.kind() == "prim::GetAttr"]
return len(next_attrs) == 0
return get_use_chains(root_getattr_node, terminate)
def parse_params(graph, state_dict):
"""
Return Relay vars and TVM NDArrays for input parameters
A chain of prim::GetAttr nodes is processed one at a time
"""
getattr_nodes = graph.findAllNodes("prim::GetAttr", recurse=True)
params = {}
param_tensors = {}
seen = set()
for node in getattr_nodes:
if _get_output_name(node) in seen:
continue
for getattrs in get_attr_chains(node):
seen.update(map(_get_output_name, getattrs))
full_attr = _getattr_full_name(getattrs)
full_attr_node_name = _get_output_name(getattrs[-1])
if full_attr in state_dict:
torch_tensor = state_dict[full_attr]
tensor, var = _get_tensor_and_var(torch_tensor,
full_attr_node_name)
param_tensors[full_attr_node_name] = tensor
params[full_attr_node_name] = var
return params, param_tensors
def parse_operators(operators, outputs, output_index_map, ret_name):
""" Convert each Torch IR operators to Relay equivalent """
for node_name, op_node in operators.items():
operator = op_node.kind()
inputs = _get_op_inputs(op_node, outputs, output_index_map)
if operator == "prim::Constant":
output_index_map[node_name] = len(outputs)
outputs.append(_get_constant(op_node))
elif operator == 'prim::ListConstruct' and _is_int_seq(inputs):
output_index_map[node_name] = len(outputs)
outputs.append(_expr.var(node_name, shape=inputs))
elif operator in ['prim::ListConstruct', 'prim::TupleConstruct']:
output_index_map[node_name] = len(outputs)
outputs.append(inputs)
elif operator in ["prim::ListUnpack", 'prim::TupleUnpack']:
assert len(inputs) == 1
unpacked_names = _get_output_names(op_node)
_update_outputs_from_pairs(zip(unpacked_names, inputs[0]),
outputs, output_index_map)
else:
output_index_map[node_name] = len(outputs)
relay_op = _convert_map[operator]
outputs.append(relay_op(inputs, _get_input_types(op_node)))
return outputs[output_index_map[ret_name]]
def get_all_op_names(graph):
""" Return all operator names in the input graph """
nodes = list(graph.nodes())
return set(node.kind() for node in nodes)
def get_graph_input_names(script_module):
""" Use this function to set the keys for input_shapes"""
# It seems variable names could change the first time a copy is made
# Use the copy of the graph here to prevent troubles later
ir_inputs = _get_input_names(script_module.graph.copy())
return ir_inputs[1:] # remove self at the 0th arg
def from_pytorch(script_module, input_shapes): def from_pytorch(script_module, input_shapes):
""" Load PyTorch model in the form of a scripted PyTorch model and convert into relay. """ Load PyTorch model in the form of a scripted PyTorch model and convert into relay.
...@@ -1016,17 +1007,35 @@ def from_pytorch(script_module, input_shapes): ...@@ -1016,17 +1007,35 @@ def from_pytorch(script_module, input_shapes):
TorchScripted PyTorch graph TorchScripted PyTorch graph
Note: We currently only support traces (ie: torch.jit.trace(model, input)) Note: We currently only support traces (ie: torch.jit.trace(model, input))
shape : Dictionary of input dimensions input_shapes : Dictionary of input dimensions
Graph level input shape dictionary Graph level input shape dictionary
The keys should be the same one returned by get_graph_input_names(...) above
Returns Returns
------- -------
mod : tvm.relay.Module mod : tvm.relay.Module
The module that optimizations will be performed on. The module that optimizations will be performed on.
params : dict of str to tvm.runtime params : dict of str to tvm.runtime.NDArray
Dict of converted parameters stored in tvm.runtime format Dict of converted parameters stored in tvm.runtime.ndarray format
""" """
g = Graph(script_module, input_shapes) graph = script_module.graph.copy()
mod, params = g.from_pytorch() _run_jit_passes(graph)
return mod, params op_names = get_all_op_names(graph)
_report_missing_conversion(op_names)
params = script_module.state_dict()
input_vars = parse_inputs(graph.inputs(), input_shapes)
param_vars, tensors = parse_params(graph, params)
input_vars.update(param_vars)
outputs = list(input_vars.values())
output_index_map = dict(zip(input_vars.keys(), range(len(outputs))))
ret_name = _get_input_names(graph.return_node())[0]
body = parse_operators(_get_operator_nodes(graph.nodes()), outputs,
output_index_map, ret_name)
func = tvm.relay.Function(_analysis.free_vars(body), body)
tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()}
return _module.IRModule.from_expr(func), tvm_params
...@@ -31,6 +31,8 @@ import torchvision ...@@ -31,6 +31,8 @@ import torchvision
from tvm import relay from tvm import relay
from tvm.contrib import graph_runtime from tvm.contrib import graph_runtime
from tvm.relay.testing.config import ctx_list from tvm.relay.testing.config import ctx_list
from tvm.relay.frontend.pytorch import get_graph_input_names
sys.setrecursionlimit(10000) sys.setrecursionlimit(10000)
...@@ -94,6 +96,7 @@ def load_model(model_name): ...@@ -94,6 +96,7 @@ def load_model(model_name):
if hasattr(torchvision.models, model_name): if hasattr(torchvision.models, model_name):
return load_torchvision(model_name) return load_torchvision(model_name)
try: try:
import pretrainedmodels
if hasattr(pretrainedmodels, model_name): if hasattr(pretrainedmodels, model_name):
return load_pretrainedmodels(model_name) return load_pretrainedmodels(model_name)
except ModuleNotFoundError: except ModuleNotFoundError:
...@@ -167,16 +170,15 @@ def verify_model(model_name, input_data=[]): ...@@ -167,16 +170,15 @@ def verify_model(model_name, input_data=[]):
baseline_outputs = tuple(out.cpu().numpy() for out in baseline_outputs) baseline_outputs = tuple(out.cpu().numpy() for out in baseline_outputs)
else: else:
baseline_outputs = (baseline_outputs.float().cpu().numpy(),) baseline_outputs = (baseline_outputs.float().cpu().numpy(),)
output_shapes = [out.shape for out in baseline_outputs]
dtype = "float32"
input_name = "input0"
input_shapes = {input_name: list(baseline_input.shape)}
trace = torch.jit.trace(baseline_model, baseline_input).float().eval() trace = torch.jit.trace(baseline_model, baseline_input).float().eval()
if torch.cuda.is_available(): if torch.cuda.is_available():
trace = trace.cuda() trace = trace.cuda()
else: else:
trace = trace.cpu() trace = trace.cpu()
input_name = get_graph_input_names(trace)[0] # only one input
input_shapes = {input_name: list(baseline_input.shape)}
mod, params = relay.frontend.from_pytorch(trace, input_shapes) mod, params = relay.frontend.from_pytorch(trace, input_shapes)
compiled_input = {input_name: tvm.nd.array(baseline_input.cpu().numpy())} compiled_input = {input_name: tvm.nd.array(baseline_input.cpu().numpy())}
...@@ -276,7 +278,7 @@ def test_forward_multiply(): ...@@ -276,7 +278,7 @@ def test_forward_multiply():
class Multiply2(Module): class Multiply2(Module):
def forward(self, *args): def forward(self, *args):
return args[0] * 1 return args[0] * 1.0
class Multiply3(Module): class Multiply3(Module):
def forward(self, *args): def forward(self, *args):
...@@ -507,7 +509,7 @@ def test_forward_size(): ...@@ -507,7 +509,7 @@ def test_forward_size():
class Size1(Module): class Size1(Module):
def forward(self, *args): def forward(self, *args):
return args[0].size(0) * args[0] return float(args[0].size(0)) * args[0]
with torch.no_grad(): with torch.no_grad():
input_data = torch.rand(input_shape).float() input_data = torch.rand(input_shape).float()
...@@ -708,6 +710,10 @@ def test_mnasnet0_5(): ...@@ -708,6 +710,10 @@ def test_mnasnet0_5():
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
verify_model("mnasnet0_5") verify_model("mnasnet0_5")
def test_mobilenet_v2():
torch.set_grad_enabled(False)
verify_model("mobilenet_v2")
""" """
#TODO: Fix VGG and AlexNet issues (probably due to pooling) #TODO: Fix VGG and AlexNet issues (probably due to pooling)
def test_alexnet(): def test_alexnet():
...@@ -721,13 +727,9 @@ def test_vgg11(): ...@@ -721,13 +727,9 @@ def test_vgg11():
def test_vgg11_bn(): def test_vgg11_bn():
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
verify_model("vgg11_bn") verify_model("vgg11_bn")
#TODO: Need to update schedule in tophub file after PR #4787 updated workloads
def test_mobilenet_v2():
torch.set_grad_enabled(False)
verify_model("mobilenet_v2")
""" """
if __name__ == "__main__": if __name__ == "__main__":
# Single operator tests # Single operator tests
test_forward_add() test_forward_add()
...@@ -767,3 +769,4 @@ if __name__ == "__main__": ...@@ -767,3 +769,4 @@ if __name__ == "__main__":
test_inception_v3() test_inception_v3()
test_googlenet() test_googlenet()
test_mnasnet0_5() test_mnasnet0_5()
test_mobilenet_v2()
...@@ -41,14 +41,13 @@ Currently, TVM supports PyTorch 1.4, 1.3, and 1.2. Other versions may ...@@ -41,14 +41,13 @@ Currently, TVM supports PyTorch 1.4, 1.3, and 1.2. Other versions may
be unstable. be unstable.
""" """
# tvm, relay
import tvm import tvm
from tvm import relay from tvm import relay
# numpy, packaging
import numpy as np import numpy as np
from packaging import version
from tvm.contrib.download import download_testdata from tvm.contrib.download import download_testdata
from tvm.relay.frontend.pytorch import get_graph_input_names
# PyTorch imports # PyTorch imports
import torch import torch
...@@ -91,7 +90,8 @@ img = np.expand_dims(img, 0) ...@@ -91,7 +90,8 @@ img = np.expand_dims(img, 0)
# Import the graph to Relay # Import the graph to Relay
# ------------------------- # -------------------------
# Convert PyTorch graph to Relay graph. # Convert PyTorch graph to Relay graph.
shape_dict = {'img': img.shape} input_name = get_graph_input_names(scripted_model)[0] # only one input
shape_dict = {input_name: img.shape}
mod, params = relay.frontend.from_pytorch(scripted_model, mod, params = relay.frontend.from_pytorch(scripted_model,
shape_dict) shape_dict)
...@@ -116,12 +116,12 @@ from tvm.contrib import graph_runtime ...@@ -116,12 +116,12 @@ from tvm.contrib import graph_runtime
dtype = 'float32' dtype = 'float32'
m = graph_runtime.create(graph, lib, ctx) m = graph_runtime.create(graph, lib, ctx)
# Set inputs # Set inputs
m.set_input('img', tvm.nd.array(img.astype(dtype))) m.set_input(input_name, tvm.nd.array(img.astype(dtype)))
m.set_input(**params) m.set_input(**params)
# Execute # Execute
m.run() m.run()
# Get outputs # Get outputs
tvm_output = m.get_output(0, tvm.nd.empty(((1, 1000)), 'float32')) tvm_output = m.get_output(0)
##################################################################### #####################################################################
# Look up synset name # Look up synset name
...@@ -163,4 +163,4 @@ with torch.no_grad(): ...@@ -163,4 +163,4 @@ with torch.no_grad():
torch_class_key = class_id_to_key[top1_torch] torch_class_key = class_id_to_key[top1_torch]
print('Relay top-1 id: {}, class name: {}'.format(top1_tvm, key_to_classname[tvm_class_key])) print('Relay top-1 id: {}, class name: {}'.format(top1_tvm, key_to_classname[tvm_class_key]))
print('Torch top-1 id: {}, class name: {}'.format(top1_torch, key_to_classname[torch_class_key])) print('Torch top-1 id: {}, class name: {}'.format(top1_torch, key_to_classname[torch_class_key]))
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment