Unverified Commit 03cbf78e by Jeremy Johnson Committed by GitHub

[Frontend][Torch] Fix up graph input handling (#5204)

* [Frontend][Torch] Simplify operator input handling

* [Frontend][Torch] Allow user supplied input names to override graph inputs

* Fix pylint issues

* Updates from code review feedback

* Fix tutorial to use shape list input

* Disable intermittent test failure in topi vision test
parent 15b1751c
...@@ -1071,16 +1071,8 @@ def _get_input_names(node_or_graph): ...@@ -1071,16 +1071,8 @@ def _get_input_names(node_or_graph):
return [inp.debugName() for inp in node_or_graph.inputs()] return [inp.debugName() for inp in node_or_graph.inputs()]
def _get_op_inputs(op_node, outputs, output_index_map): def _get_op_inputs(op_node, outputs):
input_names = [output_index_map[name] return [outputs[name] for name in _get_input_names(op_node)]
for name in _get_input_names(op_node)]
return [outputs[name] for name in input_names]
def _update_outputs_from_pairs(name_output_pairs, outputs, output_index_map):
for output_name, output in name_output_pairs:
output_index_map[output_name] = len(outputs)
outputs.append(output)
def _report_missing_conversion(op_names): def _report_missing_conversion(op_names):
...@@ -1100,18 +1092,31 @@ def _report_missing_conversion(op_names): ...@@ -1100,18 +1092,31 @@ def _report_missing_conversion(op_names):
raise NotImplementedError(msg) raise NotImplementedError(msg)
def _check_input_names(script_module, input_shapes): def _check_inputs(graph, input_shapes):
""" Check the graph inputs match the inputs """ """
ir_inputs = get_graph_input_names(script_module) Check the graph inputs match the expected number of inputs
and are in the correct format
"""
ir_inputs = _get_graph_input_names(graph)
for ir_input in ir_inputs: if not isinstance(input_shapes, list):
if ir_input not in input_shapes: msg = "Graph inputs input_shapes should be list"
msg = "Missing graph input {} in input_shapes".format(ir_input) raise RuntimeError(msg)
missing_inputs = len(ir_inputs) - len(input_shapes)
if missing_inputs > 0:
msg = "Missing {} graph input(s) in input_shapes".format(missing_inputs)
raise RuntimeError(msg) raise RuntimeError(msg)
for input_name in input_shapes: for num, inp in enumerate(input_shapes):
if input_name not in ir_inputs: if num < len(ir_inputs):
msg = "Unused graph input {} in input_shapes".format(input_name) if not isinstance(inp, tuple):
msg = "Graph input {} is not a tuple".format(num)
raise RuntimeError(msg)
if (len(inp) != 2 or not isinstance(inp[0], str)):
msg = "Graph input {} is not valid, expected ('name', shape)".format(inp)
raise RuntimeError(msg)
else:
msg = "Unused graph input {} in input_shapes".format(inp)
logging.warning(msg) logging.warning(msg)
...@@ -1203,10 +1208,19 @@ def _get_operator_nodes(nodes): ...@@ -1203,10 +1208,19 @@ def _get_operator_nodes(nodes):
return ops return ops
def _get_relay_input_vars(input_shapes): def _get_relay_input_vars(graph, input_shapes):
""" Return Relay vars from input shapes """ """
return {iname: _expr.var(iname, shape=ishape) Return Relay vars from input shapes and create entries based on
for iname, ishape in input_shapes.items()} expected graph inputs - to allow translation
"""
input_vars = {}
ir_inputs = _get_graph_input_names(graph)
for ir_input, (name, shape) in zip(ir_inputs, input_shapes):
inp = _expr.var(name, shape=shape)
# Translate from graph input to user input name
input_vars[ir_input] = inp
return input_vars
def get_use_chains(root_node, terminate=lambda _: False): def get_use_chains(root_node, terminate=lambda _: False):
...@@ -1284,24 +1298,24 @@ def convert_params(graph, state_dict): ...@@ -1284,24 +1298,24 @@ def convert_params(graph, state_dict):
return params, param_tensors, packed_param_map return params, param_tensors, packed_param_map
def convert_block(block, outputs, output_index_map): def convert_block(block, outputs):
""" Translate Torch "Block", used for prim::If and prim::Loop """ """ Translate Torch "Block", used for prim::If and prim::Loop """
ops = _get_operator_nodes(block.nodes()) ops = _get_operator_nodes(block.nodes())
ret_names = _get_input_names(block.returnNode()) ret_names = _get_input_names(block.returnNode())
return convert_operators(ops, outputs, output_index_map, ret_names) return convert_operators(ops, outputs, ret_names)
def convert_if(if_node, outputs, output_index_map): def convert_if(if_node, outputs):
""" Translate Torch prim::If to Relay If """ """ Translate Torch prim::If to Relay If """
cond = outputs[output_index_map[if_node.inputsAt(0).debugName()]] cond = outputs[if_node.inputsAt(0).debugName()]
blocks = list(if_node.blocks()) blocks = list(if_node.blocks())
true_branch = convert_block(blocks[0], outputs, output_index_map) true_branch = convert_block(blocks[0], outputs)
false_branch = convert_block(blocks[1], outputs, output_index_map) false_branch = convert_block(blocks[1], outputs)
assert len(true_branch) == 1 and len(false_branch) == 1 assert len(true_branch) == 1 and len(false_branch) == 1
return _expr.If(cond, true_branch[0], false_branch[0]) return _expr.If(cond, true_branch[0], false_branch[0])
def convert_loop(loop_node, outputs, output_index_map): def convert_loop(loop_node, outputs):
""" Translate Torch prim::Loop to Relay while_loop """ """ Translate Torch prim::Loop to Relay while_loop """
def get_input(index): def get_input(index):
ivalue = loop_node.inputsAt(index) ivalue = loop_node.inputsAt(index)
...@@ -1309,8 +1323,8 @@ def convert_loop(loop_node, outputs, output_index_map): ...@@ -1309,8 +1323,8 @@ def convert_loop(loop_node, outputs, output_index_map):
if inode.kind() == "prim::Constant": if inode.kind() == "prim::Constant":
return _expr.const(_get_constant(inode)) return _expr.const(_get_constant(inode))
var_name = ivalue.debugName() var_name = ivalue.debugName()
assert var_name in output_index_map assert var_name in outputs
return _wrap_const(outputs[output_index_map[var_name]]) return _wrap_const(outputs[var_name])
# Refer to the spec for prim::Loop below # Refer to the spec for prim::Loop below
# https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/OVERVIEW.md#loops # https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/OVERVIEW.md#loops
...@@ -1342,9 +1356,9 @@ def convert_loop(loop_node, outputs, output_index_map): ...@@ -1342,9 +1356,9 @@ def convert_loop(loop_node, outputs, output_index_map):
# Update loop variables using the prev iteration outputs # Update loop variables using the prev iteration outputs
assert len(current_vals) == len(block_input_names) assert len(current_vals) == len(block_input_names)
for (i, iname) in enumerate(block_input_names): for (i, iname) in enumerate(block_input_names):
outputs[output_index_map[iname]] = current_vals[i] outputs[iname] = current_vals[i]
block_outputs = convert_block(body_block, outputs, output_index_map) block_outputs = convert_block(body_block, outputs)
if not is_while_loop: if not is_while_loop:
# iter var increment implicit in torch, so do it manually # iter var increment implicit in torch, so do it manually
...@@ -1374,7 +1388,7 @@ def convert_loop(loop_node, outputs, output_index_map): ...@@ -1374,7 +1388,7 @@ def convert_loop(loop_node, outputs, output_index_map):
name_val_pairs = list(zip(block_input_names, name_val_pairs = list(zip(block_input_names,
[init_loop_iter_val] + init_vals)) [init_loop_iter_val] + init_vals))
_update_outputs_from_pairs(name_val_pairs, outputs, output_index_map) outputs.update(name_val_pairs)
loop_iter_var = _expr.var(block_input_names[0], shape=(), loop_iter_var = _expr.var(block_input_names[0], shape=(),
dtype=loop_iter_dtype) dtype=loop_iter_dtype)
...@@ -1386,36 +1400,30 @@ def convert_loop(loop_node, outputs, output_index_map): ...@@ -1386,36 +1400,30 @@ def convert_loop(loop_node, outputs, output_index_map):
return [_expr.TupleGetItem(loop_val, i+1) for i in range(num_loop_var)] return [_expr.TupleGetItem(loop_val, i+1) for i in range(num_loop_var)]
def convert_operators(operators, outputs, output_index_map, ret_names): def convert_operators(operators, outputs, ret_names):
""" Convert each Torch IR operators to Relay equivalent """ """ Convert each Torch IR operators to Relay equivalent """
for node_name, op_node in operators: for node_name, op_node in operators:
operator = op_node.kind() operator = op_node.kind()
inputs = _get_op_inputs(op_node, outputs, output_index_map) inputs = _get_op_inputs(op_node, outputs)
if operator == "prim::Constant": if operator == "prim::Constant":
output_index_map[node_name] = len(outputs) outputs[node_name] = _get_constant(op_node)
outputs.append(_get_constant(op_node))
elif operator == 'prim::ListConstruct' and _is_int_seq(inputs): elif operator == 'prim::ListConstruct' and _is_int_seq(inputs):
output_index_map[node_name] = len(outputs) outputs[node_name] = _expr.var(node_name, shape=inputs)
outputs.append(_expr.var(node_name, shape=inputs))
elif operator in ['prim::ListConstruct', 'prim::TupleConstruct']: elif operator in ['prim::ListConstruct', 'prim::TupleConstruct']:
output_index_map[node_name] = len(outputs) outputs[node_name] = inputs
outputs.append(inputs)
elif operator in ["prim::ListUnpack", 'prim::TupleUnpack']: elif operator in ["prim::ListUnpack", 'prim::TupleUnpack']:
assert len(inputs) == 1 assert len(inputs) == 1
unpacked_names = _get_output_names(op_node) unpacked_names = _get_output_names(op_node)
_update_outputs_from_pairs(zip(unpacked_names, inputs[0]), outputs.update(zip(unpacked_names, inputs[0]))
outputs, output_index_map)
elif operator == "prim::If": elif operator == "prim::If":
if_out = convert_if(op_node, outputs, output_index_map) if_out = convert_if(op_node, outputs)
output_index_map[node_name] = len(outputs) outputs[node_name] = if_out
outputs.append(if_out)
elif operator == "prim::Loop": elif operator == "prim::Loop":
loop_out = convert_loop(op_node, outputs, output_index_map) loop_out = convert_loop(op_node, outputs)
unpacked_names = _get_output_names(op_node) unpacked_names = _get_output_names(op_node)
assert len(loop_out) == len(unpacked_names) assert len(loop_out) == len(unpacked_names)
_update_outputs_from_pairs(zip(unpacked_names, loop_out), outputs.update(zip(unpacked_names, loop_out))
outputs, output_index_map)
else: else:
relay_op = _convert_map[operator] relay_op = _convert_map[operator]
relay_out = relay_op(inputs, _get_input_types(op_node)) relay_out = relay_op(inputs, _get_input_types(op_node))
...@@ -1424,13 +1432,11 @@ def convert_operators(operators, outputs, output_index_map, ret_names): ...@@ -1424,13 +1432,11 @@ def convert_operators(operators, outputs, output_index_map, ret_names):
# This is for torch operators that return multiple outputs # This is for torch operators that return multiple outputs
# See _adaptive_max_2d above for example # See _adaptive_max_2d above for example
out_names = _get_output_names(op_node) out_names = _get_output_names(op_node)
_update_outputs_from_pairs(zip(out_names, relay_out), outputs.update(zip(out_names, relay_out))
outputs, output_index_map)
else: else:
output_index_map[node_name] = len(outputs) outputs[node_name] = relay_out
outputs.append(relay_out)
return [_wrap_const(outputs[output_index_map[ret_name]]) return [_wrap_const(outputs[ret_name])
for ret_name in ret_names] for ret_name in ret_names]
...@@ -1446,11 +1452,11 @@ def get_all_op_names(graph): ...@@ -1446,11 +1452,11 @@ def get_all_op_names(graph):
return set(node.kind() for node in nodes) return set(node.kind() for node in nodes)
def get_graph_input_names(script_module): def _get_graph_input_names(graph):
""" Use this function to set the keys for input_shapes""" """ Get the graph input names (use after graph copy and run jit passes) """
# It seems variable names could change the first time a copy is made # Variable names could change the first time a copy is made and after
# Use the copy of the graph here to prevent troubles later # _run_jit_passes is called, expected that those functions already invoked
ir_inputs = _get_input_names(script_module.graph.copy()) ir_inputs = _get_input_names(graph)
return ir_inputs[1:] # remove self at the 0th arg return ir_inputs[1:] # remove self at the 0th arg
...@@ -1464,9 +1470,10 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None): ...@@ -1464,9 +1470,10 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None):
TorchScripted PyTorch graph TorchScripted PyTorch graph
Note: We currently only support traces (ie: torch.jit.trace(model, input)) Note: We currently only support traces (ie: torch.jit.trace(model, input))
input_shapes : Dictionary of input dimensions input_shapes : List of tuples of input name and input dimensions
Graph level input shape dictionary Graph level input shape list
The keys should be the same one returned by get_graph_input_names(...) above The same input names need to be used for deployment, so choose easy to
remember names (such as: input0, input1)
custom_convert_map: Dictionary of str to Relay op custom_convert_map: Dictionary of str to Relay op
A custom op conversion map in the same format as _convert_map above A custom op conversion map in the same format as _convert_map above
...@@ -1487,30 +1494,28 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None): ...@@ -1487,30 +1494,28 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None):
op_names = get_all_op_names(graph) op_names = get_all_op_names(graph)
_report_missing_conversion(op_names) _report_missing_conversion(op_names)
_check_input_names(script_module, input_shapes) _check_inputs(graph, input_shapes)
params = script_module.state_dict() params = script_module.state_dict()
input_vars = _get_relay_input_vars(input_shapes) outputs = _get_relay_input_vars(graph, input_shapes)
param_vars, tensors, packed_param_map = convert_params(graph, params) param_vars, tensors, packed_param_map = convert_params(graph, params)
tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()} tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()}
input_vars.update(param_vars) outputs.update(param_vars)
outputs = list(input_vars.values())
output_index_map = dict(zip(input_vars.keys(), range(len(outputs))))
ret_name = _get_input_names(graph.return_node()) ret_name = _get_input_names(graph.return_node())
# For quantized models # For quantized models
if "aten::quantize_per_tensor" in op_names: if "aten::quantize_per_tensor" in op_names:
weight_quant_params = qnn_torch.get_weight_quant_params(script_module) weight_quant_params = qnn_torch.get_weight_quant_params(script_module)
qnn_torch.add_input_quant_params_to_op_inputs(graph) qnn_torch.add_input_quant_params_to_op_inputs(graph)
qnn_torch.add_quant_params_to_outputs(outputs, output_index_map, qnn_torch.add_quant_params_to_outputs(outputs,
packed_param_map, packed_param_map,
weight_quant_params) weight_quant_params)
qnn_torch.add_quant_params(tvm_params, weight_quant_params) qnn_torch.add_quant_params(tvm_params, weight_quant_params)
_convert_map.update(qnn_torch.convert_map) _convert_map.update(qnn_torch.convert_map)
ret = convert_operators(_get_operator_nodes(graph.nodes()), outputs, ret = convert_operators(_get_operator_nodes(graph.nodes()),
output_index_map, ret_name) outputs, ret_name)
if isinstance(ret[0], list): if isinstance(ret[0], list):
ret[0] = _expr.Tuple(ret[0]) ret[0] = _expr.Tuple(ret[0])
......
...@@ -101,20 +101,19 @@ def get_weight_quant_params(script_module): ...@@ -101,20 +101,19 @@ def get_weight_quant_params(script_module):
return quant_params return quant_params
def add_quant_params_to_outputs(outputs, output_index_map, def add_quant_params_to_outputs(outputs, packed_param_map,
packed_param_map, quant_params): quant_params):
""" """
Add quant params to outputs so that they can be referenced by other Add quant params to outputs so that they can be referenced by other
ops later. Weights are quantized here. ops later. Weights are quantized here.
""" """
for node_name, packed_param_name in packed_param_map.items(): for node_name, packed_param_name in packed_param_map.items():
qparam = quant_params[packed_param_name] qparam = quant_params[packed_param_name]
output_index_map[node_name] = len(outputs)
qweight = relay.qnn.op.quantize(qparam.weight_var, qparam.scale, qweight = relay.qnn.op.quantize(qparam.weight_var, qparam.scale,
qparam.zero_point, out_dtype="int8", qparam.zero_point, out_dtype="int8",
axis=0) axis=0)
param_tup = (qweight, qparam.scale, qparam.zero_point, qparam.bias_var) param_tup = (qweight, qparam.scale, qparam.zero_point, qparam.bias_var)
outputs.append(param_tup) outputs[node_name] = param_tup
def _get_quant_param_for_input(input_value): def _get_quant_param_for_input(input_value):
......
...@@ -28,7 +28,6 @@ from torch.quantization import fuse_modules, QuantWrapper ...@@ -28,7 +28,6 @@ from torch.quantization import fuse_modules, QuantWrapper
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.relay.frontend.pytorch import get_graph_input_names
from tvm.contrib.download import download_testdata from tvm.contrib.download import download_testdata
...@@ -39,7 +38,7 @@ def torch_version_check(): ...@@ -39,7 +38,7 @@ def torch_version_check():
def get_tvm_runtime(script_module, input_name, ishape): def get_tvm_runtime(script_module, input_name, ishape):
input_shapes = {input_name: ishape} input_shapes = [(input_name, ishape)]
mod, params = relay.frontend.from_pytorch(script_module, input_shapes) mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
with relay.build_config(opt_level=3): with relay.build_config(opt_level=3):
...@@ -287,7 +286,7 @@ def test_quantized_modules(): ...@@ -287,7 +286,7 @@ def test_quantized_modules():
with torch.no_grad(): with torch.no_grad():
pt_result = script_module(inp.clone()).numpy() pt_result = script_module(inp.clone()).numpy()
input_name = get_graph_input_names(script_module)[0] input_name = "input"
runtime = get_tvm_runtime(script_module, input_name, ishape) runtime = get_tvm_runtime(script_module, input_name, ishape)
runtime.set_input(input_name, inp.numpy().copy()) runtime.set_input(input_name, inp.numpy().copy())
runtime.run() runtime.run()
...@@ -383,7 +382,7 @@ def test_quantized_imagenet(): ...@@ -383,7 +382,7 @@ def test_quantized_imagenet():
with torch.no_grad(): with torch.no_grad():
pt_result = script_module(pt_inp).numpy() pt_result = script_module(pt_inp).numpy()
input_name = get_graph_input_names(script_module)[0] input_name = "image"
runtime = get_tvm_runtime(script_module, input_name, (1, 3, 224, 224)) runtime = get_tvm_runtime(script_module, input_name, (1, 3, 224, 224))
runtime.set_input(input_name, inp) runtime.set_input(input_name, inp)
runtime.run() runtime.run()
......
...@@ -28,7 +28,6 @@ import torchvision ...@@ -28,7 +28,6 @@ import torchvision
from tvm import relay from tvm import relay
from tvm.contrib import graph_runtime from tvm.contrib import graph_runtime
from tvm.relay.testing.config import ctx_list from tvm.relay.testing.config import ctx_list
from tvm.relay.frontend.pytorch import get_graph_input_names
sys.setrecursionlimit(10000) sys.setrecursionlimit(10000)
...@@ -169,8 +168,8 @@ def verify_model(model_name, input_data=[], ...@@ -169,8 +168,8 @@ def verify_model(model_name, input_data=[],
else: else:
trace = trace.cpu() trace = trace.cpu()
input_names = get_graph_input_names(trace) input_names = ["input{}".format(idx) for idx, inp in enumerate(baseline_input)]
input_shapes = dict(zip(input_names, input_shapes = list(zip(input_names,
[inp.shape for inp in baseline_input])) [inp.shape for inp in baseline_input]))
mod, params = relay.frontend.from_pytorch(trace, input_shapes, mod, params = relay.frontend.from_pytorch(trace, input_shapes,
custom_convert_map) custom_convert_map)
...@@ -888,11 +887,12 @@ def test_3d_models(): ...@@ -888,11 +887,12 @@ def test_3d_models():
def verify_script_model(pt_model, ishapes): def verify_script_model(pt_model, ishapes):
script_module = torch.jit.script(pt_model) script_module = torch.jit.script(pt_model)
input_names = get_graph_input_names(script_module)
input_shapes = dict(zip(input_names, ishapes))
inputs = [torch.randn(input_shapes[input_name], dtype=torch.float) input_names = ["i{}".format(idx) for idx, ish in enumerate(ishapes)]
for input_name in input_names] input_shapes = list(zip(input_names, ishapes))
inputs = [torch.randn(shape, dtype=torch.float)
for shape in ishapes]
mod, params = relay.frontend.from_pytorch(script_module, input_shapes) mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
......
...@@ -103,11 +103,14 @@ def verify_get_valid_counts(dshape, score_threshold, id_index, score_index): ...@@ -103,11 +103,14 @@ def verify_get_valid_counts(dshape, score_threshold, id_index, score_index):
tvm.testing.assert_allclose(tvm_out1.asnumpy(), np_out1, rtol=1e-3) tvm.testing.assert_allclose(tvm_out1.asnumpy(), np_out1, rtol=1e-3)
tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3) tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3)
""" Skip this test as it is intermittent
see https://github.com/apache/incubator-tvm/pull/4901#issuecomment-595040094
for device in ['llvm', 'cuda', 'opencl']: for device in ['llvm', 'cuda', 'opencl']:
# Disable opencl test for now # Disable opencl test for now
if device != "llvm" and device != "cuda": if device != "llvm" and device != "cuda":
continue continue
check_device(device) check_device(device)
"""
def test_get_valid_counts(): def test_get_valid_counts():
......
...@@ -47,7 +47,6 @@ from tvm import relay ...@@ -47,7 +47,6 @@ from tvm import relay
import numpy as np import numpy as np
from tvm.contrib.download import download_testdata from tvm.contrib.download import download_testdata
from tvm.relay.frontend.pytorch import get_graph_input_names
# PyTorch imports # PyTorch imports
import torch import torch
...@@ -90,10 +89,10 @@ img = np.expand_dims(img, 0) ...@@ -90,10 +89,10 @@ img = np.expand_dims(img, 0)
# Import the graph to Relay # Import the graph to Relay
# ------------------------- # -------------------------
# Convert PyTorch graph to Relay graph. # Convert PyTorch graph to Relay graph.
input_name = get_graph_input_names(scripted_model)[0] # only one input input_name = 'input0' # only one input, set it to this name
shape_dict = {input_name: img.shape} shape_list = [(input_name, img.shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, mod, params = relay.frontend.from_pytorch(scripted_model,
shape_dict) shape_list)
###################################################################### ######################################################################
# Relay Build # Relay Build
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment