Unverified Commit 03cbf78e by Jeremy Johnson Committed by GitHub

[Frontend][Torch] Fix up graph input handling (#5204)

* [Frontend][Torch] Simplify operator input handling

* [Frontend][Torch] Allow user supplied input names to override graph inputs

* Fix pylint issues

* Updates from code review feedback

* Fix tutorial to use shape list input

* Disable intermittent test failure in topi vision test
parent 15b1751c
...@@ -101,20 +101,19 @@ def get_weight_quant_params(script_module): ...@@ -101,20 +101,19 @@ def get_weight_quant_params(script_module):
return quant_params return quant_params
def add_quant_params_to_outputs(outputs, output_index_map, def add_quant_params_to_outputs(outputs, packed_param_map,
packed_param_map, quant_params): quant_params):
""" """
Add quant params to outputs so that they can be referenced by other Add quant params to outputs so that they can be referenced by other
ops later. Weights are quantized here. ops later. Weights are quantized here.
""" """
for node_name, packed_param_name in packed_param_map.items(): for node_name, packed_param_name in packed_param_map.items():
qparam = quant_params[packed_param_name] qparam = quant_params[packed_param_name]
output_index_map[node_name] = len(outputs)
qweight = relay.qnn.op.quantize(qparam.weight_var, qparam.scale, qweight = relay.qnn.op.quantize(qparam.weight_var, qparam.scale,
qparam.zero_point, out_dtype="int8", qparam.zero_point, out_dtype="int8",
axis=0) axis=0)
param_tup = (qweight, qparam.scale, qparam.zero_point, qparam.bias_var) param_tup = (qweight, qparam.scale, qparam.zero_point, qparam.bias_var)
outputs.append(param_tup) outputs[node_name] = param_tup
def _get_quant_param_for_input(input_value): def _get_quant_param_for_input(input_value):
......
...@@ -28,7 +28,6 @@ from torch.quantization import fuse_modules, QuantWrapper ...@@ -28,7 +28,6 @@ from torch.quantization import fuse_modules, QuantWrapper
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.relay.frontend.pytorch import get_graph_input_names
from tvm.contrib.download import download_testdata from tvm.contrib.download import download_testdata
...@@ -39,7 +38,7 @@ def torch_version_check(): ...@@ -39,7 +38,7 @@ def torch_version_check():
def get_tvm_runtime(script_module, input_name, ishape): def get_tvm_runtime(script_module, input_name, ishape):
input_shapes = {input_name: ishape} input_shapes = [(input_name, ishape)]
mod, params = relay.frontend.from_pytorch(script_module, input_shapes) mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
with relay.build_config(opt_level=3): with relay.build_config(opt_level=3):
...@@ -287,7 +286,7 @@ def test_quantized_modules(): ...@@ -287,7 +286,7 @@ def test_quantized_modules():
with torch.no_grad(): with torch.no_grad():
pt_result = script_module(inp.clone()).numpy() pt_result = script_module(inp.clone()).numpy()
input_name = get_graph_input_names(script_module)[0] input_name = "input"
runtime = get_tvm_runtime(script_module, input_name, ishape) runtime = get_tvm_runtime(script_module, input_name, ishape)
runtime.set_input(input_name, inp.numpy().copy()) runtime.set_input(input_name, inp.numpy().copy())
runtime.run() runtime.run()
...@@ -383,7 +382,7 @@ def test_quantized_imagenet(): ...@@ -383,7 +382,7 @@ def test_quantized_imagenet():
with torch.no_grad(): with torch.no_grad():
pt_result = script_module(pt_inp).numpy() pt_result = script_module(pt_inp).numpy()
input_name = get_graph_input_names(script_module)[0] input_name = "image"
runtime = get_tvm_runtime(script_module, input_name, (1, 3, 224, 224)) runtime = get_tvm_runtime(script_module, input_name, (1, 3, 224, 224))
runtime.set_input(input_name, inp) runtime.set_input(input_name, inp)
runtime.run() runtime.run()
......
...@@ -28,7 +28,6 @@ import torchvision ...@@ -28,7 +28,6 @@ import torchvision
from tvm import relay from tvm import relay
from tvm.contrib import graph_runtime from tvm.contrib import graph_runtime
from tvm.relay.testing.config import ctx_list from tvm.relay.testing.config import ctx_list
from tvm.relay.frontend.pytorch import get_graph_input_names
sys.setrecursionlimit(10000) sys.setrecursionlimit(10000)
...@@ -169,8 +168,8 @@ def verify_model(model_name, input_data=[], ...@@ -169,8 +168,8 @@ def verify_model(model_name, input_data=[],
else: else:
trace = trace.cpu() trace = trace.cpu()
input_names = get_graph_input_names(trace) input_names = ["input{}".format(idx) for idx, inp in enumerate(baseline_input)]
input_shapes = dict(zip(input_names, input_shapes = list(zip(input_names,
[inp.shape for inp in baseline_input])) [inp.shape for inp in baseline_input]))
mod, params = relay.frontend.from_pytorch(trace, input_shapes, mod, params = relay.frontend.from_pytorch(trace, input_shapes,
custom_convert_map) custom_convert_map)
...@@ -888,11 +887,12 @@ def test_3d_models(): ...@@ -888,11 +887,12 @@ def test_3d_models():
def verify_script_model(pt_model, ishapes): def verify_script_model(pt_model, ishapes):
script_module = torch.jit.script(pt_model) script_module = torch.jit.script(pt_model)
input_names = get_graph_input_names(script_module)
input_shapes = dict(zip(input_names, ishapes))
inputs = [torch.randn(input_shapes[input_name], dtype=torch.float) input_names = ["i{}".format(idx) for idx, ish in enumerate(ishapes)]
for input_name in input_names] input_shapes = list(zip(input_names, ishapes))
inputs = [torch.randn(shape, dtype=torch.float)
for shape in ishapes]
mod, params = relay.frontend.from_pytorch(script_module, input_shapes) mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
......
...@@ -103,11 +103,14 @@ def verify_get_valid_counts(dshape, score_threshold, id_index, score_index): ...@@ -103,11 +103,14 @@ def verify_get_valid_counts(dshape, score_threshold, id_index, score_index):
tvm.testing.assert_allclose(tvm_out1.asnumpy(), np_out1, rtol=1e-3) tvm.testing.assert_allclose(tvm_out1.asnumpy(), np_out1, rtol=1e-3)
tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3) tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3)
""" Skip this test as it is intermittent
see https://github.com/apache/incubator-tvm/pull/4901#issuecomment-595040094
for device in ['llvm', 'cuda', 'opencl']: for device in ['llvm', 'cuda', 'opencl']:
# Disable opencl test for now # Disable opencl test for now
if device != "llvm" and device != "cuda": if device != "llvm" and device != "cuda":
continue continue
check_device(device) check_device(device)
"""
def test_get_valid_counts(): def test_get_valid_counts():
......
...@@ -47,7 +47,6 @@ from tvm import relay ...@@ -47,7 +47,6 @@ from tvm import relay
import numpy as np import numpy as np
from tvm.contrib.download import download_testdata from tvm.contrib.download import download_testdata
from tvm.relay.frontend.pytorch import get_graph_input_names
# PyTorch imports # PyTorch imports
import torch import torch
...@@ -90,10 +89,10 @@ img = np.expand_dims(img, 0) ...@@ -90,10 +89,10 @@ img = np.expand_dims(img, 0)
# Import the graph to Relay # Import the graph to Relay
# ------------------------- # -------------------------
# Convert PyTorch graph to Relay graph. # Convert PyTorch graph to Relay graph.
input_name = get_graph_input_names(scripted_model)[0] # only one input input_name = 'input0' # only one input, set it to this name
shape_dict = {input_name: img.shape} shape_list = [(input_name, img.shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, mod, params = relay.frontend.from_pytorch(scripted_model,
shape_dict) shape_list)
###################################################################### ######################################################################
# Relay Build # Relay Build
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment