Commit b3b3d28a by Hiroyuki Makino Committed by Yizhi Liu

[Relay][Frontend] Caffe2 Support (#2507)

* [Relay][Frontend] Add Caffe2 Support

* [Relay][Frontend] Add Caffe2 Support (fix unsed import)

* [Relay][Frontend] Add Caffe2 Support (fix caffe2 model import)

* [Relay][Frontend] Add Caffe2 Support (fix model install and reflect code reviews)

* [Relay][Frontend] Add Caffe2 Support (fix caffe2 model import)

* [Relay][Frontend] Add Caffe2 Support (fix caffe2 model import)

* [Relay][Frontend] Add Caffe2 Support (fix caffe2 model import)

* [Relay][Frontend] Add Caffe2 Support (fix caffe2 frontend import)

* [Relay][Frontend] Add Caffe2 Support (rename function name in test_forward)

* [Relay][Frontend] Add Caffe2 Support (fix caffe2 model import)

* [Relay][Frontend] Add Caffe2 Support (fix caffe2 model import)

* [Doc] Caffe2 frontend tutorial

* [Doc] Caffe2 frontend tutorial

* [Doc] Caffe2 frontend tutorial

* [Relay][Frontend] Add Caffe2 Support (remove unsed file)
parent e012f819
......@@ -67,6 +67,9 @@ RUN bash /install/ubuntu_install_onnx.sh
COPY install/ubuntu_install_tflite.sh /install/ubuntu_install_tflite.sh
RUN bash /install/ubuntu_install_tflite.sh
COPY install/ubuntu_install_caffe2.sh /install/ubuntu_install_caffe2.sh
RUN bash /install/ubuntu_install_caffe2.sh
RUN pip3 install Pillow
COPY install/ubuntu_install_vulkan.sh /install/ubuntu_install_vulkan.sh
......
python3 -m caffe2.python.models.download -i -f squeezenet
python3 -m caffe2.python.models.download -i -f resnet50
python3 -m caffe2.python.models.download -i -f vgg19
......@@ -12,3 +12,4 @@ from .keras import from_keras
from .onnx import from_onnx
from .tflite import from_tflite
from .coreml import from_coreml
from .caffe2 import from_caffe2
"""Store for caffe2 examples and common models."""
from __future__ import absolute_import as _abs
import os
import sys
import importlib
from . import squeezenet
from caffe2.python.models.download import ModelDownloader
models = [
'squeezenet',
'resnet50',
'vgg19',
]
mf = ModelDownloader()
class Model:
def __init__(self, model_name):
self.init_net, self.predict_net, self.value_info = mf.get_c2_model(model_name)
for model in models:
try:
locals()['c2_' + model] = importlib.import_module('caffe2.python.models.' + model)
except ImportError:
locals()['c2_' + model] = Model(model)
# squeezenet
def relay_squeezenet():
return squeezenet.get_workload()
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# coding: utf-8
# pylint: disable=unused-argument
"""
Symbol of SqueezeNet
Reference:
Iandola, Forrest N., et al.
"Squeezenet: Alexnet-level accuracy with 50x fewer parameters and< 0.5 mb model size." (2016).
"""
from tvm import relay
from tvm.relay.testing import create_workload
# Helpers
def _make_fire(net, squeeze_channels, expand1x1_channels, expand3x3_channels, prefix=""):
net = _make_fire_conv(net, squeeze_channels, 1, 0, "%s/squeeze1x1" % prefix)
left = _make_fire_conv(net, expand1x1_channels, 1, 0, "%s/expand1x1" % prefix)
right = _make_fire_conv(net, expand3x3_channels, 3, 1, "%s/expand3x3" % prefix)
# NOTE : Assume NCHW layout here
net = relay.concatenate((left, right), axis=1)
return net
def _make_fire_conv(net, channels, kernel_size, padding=0, prefix=""):
net = relay.nn.conv2d(net, relay.var("%s_weight" % prefix),
channels=channels,
kernel_size=(kernel_size, kernel_size),
padding=(padding, padding))
net = relay.nn.bias_add(net, relay.var("%s_bias" % prefix))
net = relay.nn.relu(net)
return net
# Net
def get_net(batch_size, image_shape, num_classes, dtype):
"""Get symbol of SqueezeNet
Parameters
----------
batch_size : int
The batch size used in the model
image_shape : tuple
The input image shape
num_classes: int
The number of classification results
dtype : str
The data type
"""
data_shape = (batch_size,) + image_shape
net = relay.var("data", shape=data_shape, dtype=dtype)
net = relay.nn.conv2d(net, relay.var("conv1_weight"),
channels=64,
kernel_size=(3, 3),
strides=(2, 2),
padding=(0, 0))
net = relay.nn.bias_add(net, relay.var("conv1_bias"))
net = relay.nn.relu(net)
net = relay.nn.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
net = _make_fire(net, 16, 64, 64, 'fire2')
net = _make_fire(net, 16, 64, 64, "fire3")
net = relay.nn.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
net = _make_fire(net, 32, 128, 128, "fire4")
net = _make_fire(net, 32, 128, 128, "fire5")
net = relay.nn.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
net = _make_fire(net, 48, 192, 192, "fire6")
net = _make_fire(net, 48, 192, 192, "fire7")
net = _make_fire(net, 64, 256, 256, "fire8")
net = _make_fire(net, 64, 256, 256, "fire9")
net = relay.nn.dropout(net, rate=0.5)
net = relay.nn.conv2d(net, relay.var('conv10_weight'), channels=num_classes, kernel_size=(1, 1))
net = relay.nn.bias_add(net, relay.var("conv10_bias"))
net = relay.nn.relu(net)
net = relay.nn.global_avg_pool2d(net)
net = relay.nn.softmax(net, axis=1)
args = relay.ir_pass.free_vars(net)
return relay.Function(args, net)
def get_workload(batch_size=1,
image_shape=(3, 224, 224),
num_classes=1000,
dtype="float32"):
"""Get benchmark workload for SqueezeNet
Parameters
----------
batch_size : int, optional
The batch size used in the model
num_classes : int, optional
Number of classes
image_shape : tuple, optional
The input image shape
dtype : str, optional
The data type
Returns
-------
net : relay.Function
The computational graph
params : dict of str to NDArray
The parameters.
"""
net = get_net(batch_size, image_shape, num_classes, dtype)
return create_workload(net)
import numpy as np
import tvm
from tvm.contrib import graph_runtime
from tvm.relay.testing.config import ctx_list
from tvm import relay
from model_zoo import c2_squeezenet, c2_resnet50, c2_vgg19
from caffe2.python import workspace
def get_tvm_output(model,
input_data,
target,
ctx,
output_shape,
output_dtype='float32'):
""" Generic function to execute and get tvm output"""
# supporting multiple inputs in caffe2 in a bit tricky,
# because the input names can appear at the beginning or end of model.predict_net.external_input
assert isinstance(input_data, np.ndarray)
# here we use the first input blob to the first op to get the input name
input_names = model.predict_net.op[0].input[0]
shape_dict = {input_names: input_data.shape}
dtype_dict = {input_names: input_data.dtype}
func, params = relay.frontend.from_caffe2(model.init_net, model.predict_net, shape_dict, dtype_dict)
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(func, target, params=params)
m = graph_runtime.create(graph, lib, ctx)
# set inputs
m.set_input(input_names, tvm.nd.array(input_data.astype(input_data.dtype)))
m.set_input(**params)
# execute
m.run()
# get outputs
if isinstance(output_shape, list) and isinstance(output_dtype, list):
tvm_output_list = []
for i, s in enumerate(output_shape):
tvm_output = m.get_output(i, tvm.nd.empty((s), output_dtype[i]))
tvm_output_list.append(tvm_output.asnumpy())
return tvm_output_list
else:
tvm_output = m.get_output(0, tvm.nd.empty((output_shape),
output_dtype))
return tvm_output.asnumpy()
def get_caffe2_output(model, x, dtype='float32'):
workspace.RunNetOnce(model.init_net)
input_blob = model.predict_net.op[0].input[0]
workspace.FeedBlob(input_blob, x.astype(dtype))
workspace.RunNetOnce(model.predict_net)
output_blob = model.predict_net.external_output[0]
c2_output = workspace.FetchBlob(output_blob)
return c2_output
def verify_caffe2_forward_impl(model, data_shape, out_shape):
dtype = 'float32'
data = np.random.uniform(size=data_shape).astype(dtype)
c2_out = get_caffe2_output(model, data, dtype)
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, data, target, ctx, out_shape, dtype)
tvm.testing.assert_allclose(c2_out, tvm_out, rtol=1e-5, atol=1e-5)
def test_forward_squeezenet1_1():
verify_caffe2_forward_impl(c2_squeezenet, (1, 3, 224, 224), (1, 1000, 1, 1))
def test_forward_resnet50():
verify_caffe2_forward_impl(c2_resnet50, (1, 3, 224, 224), (1, 1000))
def test_forward_vgg19():
verify_caffe2_forward_impl(c2_vgg19, (1, 3, 224, 224), (1, 1000))
if __name__ == '__main__':
test_forward_squeezenet1_1()
test_forward_resnet50()
test_forward_vgg19()
"""Test graph equality of caffe2 models."""
from tvm import relay
from model_zoo import c2_squeezenet, relay_squeezenet
def compare_graph(f1, f2):
f1 = relay.ir_pass.infer_type(f1)
f2 = relay.ir_pass.infer_type(f2)
assert relay.ir_pass.alpha_equal(f1, f2)
def test_squeeze_net():
shape_dict = {'data': (1, 3, 224, 224)}
dtype_dict = {'data': 'float32'}
from_c2_func, _ = relay.frontend.from_caffe2(c2_squeezenet.init_net, c2_squeezenet.predict_net, shape_dict, dtype_dict)
relay_func, _ = relay_squeezenet()
compare_graph(from_c2_func, relay_func)
if __name__ == '__main__':
test_squeeze_net()
......@@ -47,3 +47,7 @@ python3 -m nose -v tests/python/frontend/nnvm_to_relay || exit -1
echo "Running relay TFLite frontend test..."
python3 -m nose -v tests/python/frontend/tflite || exit -1
echo "Running relay caffe2 frondend test..."
python3 -m nose -v tests/python/frontend/caffe2 || exit -1
"""
Compile Caffe2 Models
=====================
**Author**: `Hiroyuki Makino <https://makihiro.github.io/>`_
This article is an introductory tutorial to deploy Caffe2 models with Relay.
For us to begin with, Caffe2 should be installed.
A quick solution is to install via conda
.. code-block:: bash
# for cpu
conda install pytorch-nightly-cpu -c pytorch
# for gpu with CUDA 8
conda install pytorch-nightly cuda80 -c pytorch
or please refer to official site
https://caffe2.ai/docs/getting-started.html
"""
######################################################################
# Utils for downloading files
# ----------------------------
def download(url, path, overwrite=False):
import os
if os.path.isfile(path) and not overwrite:
print('File {} exists, skip.'.format(path))
return
print('Downloading from url {} to {}'.format(url, path))
try:
import urllib.request
urllib.request.urlretrieve(url, path)
except:
import urllib
urllib.urlretrieve(url, path)
######################################################################
# Load pretrained Caffe2 model
# ----------------------------
# We load a pretrained resnet50 classification model provided by Caffe2.
from caffe2.python.models.download import ModelDownloader
mf = ModelDownloader()
class Model:
def __init__(self, model_name):
self.init_net, self.predict_net, self.value_info = mf.get_c2_model(model_name)
resnet50 = Model('resnet50')
######################################################################
# Load a test image
# ------------------
# A single cat dominates the examples!
from PIL import Image
from matplotlib import pyplot as plt
import numpy as np
img_url = 'https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true'
download(img_url, 'cat.png')
img = Image.open('cat.png').resize((224, 224))
plt.imshow(img)
plt.show()
# input preprocess
def transform_image(image):
image = np.array(image) - np.array([123., 117., 104.])
image /= np.array([58.395, 57.12, 57.375])
image = image.transpose((2, 0, 1))
image = image[np.newaxis, :].astype('float32')
return image
data = transform_image(img)
######################################################################
# Compile the model on Relay
# --------------------------
# Caffe2 input tensor name, shape and type
input_name = resnet50.predict_net.op[0].input[0]
shape_dict = {input_name: data.shape}
dtype_dict = {input_name: data.dtype}
# parse Caffe2 model and convert into Relay computation graph
from tvm import relay
func, params = relay.frontend.from_caffe2(resnet50.init_net, resnet50.predict_net, shape_dict, dtype_dict)
# compile the model
# target x86 cpu
target = 'llvm'
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(func, target, params=params)
######################################################################
# Execute on TVM
# ---------------
# The process is no different from other examples.
import tvm
from tvm.contrib import graph_runtime
# context x86 cpu, use tvm.gpu(0) if you run on GPU
ctx = tvm.cpu(0)
# create a runtime executor module
m = graph_runtime.create(graph, lib, ctx)
# set inputs
m.set_input(input_name, tvm.nd.array(data.astype('float32')))
# set related params
m.set_input(**params)
# execute
m.run()
# get outputs
tvm_out = m.get_output(0)
top1_tvm = np.argmax(tvm_out.asnumpy()[0])
#####################################################################
# Look up synset name
# -------------------
# Look up prediction top 1 index in 1000 class synset.
from caffe2.python import workspace
synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/',
'4d0b62f3d01426887599d4f7ede23ee5/raw/',
'596b27d23537e5a1b5751d2b0481ef172f58b539/',
'imagenet1000_clsid_to_human.txt'])
synset_name = 'synset.txt'
download(synset_url, synset_name)
with open(synset_name) as f:
synset = eval(f.read())
print('Relay top-1 id: {}, class name: {}'.format(top1_tvm, synset[top1_tvm]))
# confirm correctness with caffe2 output
p = workspace.Predictor(resnet50.init_net, resnet50.predict_net)
caffe2_out = p.run({input_name: data})
top1_caffe2 = np.argmax(caffe2_out)
print('Caffe2 top-1 id: {}, class name: {}'.format(top1_caffe2, synset[top1_caffe2]))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment