Unverified Commit 7ccb4363 by masahi Committed by GitHub

[Relay, Torch] Clean up and refactor PyTorch frontend (#4944)

* The initial import of refactored implementation, all tests passed

* enable mobilenet v2 test

* minor cleanup

* reorg

* fix lint

* use input names that come with torch IR

* fix typo

* introduce parse_operators

* fix lint

* add _ prefix
parent a6fae5ed
...@@ -18,6 +18,9 @@ ...@@ -18,6 +18,9 @@
# pylint: disable=consider-iterating-dictionary, invalid-name, unused-argument, unused-variable, broad-except # pylint: disable=consider-iterating-dictionary, invalid-name, unused-argument, unused-variable, broad-except
# pylint: disable=import-outside-toplevel, simplifiable-if-expression, unnecessary-comprehension # pylint: disable=import-outside-toplevel, simplifiable-if-expression, unnecessary-comprehension
"""PT: PyTorch frontend.""" """PT: PyTorch frontend."""
import itertools
from packaging import version
import numpy as np import numpy as np
import tvm import tvm
...@@ -396,9 +399,11 @@ def _dense(): ...@@ -396,9 +399,11 @@ def _dense():
def _size(): def _size():
def _impl(inputs, input_types): def _impl(inputs, input_types):
axis = int(inputs[1])
shape = _infer_shape(inputs[0]) shape = _infer_shape(inputs[0])
if len(inputs) > 1:
axis = int(inputs[1])
return shape[axis] return shape[axis]
return shape
return _impl return _impl
def _numtotensor(): def _numtotensor():
...@@ -484,10 +489,19 @@ def _reduce(name): ...@@ -484,10 +489,19 @@ def _reduce(name):
def _mean(): def _mean():
def _impl(inputs, input_types): def _impl(inputs, input_types):
data = inputs[0] data = inputs[0]
axis = _infer_shape(inputs[1])
if inputs[1]:
axis = _infer_shape(inputs[1])
else:
axis = None
if len(inputs) > 2 and inputs[2]:
keepdims = int(inputs[2]) keepdims = int(inputs[2])
else:
keepdims = False
if len(inputs) > 3 and inputs[3]:
exclude = int(inputs[3]) exclude = int(inputs[3])
else:
exclude = False
return _op.mean(data, axis, keepdims, exclude) return _op.mean(data, axis, keepdims, exclude)
return _impl return _impl
...@@ -651,7 +665,7 @@ def _convert_elemwise_input(data, input_type): ...@@ -651,7 +665,7 @@ def _convert_elemwise_input(data, input_type):
if isinstance(data, torch.Tensor): if isinstance(data, torch.Tensor):
return _expr.const(data.item(), dtype=_convert_data_type(input_type)) return _expr.const(data.item(), dtype=_convert_data_type(input_type))
elif not isinstance(data, _expr.Expr): elif not isinstance(data, _expr.Expr):
return _expr.const(int(data), dtype=_convert_data_type(input_type)) return _expr.const(data, dtype=_convert_data_type(input_type))
else: else:
return data return data
...@@ -718,293 +732,270 @@ _convert_map = { ...@@ -718,293 +732,270 @@ _convert_map = {
"aten::sqrt" : _sqrt() "aten::sqrt" : _sqrt()
} }
# Internal graph for parsing
class Graph(object): def _run_jit_passes(graph):
""" A helper class for parsing PyTorch model to Relay graph.""" """ The inline pass is necessary to unwrap prim::CallMethod """
import torch
if version.parse(torch.__version__) >= version.parse("1.4.0"):
torch._C._jit_pass_inline(graph)
def __init__(self, script_module, input_shapes): def _is_int_seq(seq):
return len(seq) > 0 and all([isinstance(i, int) for i in seq])
self._script_module = script_module
self._graph = script_module.graph.copy()
# TODO: Temporary fix to remove prim::CallMethod node introduced in PT 1.4 def _get_tensor_and_var(torch_tensor, name):
import torch tensor = tvm.nd.array(torch_tensor.cpu().numpy())
from packaging import version var = _expr.var(name, shape=tensor.shape)
if version.parse(torch.__version__) >= version.parse("1.4.0"): return tensor, var
torch._C._jit_pass_inline(self._graph)
self._inputs_r = {}
self._params = {}
self._param_tensors = {}
self._consts = {}
self._ops = {}
self._op_inputs_r = {}
self._op_inputs_types = {}
self._input_shapes = input_shapes if input_shapes else {}
self._parsed_node_names = {}
def from_pytorch(self):
""" Construct relay nodes from PyTorch graph
Currently only supports traced PyTorch format which means no control flow.
User must perform torch.jit.trace on a model and pass this in.
Future support should include support scripted models (torch.jit.script) which
preserves control flow.
Returns
-------
mod : tvm.relay.Module
The module that optimizations will be performed on.
params : dict of str to tvm.runtime def _get_output_name(node):
Dict of converted parameters stored in tvm.runtime format assert node.outputsSize() == 1
""" return node.output().debugName()
# Check for missing ops
missing_operators = self._parse_import_prerequisites()
if missing_operators:
raise tvm.error.OpNotImplemented( \
"The following operators are not implemented: {}".format(missing_operators))
# Translate PyTorch graph to by decorating Graph with state dict and inputs into each op
self._parse_inputs()
self._parse_params()
self._parse_ops()
outputs = []
nid = 0
for op_name, op_node in self._ops.items():
if op_node.kind() == "prim::ListConstruct":
if any(inp.debugName() in self._parsed_node_names.keys() \
for inp in op_node.inputs()):
list_constr = []
for i in op_node.inputs():
if i.debugName() in self._parsed_node_names.keys():
list_constr.append( \
outputs[self._parsed_node_names[i.debugName()]])
elif i.node().kind() == "prim::Constant":
list_constr.append(int(self._consts[i.debugName()]))
elif i.debugName() in self._inputs_r.keys():
list_constr.append(int(self._inputs_r[i.debugName()]))
# Unwrap for tensors
if len(list_constr) == 1:
list_constr = list_constr[0]
outputs.append(list_constr)
self._parsed_node_names[op_name] = nid
nid = nid+1
elif op_node.kind() != "prim::Constant":
for i in op_node.inputs():
if i.debugName() in self._parsed_node_names.keys():
for cnt in range(0, len(self._op_inputs_r[op_name])):
if isinstance(self._op_inputs_r[op_name][cnt], str):
if "call/var" in self._op_inputs_r[op_name][cnt]:
self._op_inputs_r[op_name][cnt] = \
outputs[self._parsed_node_names[i.debugName()]]
break
call = _convert_map[op_node.kind()](self._op_inputs_r[op_name],
self._op_inputs_types[op_name])
outputs.append(call)
self._parsed_node_names[op_name] = nid
nid = nid+1
func = tvm.relay.Function(_analysis.free_vars(outputs[-1]), outputs[-1])
param = {k: tvm.nd.array(v) for k, v in self._param_tensors.items()}
return _module.IRModule.from_expr(func), param
def _parse_inputs(self):
""" Map inputs to parser and inputs to graph. """
# Get names and objects of inputs for IR
ir_inputs = [i for i in self._graph.inputs()]
# Create corresponding shape and add to input
for input_name, ir_input in zip(self._input_shapes, ir_inputs[1:]):
input_shape = self._input_shapes[input_name]
ir_input.setDebugName(input_name)
ir_dtype = _convert_data_type(ir_input.type().scalarType().lower())
self._inputs_r[input_name] = _expr.var(input_name,
shape=self._input_shapes[input_name],
dtype=ir_dtype)
# Add self (first input of a PyTorch graph) to inputs, the value doesn't matter here
input_name = ir_inputs[0].debugName()
self._inputs_r[input_name] = "self"
def _parse_params(self):
""" Map state dictionary values to corresponding prim::GetAttr op node. """
# Grab weights, biases, etc. from graph
state_dict = self._script_module.state_dict()
param_names = []
for key, value in state_dict.items():
param_str = str(key)
param_name = param_str.split(".")[-1]
param_names.append(param_name)
# Get names of all inputs
input_names = [i for i in self._inputs_r.keys()]
# Iterate through graph for getAttr nodes and match full state_dict name to nodes
node_weight_map = {}
for node in self._graph.nodes():
if node.kind() == "prim::GetAttr":
def _get_output_names(node):
return [output.debugName() for output in node.outputs()]
def _get_input_names(node_or_graph):
return [inp.debugName() for inp in node_or_graph.inputs()]
def _get_op_inputs(op_node, outputs, output_index_map):
input_names = [output_index_map[name]
for name in _get_input_names(op_node)]
return [outputs[name] for name in input_names]
def _update_outputs_from_pairs(name_output_pairs, outputs, output_index_map):
for output_name, output in name_output_pairs:
output_index_map[output_name] = len(outputs)
outputs.append(output)
def _report_missing_conversion(op_names):
""" Check if all ops in an input graph are supported by TVM """
known_ops = ["prim::Constant", "prim::GetAttr",
"prim::ListConstruct", "prim::ListUnpack",
"prim::TupleConstruct", "prim::TupleUnpack"]
known_ops += list(_convert_map.keys())
missing = [op_name for op_name in op_names
if op_name not in known_ops]
if missing:
msg = "The following operators are not implemented: {}".format(missing)
raise NotImplementedError(msg)
def _getattr_attr_name(node):
attribute_names = node.attributeNames() attribute_names = node.attributeNames()
assert len(attribute_names) == 1 assert len(attribute_names) == 1
node_getattr_name = node.s(attribute_names[0]) attr_name = node.s(attribute_names[0])
node_arg = node.input().debugName() return attr_name
if node.outputsSize() == 1:
node_name = node.output().debugName()
else:
node_name = [output.debugName() for output in node.outputs()][0]
if node_arg in input_names: def _getattr_full_name(getattrs):
node_weight_map[node_name] = node_getattr_name return ".".join([_getattr_attr_name(node) for node in getattrs])
else:
previous_map = node_weight_map[node_arg[:]]
node_weight_map[node_name] = previous_map+"."+node_getattr_name
if node_getattr_name in param_names:
value = state_dict[node_weight_map[node_name]] def _get_input_types(op_node):
tensor = tvm.nd.array(value.cpu().numpy()) """ Returns a torch type for each input nodes """
shape = tensor.shape input_list_types = []
self._param_tensors[node_name] = tensor for input_node in op_node.inputs():
in_ty = input_node.type()
input_node_kind = in_ty.kind()
if input_node_kind == 'TensorType':
if in_ty.scalarType() is None:
input_list_types.append(None)
else:
input_list_types.append(in_ty.scalarType().lower())
elif input_node_kind == 'ListType':
input_list_types.append(str(in_ty.getElementType()).lower())
elif input_node_kind in ['IntType', 'FloatType', 'BoolType',
'StringType', 'OptionalType']:
input_list_types.append(str(in_ty).lower())
else:
input_list_types.append('UnsupportedType')
self._params[node_name] = _expr.var(node_name, if op_node.kind() in ['aten::ones', 'aten::zeros']:
shape=shape, node_type = op_node.output().type()
dtype=_convert_data_type(str(value.dtype))) scalar_type = node_type.scalarType()
if scalar_type:
input_list_types[0] = scalar_type.lower()
def _parse_ops(self): return input_list_types
""" Iterate through nodes and decorate graph with constants, operators,
and the inputs to each operator. """
# Traverse nodes and add to graph
for node in self._graph.nodes():
if node.outputsSize() == 1:
node_name = node.output().debugName()
else:
node_name = [output.debugName() for output in node.outputs()][0]
if node.kind() == "prim::Constant": def _get_constant(node):
if node.hasAttributes(): """ Retrieve a constant associated with this prim::Constant node """
attribute_names = node.attributeNames() attribute_names = node.attributeNames()
num_attributes = len(attribute_names)
if num_attributes == 1:
attr_name = attribute_names[0] attr_name = attribute_names[0]
ty = node.output().type().kind() ty = node.output().type().kind()
if ty in ["IntType", "BoolType"]: if ty in ["IntType", "BoolType"]:
self._consts[node_name] = node.i(attr_name) return node.i(attr_name)
elif ty in ["FloatType", "LongType"]: elif ty in ["FloatType", "LongType"]:
self._consts[node_name] = node.f(attr_name) return node.f(attr_name)
elif ty in ["TensorType", "CompleteTensorType"]: elif ty in ["TensorType", "CompleteTensorType"]:
self._consts[node_name] = node.output().toIValue() tensor = node.t(attr_name)
if len(tensor.shape) == 0: # tensor(0.1)
return float(tensor)
return tensor
elif ty == "DeviceObjType":
return node.s(attr_name)
elif ty == "FunctionType":
return None
else: else:
self._consts[node_name] = "0" raise NotImplementedError("Unsupported type: %s" % ty)
else:
assert num_attributes == 0
return None
def _get_operator_nodes(nodes):
""" Returns torch IR nodes that need conversion to Relay """
ops = {}
# Traverse nodes and add to graph
for node in nodes:
if node.outputsSize() > 1:
node_name = "_".join(_get_output_names(node))
else: else:
self._consts[node_name] = "0" node_name = _get_output_name(node)
elif node.kind() == "prim::ListConstruct":
list_shape = []
for input_node in node.inputs():
if input_node.debugName() in self._inputs_r.keys():
c = self._inputs_r[input_node.debugName()]
assert isinstance(c, int)
list_shape.append(c)
elif input_node.debugName() in self._consts.keys():
c = self._consts[input_node.debugName()]
assert isinstance(c, int)
list_shape.append(c)
self._inputs_r[node_name] = _expr.var(node_name, shape=list_shape)
if node.kind() != "prim::GetAttr": if node.kind() != "prim::GetAttr":
self._add_op(node_name, node) ops[node_name] = node
# Graph Helper Functions return ops
def _add_op(self, node_id, op_node):
""" Add an operator and its operators inputs to the graph and insert placeholders
where an input is a call node.
Parameters def parse_inputs(graph_inputs, input_shapes):
---------- """ Return Relay vars from torch input vars """
node_id : string ir_inputs = list(graph_inputs)
The ID of the op node input_vars = {}
op_node : PyTorch Node object for input_name, ir_input in zip(input_shapes, ir_inputs[1:]):
The full Node object for the op node input_vars[input_name] = _expr.var(input_name,
shape=input_shapes[input_name])
return input_vars
def get_use_chains(root_node, terminate=lambda _: False):
""" """
self._ops[(node_id)] = op_node Track a chain of users of this node forward, returning a list of chains
input_list_r = [] See get_attr_chains below for its usage
input_list_types = [] """
for input_value in op_node.inputs(): def concat_lists(lists):
return itertools.chain.from_iterable(lists)
inode_id = input_value.debugName() def inner(current, accum):
inode = input_value.node() users = []
for output in current.outputs():
users += [use.user for use in output.uses()]
if inode_id in self._inputs_r.keys(): if not users or terminate(users):
input_list_r.append(self._inputs_r[inode_id]) return [accum]
elif inode_id in self._params.keys():
input_list_r.append(self._params[inode_id])
elif inode.kind() == "prim::Constant":
input_list_r.append(self._consts[inode_id])
else:
input_list_r.append("call/var."+inode_id)
# If the inputs of a ListConstruct op is a call or var, remove it from inputs
if op_node.kind() == "prim::ListConstruct":
if node_id in self._inputs_r.keys():
self._inputs_r.pop(node_id)
try:
input_value_kind = input_value.type().kind()
if input_value_kind in ["TensorType", "CompleteTensorType"]:
if input_value.type().scalarType() is None:
input_list_types.append("float")
else:
input_list_types.append(input_value.type().scalarType().lower())
elif input_value_kind == "ListType":
input_list_types.append(str(input_value.type().getElementType()).lower())
elif input_value_kind in ["IntType", "FloatType", "BoolType", "StringType",
"OptionalType"]:
input_list_types.append(str(input_value.type()).lower())
else:
input_list_types.append("UnsupportedType")
print("UnsupportedType "+str(input_value.type())+" and "+str(input_value_kind))
except Exception as e:
print("Internal PyTorch error. Failed to grab type.")
if op_node.kind() in ["aten::ones", "aten::zeros"]: return concat_lists([inner(nxt, accum + [nxt]) for nxt in users])
node_type = op_node.output().type().scalarType()
input_list_types[0] = node_type.lower()
self._op_inputs_r[node_id] = input_list_r return inner(root_node, [root_node])
self._op_inputs_types[node_id] = input_list_types
def _parse_import_prerequisites(self):
""" Calculate the named preconditions from PyTorch graph.
Returns def get_attr_chains(root_getattr_node):
------- """ Returns chains of attribute access starting from root_getattr_node
missing_operators : set object
Set of operator names which don't have their mapping in TVM For example, given attribute "block", as in "self.block" when "self" points
i.e. which are not supported to the top level torch.nn.Module, it returns lists of attribute "chains",
e.g. ['block', '2'], ['block', '1'], ['block', '0', '_packed_params']
These sets of attributes form full attribute accessors. For example,
"self.block.1", "self.block.2" will return the second and third submodule,
and "self.block.0._packed_params" will return the parameters of the first
submodule.
""" """
missing_operators = set() def terminate(users):
for node in self._graph.nodes(): next_attrs = [user for user in users if user.kind() == "prim::GetAttr"]
if not node.kind() in ["prim::Constant", "prim::ListConstruct", "prim::GetAttr"] \ return len(next_attrs) == 0
and not node.kind() in _convert_map:
missing_operators.add(node.kind()) return get_use_chains(root_getattr_node, terminate)
def parse_params(graph, state_dict):
"""
Return Relay vars and TVM NDArrays for input parameters
A chain of prim::GetAttr nodes is processed one at a time
"""
getattr_nodes = graph.findAllNodes("prim::GetAttr", recurse=True)
params = {}
param_tensors = {}
seen = set()
for node in getattr_nodes:
if _get_output_name(node) in seen:
continue
for getattrs in get_attr_chains(node):
seen.update(map(_get_output_name, getattrs))
full_attr = _getattr_full_name(getattrs)
full_attr_node_name = _get_output_name(getattrs[-1])
if full_attr in state_dict:
torch_tensor = state_dict[full_attr]
tensor, var = _get_tensor_and_var(torch_tensor,
full_attr_node_name)
param_tensors[full_attr_node_name] = tensor
params[full_attr_node_name] = var
return params, param_tensors
def parse_operators(operators, outputs, output_index_map, ret_name):
""" Convert each Torch IR operators to Relay equivalent """
for node_name, op_node in operators.items():
operator = op_node.kind()
inputs = _get_op_inputs(op_node, outputs, output_index_map)
if operator == "prim::Constant":
output_index_map[node_name] = len(outputs)
outputs.append(_get_constant(op_node))
elif operator == 'prim::ListConstruct' and _is_int_seq(inputs):
output_index_map[node_name] = len(outputs)
outputs.append(_expr.var(node_name, shape=inputs))
elif operator in ['prim::ListConstruct', 'prim::TupleConstruct']:
output_index_map[node_name] = len(outputs)
outputs.append(inputs)
elif operator in ["prim::ListUnpack", 'prim::TupleUnpack']:
assert len(inputs) == 1
unpacked_names = _get_output_names(op_node)
_update_outputs_from_pairs(zip(unpacked_names, inputs[0]),
outputs, output_index_map)
else:
output_index_map[node_name] = len(outputs)
relay_op = _convert_map[operator]
outputs.append(relay_op(inputs, _get_input_types(op_node)))
return outputs[output_index_map[ret_name]]
def get_all_op_names(graph):
""" Return all operator names in the input graph """
nodes = list(graph.nodes())
return set(node.kind() for node in nodes)
def get_graph_input_names(script_module):
""" Use this function to set the keys for input_shapes"""
# It seems variable names could change the first time a copy is made
# Use the copy of the graph here to prevent troubles later
ir_inputs = _get_input_names(script_module.graph.copy())
return ir_inputs[1:] # remove self at the 0th arg
return missing_operators
def from_pytorch(script_module, input_shapes): def from_pytorch(script_module, input_shapes):
""" Load PyTorch model in the form of a scripted PyTorch model and convert into relay. """ Load PyTorch model in the form of a scripted PyTorch model and convert into relay.
...@@ -1016,17 +1007,35 @@ def from_pytorch(script_module, input_shapes): ...@@ -1016,17 +1007,35 @@ def from_pytorch(script_module, input_shapes):
TorchScripted PyTorch graph TorchScripted PyTorch graph
Note: We currently only support traces (ie: torch.jit.trace(model, input)) Note: We currently only support traces (ie: torch.jit.trace(model, input))
shape : Dictionary of input dimensions input_shapes : Dictionary of input dimensions
Graph level input shape dictionary Graph level input shape dictionary
The keys should be the same one returned by get_graph_input_names(...) above
Returns Returns
------- -------
mod : tvm.relay.Module mod : tvm.relay.Module
The module that optimizations will be performed on. The module that optimizations will be performed on.
params : dict of str to tvm.runtime params : dict of str to tvm.runtime.NDArray
Dict of converted parameters stored in tvm.runtime format Dict of converted parameters stored in tvm.runtime.ndarray format
""" """
g = Graph(script_module, input_shapes) graph = script_module.graph.copy()
mod, params = g.from_pytorch() _run_jit_passes(graph)
return mod, params op_names = get_all_op_names(graph)
_report_missing_conversion(op_names)
params = script_module.state_dict()
input_vars = parse_inputs(graph.inputs(), input_shapes)
param_vars, tensors = parse_params(graph, params)
input_vars.update(param_vars)
outputs = list(input_vars.values())
output_index_map = dict(zip(input_vars.keys(), range(len(outputs))))
ret_name = _get_input_names(graph.return_node())[0]
body = parse_operators(_get_operator_nodes(graph.nodes()), outputs,
output_index_map, ret_name)
func = tvm.relay.Function(_analysis.free_vars(body), body)
tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()}
return _module.IRModule.from_expr(func), tvm_params
...@@ -31,6 +31,8 @@ import torchvision ...@@ -31,6 +31,8 @@ import torchvision
from tvm import relay from tvm import relay
from tvm.contrib import graph_runtime from tvm.contrib import graph_runtime
from tvm.relay.testing.config import ctx_list from tvm.relay.testing.config import ctx_list
from tvm.relay.frontend.pytorch import get_graph_input_names
sys.setrecursionlimit(10000) sys.setrecursionlimit(10000)
...@@ -94,6 +96,7 @@ def load_model(model_name): ...@@ -94,6 +96,7 @@ def load_model(model_name):
if hasattr(torchvision.models, model_name): if hasattr(torchvision.models, model_name):
return load_torchvision(model_name) return load_torchvision(model_name)
try: try:
import pretrainedmodels
if hasattr(pretrainedmodels, model_name): if hasattr(pretrainedmodels, model_name):
return load_pretrainedmodels(model_name) return load_pretrainedmodels(model_name)
except ModuleNotFoundError: except ModuleNotFoundError:
...@@ -167,16 +170,15 @@ def verify_model(model_name, input_data=[]): ...@@ -167,16 +170,15 @@ def verify_model(model_name, input_data=[]):
baseline_outputs = tuple(out.cpu().numpy() for out in baseline_outputs) baseline_outputs = tuple(out.cpu().numpy() for out in baseline_outputs)
else: else:
baseline_outputs = (baseline_outputs.float().cpu().numpy(),) baseline_outputs = (baseline_outputs.float().cpu().numpy(),)
output_shapes = [out.shape for out in baseline_outputs]
dtype = "float32"
input_name = "input0"
input_shapes = {input_name: list(baseline_input.shape)}
trace = torch.jit.trace(baseline_model, baseline_input).float().eval() trace = torch.jit.trace(baseline_model, baseline_input).float().eval()
if torch.cuda.is_available(): if torch.cuda.is_available():
trace = trace.cuda() trace = trace.cuda()
else: else:
trace = trace.cpu() trace = trace.cpu()
input_name = get_graph_input_names(trace)[0] # only one input
input_shapes = {input_name: list(baseline_input.shape)}
mod, params = relay.frontend.from_pytorch(trace, input_shapes) mod, params = relay.frontend.from_pytorch(trace, input_shapes)
compiled_input = {input_name: tvm.nd.array(baseline_input.cpu().numpy())} compiled_input = {input_name: tvm.nd.array(baseline_input.cpu().numpy())}
...@@ -276,7 +278,7 @@ def test_forward_multiply(): ...@@ -276,7 +278,7 @@ def test_forward_multiply():
class Multiply2(Module): class Multiply2(Module):
def forward(self, *args): def forward(self, *args):
return args[0] * 1 return args[0] * 1.0
class Multiply3(Module): class Multiply3(Module):
def forward(self, *args): def forward(self, *args):
...@@ -507,7 +509,7 @@ def test_forward_size(): ...@@ -507,7 +509,7 @@ def test_forward_size():
class Size1(Module): class Size1(Module):
def forward(self, *args): def forward(self, *args):
return args[0].size(0) * args[0] return float(args[0].size(0)) * args[0]
with torch.no_grad(): with torch.no_grad():
input_data = torch.rand(input_shape).float() input_data = torch.rand(input_shape).float()
...@@ -708,6 +710,10 @@ def test_mnasnet0_5(): ...@@ -708,6 +710,10 @@ def test_mnasnet0_5():
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
verify_model("mnasnet0_5") verify_model("mnasnet0_5")
def test_mobilenet_v2():
torch.set_grad_enabled(False)
verify_model("mobilenet_v2")
""" """
#TODO: Fix VGG and AlexNet issues (probably due to pooling) #TODO: Fix VGG and AlexNet issues (probably due to pooling)
def test_alexnet(): def test_alexnet():
...@@ -721,13 +727,9 @@ def test_vgg11(): ...@@ -721,13 +727,9 @@ def test_vgg11():
def test_vgg11_bn(): def test_vgg11_bn():
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
verify_model("vgg11_bn") verify_model("vgg11_bn")
#TODO: Need to update schedule in tophub file after PR #4787 updated workloads
def test_mobilenet_v2():
torch.set_grad_enabled(False)
verify_model("mobilenet_v2")
""" """
if __name__ == "__main__": if __name__ == "__main__":
# Single operator tests # Single operator tests
test_forward_add() test_forward_add()
...@@ -767,3 +769,4 @@ if __name__ == "__main__": ...@@ -767,3 +769,4 @@ if __name__ == "__main__":
test_inception_v3() test_inception_v3()
test_googlenet() test_googlenet()
test_mnasnet0_5() test_mnasnet0_5()
test_mobilenet_v2()
...@@ -41,14 +41,13 @@ Currently, TVM supports PyTorch 1.4, 1.3, and 1.2. Other versions may ...@@ -41,14 +41,13 @@ Currently, TVM supports PyTorch 1.4, 1.3, and 1.2. Other versions may
be unstable. be unstable.
""" """
# tvm, relay
import tvm import tvm
from tvm import relay from tvm import relay
# numpy, packaging
import numpy as np import numpy as np
from packaging import version
from tvm.contrib.download import download_testdata from tvm.contrib.download import download_testdata
from tvm.relay.frontend.pytorch import get_graph_input_names
# PyTorch imports # PyTorch imports
import torch import torch
...@@ -91,7 +90,8 @@ img = np.expand_dims(img, 0) ...@@ -91,7 +90,8 @@ img = np.expand_dims(img, 0)
# Import the graph to Relay # Import the graph to Relay
# ------------------------- # -------------------------
# Convert PyTorch graph to Relay graph. # Convert PyTorch graph to Relay graph.
shape_dict = {'img': img.shape} input_name = get_graph_input_names(scripted_model)[0] # only one input
shape_dict = {input_name: img.shape}
mod, params = relay.frontend.from_pytorch(scripted_model, mod, params = relay.frontend.from_pytorch(scripted_model,
shape_dict) shape_dict)
...@@ -116,12 +116,12 @@ from tvm.contrib import graph_runtime ...@@ -116,12 +116,12 @@ from tvm.contrib import graph_runtime
dtype = 'float32' dtype = 'float32'
m = graph_runtime.create(graph, lib, ctx) m = graph_runtime.create(graph, lib, ctx)
# Set inputs # Set inputs
m.set_input('img', tvm.nd.array(img.astype(dtype))) m.set_input(input_name, tvm.nd.array(img.astype(dtype)))
m.set_input(**params) m.set_input(**params)
# Execute # Execute
m.run() m.run()
# Get outputs # Get outputs
tvm_output = m.get_output(0, tvm.nd.empty(((1, 1000)), 'float32')) tvm_output = m.get_output(0)
##################################################################### #####################################################################
# Look up synset name # Look up synset name
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment