Unverified Commit 7ccb4363 by masahi Committed by GitHub

[Relay, Torch] Clean up and refactor PyTorch frontend (#4944)

* The initial import of refactored implementation, all tests passed

* enable mobilenet v2 test

* minor cleanup

* reorg

* fix lint

* use input names that come with torch IR

* fix typo

* introduce parse_operators

* fix lint

* add _ prefix
parent a6fae5ed
......@@ -31,6 +31,8 @@ import torchvision
from tvm import relay
from tvm.contrib import graph_runtime
from tvm.relay.testing.config import ctx_list
from tvm.relay.frontend.pytorch import get_graph_input_names
sys.setrecursionlimit(10000)
......@@ -94,6 +96,7 @@ def load_model(model_name):
if hasattr(torchvision.models, model_name):
return load_torchvision(model_name)
try:
import pretrainedmodels
if hasattr(pretrainedmodels, model_name):
return load_pretrainedmodels(model_name)
except ModuleNotFoundError:
......@@ -167,16 +170,15 @@ def verify_model(model_name, input_data=[]):
baseline_outputs = tuple(out.cpu().numpy() for out in baseline_outputs)
else:
baseline_outputs = (baseline_outputs.float().cpu().numpy(),)
output_shapes = [out.shape for out in baseline_outputs]
dtype = "float32"
input_name = "input0"
input_shapes = {input_name: list(baseline_input.shape)}
trace = torch.jit.trace(baseline_model, baseline_input).float().eval()
if torch.cuda.is_available():
trace = trace.cuda()
else:
trace = trace.cpu()
input_name = get_graph_input_names(trace)[0] # only one input
input_shapes = {input_name: list(baseline_input.shape)}
mod, params = relay.frontend.from_pytorch(trace, input_shapes)
compiled_input = {input_name: tvm.nd.array(baseline_input.cpu().numpy())}
......@@ -276,7 +278,7 @@ def test_forward_multiply():
class Multiply2(Module):
def forward(self, *args):
return args[0] * 1
return args[0] * 1.0
class Multiply3(Module):
def forward(self, *args):
......@@ -507,7 +509,7 @@ def test_forward_size():
class Size1(Module):
def forward(self, *args):
return args[0].size(0) * args[0]
return float(args[0].size(0)) * args[0]
with torch.no_grad():
input_data = torch.rand(input_shape).float()
......@@ -708,6 +710,10 @@ def test_mnasnet0_5():
torch.set_grad_enabled(False)
verify_model("mnasnet0_5")
def test_mobilenet_v2():
torch.set_grad_enabled(False)
verify_model("mobilenet_v2")
"""
#TODO: Fix VGG and AlexNet issues (probably due to pooling)
def test_alexnet():
......@@ -721,13 +727,9 @@ def test_vgg11():
def test_vgg11_bn():
torch.set_grad_enabled(False)
verify_model("vgg11_bn")
#TODO: Need to update schedule in tophub file after PR #4787 updated workloads
def test_mobilenet_v2():
torch.set_grad_enabled(False)
verify_model("mobilenet_v2")
"""
if __name__ == "__main__":
# Single operator tests
test_forward_add()
......@@ -767,3 +769,4 @@ if __name__ == "__main__":
test_inception_v3()
test_googlenet()
test_mnasnet0_5()
test_mobilenet_v2()
......@@ -41,14 +41,13 @@ Currently, TVM supports PyTorch 1.4, 1.3, and 1.2. Other versions may
be unstable.
"""
# tvm, relay
import tvm
from tvm import relay
# numpy, packaging
import numpy as np
from packaging import version
from tvm.contrib.download import download_testdata
from tvm.relay.frontend.pytorch import get_graph_input_names
# PyTorch imports
import torch
......@@ -91,7 +90,8 @@ img = np.expand_dims(img, 0)
# Import the graph to Relay
# -------------------------
# Convert PyTorch graph to Relay graph.
shape_dict = {'img': img.shape}
input_name = get_graph_input_names(scripted_model)[0] # only one input
shape_dict = {input_name: img.shape}
mod, params = relay.frontend.from_pytorch(scripted_model,
shape_dict)
......@@ -116,12 +116,12 @@ from tvm.contrib import graph_runtime
dtype = 'float32'
m = graph_runtime.create(graph, lib, ctx)
# Set inputs
m.set_input('img', tvm.nd.array(img.astype(dtype)))
m.set_input(input_name, tvm.nd.array(img.astype(dtype)))
m.set_input(**params)
# Execute
m.run()
# Get outputs
tvm_output = m.get_output(0, tvm.nd.empty(((1, 1000)), 'float32'))
tvm_output = m.get_output(0)
#####################################################################
# Look up synset name
......@@ -163,4 +163,4 @@ with torch.no_grad():
torch_class_key = class_id_to_key[top1_torch]
print('Relay top-1 id: {}, class name: {}'.format(top1_tvm, key_to_classname[tvm_class_key]))
print('Torch top-1 id: {}, class name: {}'.format(top1_torch, key_to_classname[torch_class_key]))
\ No newline at end of file
print('Torch top-1 id: {}, class name: {}'.format(top1_torch, key_to_classname[torch_class_key]))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment