Unverified Commit 03cbf78e by Jeremy Johnson Committed by GitHub

[Frontend][Torch] Fix up graph input handling (#5204)

* [Frontend][Torch] Simplify operator input handling

* [Frontend][Torch] Allow user supplied input names to override graph inputs

* Fix pylint issues

* Updates from code review feedback

* Fix tutorial to use shape list input

* Disable intermittent test failure in topi vision test
parent 15b1751c
......@@ -101,20 +101,19 @@ def get_weight_quant_params(script_module):
return quant_params
def add_quant_params_to_outputs(outputs, output_index_map,
packed_param_map, quant_params):
def add_quant_params_to_outputs(outputs, packed_param_map,
quant_params):
"""
Add quant params to outputs so that they can be referenced by other
ops later. Weights are quantized here.
"""
for node_name, packed_param_name in packed_param_map.items():
qparam = quant_params[packed_param_name]
output_index_map[node_name] = len(outputs)
qweight = relay.qnn.op.quantize(qparam.weight_var, qparam.scale,
qparam.zero_point, out_dtype="int8",
axis=0)
param_tup = (qweight, qparam.scale, qparam.zero_point, qparam.bias_var)
outputs.append(param_tup)
outputs[node_name] = param_tup
def _get_quant_param_for_input(input_value):
......
......@@ -28,7 +28,6 @@ from torch.quantization import fuse_modules, QuantWrapper
import tvm
from tvm import relay
from tvm.relay.frontend.pytorch import get_graph_input_names
from tvm.contrib.download import download_testdata
......@@ -39,7 +38,7 @@ def torch_version_check():
def get_tvm_runtime(script_module, input_name, ishape):
input_shapes = {input_name: ishape}
input_shapes = [(input_name, ishape)]
mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
with relay.build_config(opt_level=3):
......@@ -287,7 +286,7 @@ def test_quantized_modules():
with torch.no_grad():
pt_result = script_module(inp.clone()).numpy()
input_name = get_graph_input_names(script_module)[0]
input_name = "input"
runtime = get_tvm_runtime(script_module, input_name, ishape)
runtime.set_input(input_name, inp.numpy().copy())
runtime.run()
......@@ -383,7 +382,7 @@ def test_quantized_imagenet():
with torch.no_grad():
pt_result = script_module(pt_inp).numpy()
input_name = get_graph_input_names(script_module)[0]
input_name = "image"
runtime = get_tvm_runtime(script_module, input_name, (1, 3, 224, 224))
runtime.set_input(input_name, inp)
runtime.run()
......
......@@ -28,7 +28,6 @@ import torchvision
from tvm import relay
from tvm.contrib import graph_runtime
from tvm.relay.testing.config import ctx_list
from tvm.relay.frontend.pytorch import get_graph_input_names
sys.setrecursionlimit(10000)
......@@ -169,8 +168,8 @@ def verify_model(model_name, input_data=[],
else:
trace = trace.cpu()
input_names = get_graph_input_names(trace)
input_shapes = dict(zip(input_names,
input_names = ["input{}".format(idx) for idx, inp in enumerate(baseline_input)]
input_shapes = list(zip(input_names,
[inp.shape for inp in baseline_input]))
mod, params = relay.frontend.from_pytorch(trace, input_shapes,
custom_convert_map)
......@@ -888,11 +887,12 @@ def test_3d_models():
def verify_script_model(pt_model, ishapes):
script_module = torch.jit.script(pt_model)
input_names = get_graph_input_names(script_module)
input_shapes = dict(zip(input_names, ishapes))
inputs = [torch.randn(input_shapes[input_name], dtype=torch.float)
for input_name in input_names]
input_names = ["i{}".format(idx) for idx, ish in enumerate(ishapes)]
input_shapes = list(zip(input_names, ishapes))
inputs = [torch.randn(shape, dtype=torch.float)
for shape in ishapes]
mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
......
......@@ -103,11 +103,14 @@ def verify_get_valid_counts(dshape, score_threshold, id_index, score_index):
tvm.testing.assert_allclose(tvm_out1.asnumpy(), np_out1, rtol=1e-3)
tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3)
""" Skip this test as it is intermittent
see https://github.com/apache/incubator-tvm/pull/4901#issuecomment-595040094
for device in ['llvm', 'cuda', 'opencl']:
# Disable opencl test for now
if device != "llvm" and device != "cuda":
continue
check_device(device)
"""
def test_get_valid_counts():
......
......@@ -47,7 +47,6 @@ from tvm import relay
import numpy as np
from tvm.contrib.download import download_testdata
from tvm.relay.frontend.pytorch import get_graph_input_names
# PyTorch imports
import torch
......@@ -90,10 +89,10 @@ img = np.expand_dims(img, 0)
# Import the graph to Relay
# -------------------------
# Convert PyTorch graph to Relay graph.
input_name = get_graph_input_names(scripted_model)[0] # only one input
shape_dict = {input_name: img.shape}
input_name = 'input0' # only one input, set it to this name
shape_list = [(input_name, img.shape)]
mod, params = relay.frontend.from_pytorch(scripted_model,
shape_dict)
shape_list)
######################################################################
# Relay Build
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment