Unverified Commit 7ccb4363 by masahi Committed by GitHub

[Relay, Torch] Clean up and refactor PyTorch frontend (#4944)

* The initial import of refactored implementation, all tests passed

* enable mobilenet v2 test

* minor cleanup

* reorg

* fix lint

* use input names that come with torch IR

* fix typo

* introduce parse_operators

* fix lint

* add _ prefix
parent a6fae5ed
...@@ -31,6 +31,8 @@ import torchvision ...@@ -31,6 +31,8 @@ import torchvision
from tvm import relay from tvm import relay
from tvm.contrib import graph_runtime from tvm.contrib import graph_runtime
from tvm.relay.testing.config import ctx_list from tvm.relay.testing.config import ctx_list
from tvm.relay.frontend.pytorch import get_graph_input_names
sys.setrecursionlimit(10000) sys.setrecursionlimit(10000)
...@@ -94,6 +96,7 @@ def load_model(model_name): ...@@ -94,6 +96,7 @@ def load_model(model_name):
if hasattr(torchvision.models, model_name): if hasattr(torchvision.models, model_name):
return load_torchvision(model_name) return load_torchvision(model_name)
try: try:
import pretrainedmodels
if hasattr(pretrainedmodels, model_name): if hasattr(pretrainedmodels, model_name):
return load_pretrainedmodels(model_name) return load_pretrainedmodels(model_name)
except ModuleNotFoundError: except ModuleNotFoundError:
...@@ -167,16 +170,15 @@ def verify_model(model_name, input_data=[]): ...@@ -167,16 +170,15 @@ def verify_model(model_name, input_data=[]):
baseline_outputs = tuple(out.cpu().numpy() for out in baseline_outputs) baseline_outputs = tuple(out.cpu().numpy() for out in baseline_outputs)
else: else:
baseline_outputs = (baseline_outputs.float().cpu().numpy(),) baseline_outputs = (baseline_outputs.float().cpu().numpy(),)
output_shapes = [out.shape for out in baseline_outputs]
dtype = "float32"
input_name = "input0"
input_shapes = {input_name: list(baseline_input.shape)}
trace = torch.jit.trace(baseline_model, baseline_input).float().eval() trace = torch.jit.trace(baseline_model, baseline_input).float().eval()
if torch.cuda.is_available(): if torch.cuda.is_available():
trace = trace.cuda() trace = trace.cuda()
else: else:
trace = trace.cpu() trace = trace.cpu()
input_name = get_graph_input_names(trace)[0] # only one input
input_shapes = {input_name: list(baseline_input.shape)}
mod, params = relay.frontend.from_pytorch(trace, input_shapes) mod, params = relay.frontend.from_pytorch(trace, input_shapes)
compiled_input = {input_name: tvm.nd.array(baseline_input.cpu().numpy())} compiled_input = {input_name: tvm.nd.array(baseline_input.cpu().numpy())}
...@@ -276,7 +278,7 @@ def test_forward_multiply(): ...@@ -276,7 +278,7 @@ def test_forward_multiply():
class Multiply2(Module): class Multiply2(Module):
def forward(self, *args): def forward(self, *args):
return args[0] * 1 return args[0] * 1.0
class Multiply3(Module): class Multiply3(Module):
def forward(self, *args): def forward(self, *args):
...@@ -507,7 +509,7 @@ def test_forward_size(): ...@@ -507,7 +509,7 @@ def test_forward_size():
class Size1(Module): class Size1(Module):
def forward(self, *args): def forward(self, *args):
return args[0].size(0) * args[0] return float(args[0].size(0)) * args[0]
with torch.no_grad(): with torch.no_grad():
input_data = torch.rand(input_shape).float() input_data = torch.rand(input_shape).float()
...@@ -708,6 +710,10 @@ def test_mnasnet0_5(): ...@@ -708,6 +710,10 @@ def test_mnasnet0_5():
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
verify_model("mnasnet0_5") verify_model("mnasnet0_5")
def test_mobilenet_v2():
torch.set_grad_enabled(False)
verify_model("mobilenet_v2")
""" """
#TODO: Fix VGG and AlexNet issues (probably due to pooling) #TODO: Fix VGG and AlexNet issues (probably due to pooling)
def test_alexnet(): def test_alexnet():
...@@ -721,13 +727,9 @@ def test_vgg11(): ...@@ -721,13 +727,9 @@ def test_vgg11():
def test_vgg11_bn(): def test_vgg11_bn():
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
verify_model("vgg11_bn") verify_model("vgg11_bn")
#TODO: Need to update schedule in tophub file after PR #4787 updated workloads
def test_mobilenet_v2():
torch.set_grad_enabled(False)
verify_model("mobilenet_v2")
""" """
if __name__ == "__main__": if __name__ == "__main__":
# Single operator tests # Single operator tests
test_forward_add() test_forward_add()
...@@ -767,3 +769,4 @@ if __name__ == "__main__": ...@@ -767,3 +769,4 @@ if __name__ == "__main__":
test_inception_v3() test_inception_v3()
test_googlenet() test_googlenet()
test_mnasnet0_5() test_mnasnet0_5()
test_mobilenet_v2()
...@@ -41,14 +41,13 @@ Currently, TVM supports PyTorch 1.4, 1.3, and 1.2. Other versions may ...@@ -41,14 +41,13 @@ Currently, TVM supports PyTorch 1.4, 1.3, and 1.2. Other versions may
be unstable. be unstable.
""" """
# tvm, relay
import tvm import tvm
from tvm import relay from tvm import relay
# numpy, packaging
import numpy as np import numpy as np
from packaging import version
from tvm.contrib.download import download_testdata from tvm.contrib.download import download_testdata
from tvm.relay.frontend.pytorch import get_graph_input_names
# PyTorch imports # PyTorch imports
import torch import torch
...@@ -91,7 +90,8 @@ img = np.expand_dims(img, 0) ...@@ -91,7 +90,8 @@ img = np.expand_dims(img, 0)
# Import the graph to Relay # Import the graph to Relay
# ------------------------- # -------------------------
# Convert PyTorch graph to Relay graph. # Convert PyTorch graph to Relay graph.
shape_dict = {'img': img.shape} input_name = get_graph_input_names(scripted_model)[0] # only one input
shape_dict = {input_name: img.shape}
mod, params = relay.frontend.from_pytorch(scripted_model, mod, params = relay.frontend.from_pytorch(scripted_model,
shape_dict) shape_dict)
...@@ -116,12 +116,12 @@ from tvm.contrib import graph_runtime ...@@ -116,12 +116,12 @@ from tvm.contrib import graph_runtime
dtype = 'float32' dtype = 'float32'
m = graph_runtime.create(graph, lib, ctx) m = graph_runtime.create(graph, lib, ctx)
# Set inputs # Set inputs
m.set_input('img', tvm.nd.array(img.astype(dtype))) m.set_input(input_name, tvm.nd.array(img.astype(dtype)))
m.set_input(**params) m.set_input(**params)
# Execute # Execute
m.run() m.run()
# Get outputs # Get outputs
tvm_output = m.get_output(0, tvm.nd.empty(((1, 1000)), 'float32')) tvm_output = m.get_output(0)
##################################################################### #####################################################################
# Look up synset name # Look up synset name
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment