Unverified Commit 03cbf78e by Jeremy Johnson Committed by GitHub

[Frontend][Torch] Fix up graph input handling (#5204)

* [Frontend][Torch] Simplify operator input handling

* [Frontend][Torch] Allow user supplied input names to override graph inputs

* Fix pylint issues

* Updates from code review feedback

* Fix tutorial to use shape list input

* Disable intermittent test failure in topi vision test
parent 15b1751c
......@@ -1071,16 +1071,8 @@ def _get_input_names(node_or_graph):
return [inp.debugName() for inp in node_or_graph.inputs()]
def _get_op_inputs(op_node, outputs, output_index_map):
input_names = [output_index_map[name]
for name in _get_input_names(op_node)]
return [outputs[name] for name in input_names]
def _update_outputs_from_pairs(name_output_pairs, outputs, output_index_map):
for output_name, output in name_output_pairs:
output_index_map[output_name] = len(outputs)
outputs.append(output)
def _get_op_inputs(op_node, outputs):
return [outputs[name] for name in _get_input_names(op_node)]
def _report_missing_conversion(op_names):
......@@ -1100,18 +1092,31 @@ def _report_missing_conversion(op_names):
raise NotImplementedError(msg)
def _check_input_names(script_module, input_shapes):
""" Check the graph inputs match the inputs """
ir_inputs = get_graph_input_names(script_module)
def _check_inputs(graph, input_shapes):
"""
Check the graph inputs match the expected number of inputs
and are in the correct format
"""
ir_inputs = _get_graph_input_names(graph)
for ir_input in ir_inputs:
if ir_input not in input_shapes:
msg = "Missing graph input {} in input_shapes".format(ir_input)
if not isinstance(input_shapes, list):
msg = "Graph inputs input_shapes should be list"
raise RuntimeError(msg)
missing_inputs = len(ir_inputs) - len(input_shapes)
if missing_inputs > 0:
msg = "Missing {} graph input(s) in input_shapes".format(missing_inputs)
raise RuntimeError(msg)
for input_name in input_shapes:
if input_name not in ir_inputs:
msg = "Unused graph input {} in input_shapes".format(input_name)
for num, inp in enumerate(input_shapes):
if num < len(ir_inputs):
if not isinstance(inp, tuple):
msg = "Graph input {} is not a tuple".format(num)
raise RuntimeError(msg)
if (len(inp) != 2 or not isinstance(inp[0], str)):
msg = "Graph input {} is not valid, expected ('name', shape)".format(inp)
raise RuntimeError(msg)
else:
msg = "Unused graph input {} in input_shapes".format(inp)
logging.warning(msg)
......@@ -1203,10 +1208,19 @@ def _get_operator_nodes(nodes):
return ops
def _get_relay_input_vars(input_shapes):
""" Return Relay vars from input shapes """
return {iname: _expr.var(iname, shape=ishape)
for iname, ishape in input_shapes.items()}
def _get_relay_input_vars(graph, input_shapes):
"""
Return Relay vars from input shapes and create entries based on
expected graph inputs - to allow translation
"""
input_vars = {}
ir_inputs = _get_graph_input_names(graph)
for ir_input, (name, shape) in zip(ir_inputs, input_shapes):
inp = _expr.var(name, shape=shape)
# Translate from graph input to user input name
input_vars[ir_input] = inp
return input_vars
def get_use_chains(root_node, terminate=lambda _: False):
......@@ -1284,24 +1298,24 @@ def convert_params(graph, state_dict):
return params, param_tensors, packed_param_map
def convert_block(block, outputs, output_index_map):
def convert_block(block, outputs):
""" Translate Torch "Block", used for prim::If and prim::Loop """
ops = _get_operator_nodes(block.nodes())
ret_names = _get_input_names(block.returnNode())
return convert_operators(ops, outputs, output_index_map, ret_names)
return convert_operators(ops, outputs, ret_names)
def convert_if(if_node, outputs, output_index_map):
def convert_if(if_node, outputs):
""" Translate Torch prim::If to Relay If """
cond = outputs[output_index_map[if_node.inputsAt(0).debugName()]]
cond = outputs[if_node.inputsAt(0).debugName()]
blocks = list(if_node.blocks())
true_branch = convert_block(blocks[0], outputs, output_index_map)
false_branch = convert_block(blocks[1], outputs, output_index_map)
true_branch = convert_block(blocks[0], outputs)
false_branch = convert_block(blocks[1], outputs)
assert len(true_branch) == 1 and len(false_branch) == 1
return _expr.If(cond, true_branch[0], false_branch[0])
def convert_loop(loop_node, outputs, output_index_map):
def convert_loop(loop_node, outputs):
""" Translate Torch prim::Loop to Relay while_loop """
def get_input(index):
ivalue = loop_node.inputsAt(index)
......@@ -1309,8 +1323,8 @@ def convert_loop(loop_node, outputs, output_index_map):
if inode.kind() == "prim::Constant":
return _expr.const(_get_constant(inode))
var_name = ivalue.debugName()
assert var_name in output_index_map
return _wrap_const(outputs[output_index_map[var_name]])
assert var_name in outputs
return _wrap_const(outputs[var_name])
# Refer to the spec for prim::Loop below
# https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/OVERVIEW.md#loops
......@@ -1342,9 +1356,9 @@ def convert_loop(loop_node, outputs, output_index_map):
# Update loop variables using the prev iteration outputs
assert len(current_vals) == len(block_input_names)
for (i, iname) in enumerate(block_input_names):
outputs[output_index_map[iname]] = current_vals[i]
outputs[iname] = current_vals[i]
block_outputs = convert_block(body_block, outputs, output_index_map)
block_outputs = convert_block(body_block, outputs)
if not is_while_loop:
# iter var increment implicit in torch, so do it manually
......@@ -1374,7 +1388,7 @@ def convert_loop(loop_node, outputs, output_index_map):
name_val_pairs = list(zip(block_input_names,
[init_loop_iter_val] + init_vals))
_update_outputs_from_pairs(name_val_pairs, outputs, output_index_map)
outputs.update(name_val_pairs)
loop_iter_var = _expr.var(block_input_names[0], shape=(),
dtype=loop_iter_dtype)
......@@ -1386,36 +1400,30 @@ def convert_loop(loop_node, outputs, output_index_map):
return [_expr.TupleGetItem(loop_val, i+1) for i in range(num_loop_var)]
def convert_operators(operators, outputs, output_index_map, ret_names):
def convert_operators(operators, outputs, ret_names):
""" Convert each Torch IR operators to Relay equivalent """
for node_name, op_node in operators:
operator = op_node.kind()
inputs = _get_op_inputs(op_node, outputs, output_index_map)
inputs = _get_op_inputs(op_node, outputs)
if operator == "prim::Constant":
output_index_map[node_name] = len(outputs)
outputs.append(_get_constant(op_node))
outputs[node_name] = _get_constant(op_node)
elif operator == 'prim::ListConstruct' and _is_int_seq(inputs):
output_index_map[node_name] = len(outputs)
outputs.append(_expr.var(node_name, shape=inputs))
outputs[node_name] = _expr.var(node_name, shape=inputs)
elif operator in ['prim::ListConstruct', 'prim::TupleConstruct']:
output_index_map[node_name] = len(outputs)
outputs.append(inputs)
outputs[node_name] = inputs
elif operator in ["prim::ListUnpack", 'prim::TupleUnpack']:
assert len(inputs) == 1
unpacked_names = _get_output_names(op_node)
_update_outputs_from_pairs(zip(unpacked_names, inputs[0]),
outputs, output_index_map)
outputs.update(zip(unpacked_names, inputs[0]))
elif operator == "prim::If":
if_out = convert_if(op_node, outputs, output_index_map)
output_index_map[node_name] = len(outputs)
outputs.append(if_out)
if_out = convert_if(op_node, outputs)
outputs[node_name] = if_out
elif operator == "prim::Loop":
loop_out = convert_loop(op_node, outputs, output_index_map)
loop_out = convert_loop(op_node, outputs)
unpacked_names = _get_output_names(op_node)
assert len(loop_out) == len(unpacked_names)
_update_outputs_from_pairs(zip(unpacked_names, loop_out),
outputs, output_index_map)
outputs.update(zip(unpacked_names, loop_out))
else:
relay_op = _convert_map[operator]
relay_out = relay_op(inputs, _get_input_types(op_node))
......@@ -1424,13 +1432,11 @@ def convert_operators(operators, outputs, output_index_map, ret_names):
# This is for torch operators that return multiple outputs
# See _adaptive_max_2d above for example
out_names = _get_output_names(op_node)
_update_outputs_from_pairs(zip(out_names, relay_out),
outputs, output_index_map)
outputs.update(zip(out_names, relay_out))
else:
output_index_map[node_name] = len(outputs)
outputs.append(relay_out)
outputs[node_name] = relay_out
return [_wrap_const(outputs[output_index_map[ret_name]])
return [_wrap_const(outputs[ret_name])
for ret_name in ret_names]
......@@ -1446,11 +1452,11 @@ def get_all_op_names(graph):
return set(node.kind() for node in nodes)
def get_graph_input_names(script_module):
""" Use this function to set the keys for input_shapes"""
# It seems variable names could change the first time a copy is made
# Use the copy of the graph here to prevent troubles later
ir_inputs = _get_input_names(script_module.graph.copy())
def _get_graph_input_names(graph):
""" Get the graph input names (use after graph copy and run jit passes) """
# Variable names could change the first time a copy is made and after
# _run_jit_passes is called, expected that those functions already invoked
ir_inputs = _get_input_names(graph)
return ir_inputs[1:] # remove self at the 0th arg
......@@ -1464,9 +1470,10 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None):
TorchScripted PyTorch graph
Note: We currently only support traces (ie: torch.jit.trace(model, input))
input_shapes : Dictionary of input dimensions
Graph level input shape dictionary
The keys should be the same one returned by get_graph_input_names(...) above
input_shapes : List of tuples of input name and input dimensions
Graph level input shape list
The same input names need to be used for deployment, so choose easy to
remember names (such as: input0, input1)
custom_convert_map: Dictionary of str to Relay op
A custom op conversion map in the same format as _convert_map above
......@@ -1487,30 +1494,28 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None):
op_names = get_all_op_names(graph)
_report_missing_conversion(op_names)
_check_input_names(script_module, input_shapes)
_check_inputs(graph, input_shapes)
params = script_module.state_dict()
input_vars = _get_relay_input_vars(input_shapes)
outputs = _get_relay_input_vars(graph, input_shapes)
param_vars, tensors, packed_param_map = convert_params(graph, params)
tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()}
input_vars.update(param_vars)
outputs = list(input_vars.values())
output_index_map = dict(zip(input_vars.keys(), range(len(outputs))))
outputs.update(param_vars)
ret_name = _get_input_names(graph.return_node())
# For quantized models
if "aten::quantize_per_tensor" in op_names:
weight_quant_params = qnn_torch.get_weight_quant_params(script_module)
qnn_torch.add_input_quant_params_to_op_inputs(graph)
qnn_torch.add_quant_params_to_outputs(outputs, output_index_map,
qnn_torch.add_quant_params_to_outputs(outputs,
packed_param_map,
weight_quant_params)
qnn_torch.add_quant_params(tvm_params, weight_quant_params)
_convert_map.update(qnn_torch.convert_map)
ret = convert_operators(_get_operator_nodes(graph.nodes()), outputs,
output_index_map, ret_name)
ret = convert_operators(_get_operator_nodes(graph.nodes()),
outputs, ret_name)
if isinstance(ret[0], list):
ret[0] = _expr.Tuple(ret[0])
......
......@@ -101,20 +101,19 @@ def get_weight_quant_params(script_module):
return quant_params
def add_quant_params_to_outputs(outputs, output_index_map,
packed_param_map, quant_params):
def add_quant_params_to_outputs(outputs, packed_param_map,
quant_params):
"""
Add quant params to outputs so that they can be referenced by other
ops later. Weights are quantized here.
"""
for node_name, packed_param_name in packed_param_map.items():
qparam = quant_params[packed_param_name]
output_index_map[node_name] = len(outputs)
qweight = relay.qnn.op.quantize(qparam.weight_var, qparam.scale,
qparam.zero_point, out_dtype="int8",
axis=0)
param_tup = (qweight, qparam.scale, qparam.zero_point, qparam.bias_var)
outputs.append(param_tup)
outputs[node_name] = param_tup
def _get_quant_param_for_input(input_value):
......
......@@ -28,7 +28,6 @@ from torch.quantization import fuse_modules, QuantWrapper
import tvm
from tvm import relay
from tvm.relay.frontend.pytorch import get_graph_input_names
from tvm.contrib.download import download_testdata
......@@ -39,7 +38,7 @@ def torch_version_check():
def get_tvm_runtime(script_module, input_name, ishape):
input_shapes = {input_name: ishape}
input_shapes = [(input_name, ishape)]
mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
with relay.build_config(opt_level=3):
......@@ -287,7 +286,7 @@ def test_quantized_modules():
with torch.no_grad():
pt_result = script_module(inp.clone()).numpy()
input_name = get_graph_input_names(script_module)[0]
input_name = "input"
runtime = get_tvm_runtime(script_module, input_name, ishape)
runtime.set_input(input_name, inp.numpy().copy())
runtime.run()
......@@ -383,7 +382,7 @@ def test_quantized_imagenet():
with torch.no_grad():
pt_result = script_module(pt_inp).numpy()
input_name = get_graph_input_names(script_module)[0]
input_name = "image"
runtime = get_tvm_runtime(script_module, input_name, (1, 3, 224, 224))
runtime.set_input(input_name, inp)
runtime.run()
......
......@@ -28,7 +28,6 @@ import torchvision
from tvm import relay
from tvm.contrib import graph_runtime
from tvm.relay.testing.config import ctx_list
from tvm.relay.frontend.pytorch import get_graph_input_names
sys.setrecursionlimit(10000)
......@@ -169,8 +168,8 @@ def verify_model(model_name, input_data=[],
else:
trace = trace.cpu()
input_names = get_graph_input_names(trace)
input_shapes = dict(zip(input_names,
input_names = ["input{}".format(idx) for idx, inp in enumerate(baseline_input)]
input_shapes = list(zip(input_names,
[inp.shape for inp in baseline_input]))
mod, params = relay.frontend.from_pytorch(trace, input_shapes,
custom_convert_map)
......@@ -888,11 +887,12 @@ def test_3d_models():
def verify_script_model(pt_model, ishapes):
script_module = torch.jit.script(pt_model)
input_names = get_graph_input_names(script_module)
input_shapes = dict(zip(input_names, ishapes))
inputs = [torch.randn(input_shapes[input_name], dtype=torch.float)
for input_name in input_names]
input_names = ["i{}".format(idx) for idx, ish in enumerate(ishapes)]
input_shapes = list(zip(input_names, ishapes))
inputs = [torch.randn(shape, dtype=torch.float)
for shape in ishapes]
mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
......
......@@ -103,11 +103,14 @@ def verify_get_valid_counts(dshape, score_threshold, id_index, score_index):
tvm.testing.assert_allclose(tvm_out1.asnumpy(), np_out1, rtol=1e-3)
tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3)
""" Skip this test as it is intermittent
see https://github.com/apache/incubator-tvm/pull/4901#issuecomment-595040094
for device in ['llvm', 'cuda', 'opencl']:
# Disable opencl test for now
if device != "llvm" and device != "cuda":
continue
check_device(device)
"""
def test_get_valid_counts():
......
......@@ -47,7 +47,6 @@ from tvm import relay
import numpy as np
from tvm.contrib.download import download_testdata
from tvm.relay.frontend.pytorch import get_graph_input_names
# PyTorch imports
import torch
......@@ -90,10 +89,10 @@ img = np.expand_dims(img, 0)
# Import the graph to Relay
# -------------------------
# Convert PyTorch graph to Relay graph.
input_name = get_graph_input_names(scripted_model)[0] # only one input
shape_dict = {input_name: img.shape}
input_name = 'input0' # only one input, set it to this name
shape_list = [(input_name, img.shape)]
mod, params = relay.frontend.from_pytorch(scripted_model,
shape_dict)
shape_list)
######################################################################
# Relay Build
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment