Commit ade98e14 by hlu1 Committed by Tianqi Chen

[nnvm] Add caffe2 frontend (#1981)

parent c5e1da93
......@@ -6,3 +6,4 @@ from .coreml import from_coreml
from .keras import from_keras
from .darknet import from_darknet
from .tensorflow import from_tensorflow
from .caffe2 import from_caffe2
......@@ -4,9 +4,9 @@ from __future__ import absolute_import as _abs
import numpy as np
import tvm
from .. import symbol as _sym
from .. import graph as _graph
from ..compiler import graph_util
from .common import get_nnvm_op, Renamer, SymbolTable, AttrConverter as AttrCvt
from .onnx_caffe2_utils import dimension_picker, dimension_constraint, \
infer_channels, revert_caffe2_pad
__all__ = ['from_onnx']
......@@ -74,16 +74,16 @@ class Pool(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
return AttrCvt(
op_name=_dimension_picker(cls.name),
op_name=dimension_picker(cls.name),
transforms={
'kernel_shape': 'pool_size',
'pads': ('padding', (0, 0), _revert_caffe2_pad)
'pads': ('padding', (0, 0), revert_caffe2_pad)
},
# very weird attributes here in onnx, force check
ignores=['dilations'],
# TODO(zhreshold): make sure ceil_mode in onnx, and layout?
extras={'ceil_mode': False},
custom_check=_dimension_constraint())(inputs, attr, params)
custom_check=dimension_constraint())(inputs, attr, params)
class Absolute(OnnxOpConverter):
......@@ -118,18 +118,18 @@ class Conv(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
# get number of channels
channels = _infer_channels(inputs[1], params)
channels = infer_channels(inputs[1], params)
attr['channels'] = channels
return AttrCvt(
op_name=_dimension_picker('conv'),
op_name=dimension_picker('conv'),
transforms={
'kernel_shape': 'kernel_size',
'dilations': ('dilation', (0, 0)),
'pads': ('padding', (0, 0), _revert_caffe2_pad),
'pads': ('padding', (0, 0), revert_caffe2_pad),
'group': ('groups', 1)
},
extras={'use_bias': len(inputs) == 3},
custom_check=_dimension_constraint())(inputs, attr, params)
custom_check=dimension_constraint())(inputs, attr, params)
class ConvTranspose(OnnxOpConverter):
......@@ -137,20 +137,20 @@ class ConvTranspose(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
# get number of channels
channels = _infer_channels(inputs[1], params, True)
channels = infer_channels(inputs[1], params, True)
attr['channels'] = channels
groups = attr.pop('group')
attr['groups'] = groups
return AttrCvt(
op_name=_dimension_picker('conv', '_transpose'),
op_name=dimension_picker('conv', '_transpose'),
transforms={
'kernel_shape': 'kernel_size',
'dilations': ('dilation', (0, 0)),
'pads': ('padding', (0, 0), _revert_caffe2_pad)
'pads': ('padding', (0, 0), revert_caffe2_pad)
},
disables=['output_shape'],
extras={'use_bias': len(inputs) == 3},
custom_check=_dimension_constraint())(inputs, attr, params)
custom_check=dimension_constraint())(inputs, attr, params)
class Div(Elemwise):
......@@ -180,7 +180,7 @@ class Gemm(OnnxOpConverter):
transA = int(attr.get('transA', 0))
transB = int(attr.get('transB', 0))
# get number of channels
channels = _infer_channels(inputs[1], params, not transB)
channels = infer_channels(inputs[1], params, not transB)
if transA:
inputs[0] = _sym.transpose(inputs[0], axes=(1, 0))
if not transB:
......@@ -254,7 +254,7 @@ class Prelu(OnnxOpConverter):
def _impl_v1(cls, inputs, attr, params):
assert len(inputs) == 2, "Prelu need 2 inputs, {} given".format(
len(inputs))
channels = _infer_channels(inputs[1], params, False)
channels = infer_channels(inputs[1], params, False)
if channels == 1:
return inputs[0] * inputs[1]
return _sym.broadcast_mul(inputs[0], inputs[1])
......@@ -362,17 +362,6 @@ class ImageScaler(OnnxOpConverter):
return ret
def _revert_caffe2_pad(attr):
"""Caffe2 require two times the normal padding."""
if len(attr) == 4:
attr = attr[:2]
elif len(attr) == 2:
pass
else:
raise ValueError("Invalid caffe2 type padding: {}".format(attr))
return attr
def _broadcast_constraint():
def _broadcast_check(attrs):
......@@ -383,43 +372,11 @@ def _broadcast_constraint():
return _broadcast_check, "Specifying broadcast axis not allowed."
def _dimension_picker(prefix, surfix=''):
def _impl(attr):
kernel = attr['kernel_shape']
if len(kernel) == 2:
return prefix + '2d' + surfix
raise NotImplementedError("Only 2d kernel supported.")
return _impl
def _dimension_constraint():
def _dim_check(attrs):
if len(attrs['kernel_shape']) == 2:
return True
return False
return _dim_check, "Only 2d kernel supported."
def _infer_channels(inputs, params, transpose=False):
"""A hack for getting 'channles' or 'units' since onnx don't provide
these attributes. We check the shape of weights provided to get the number.
"""
g = _graph.create(inputs)
shape_dict = {k: v.shape for k, v in params.items()}
_, out_shapes = graph_util.infer_shape(g, **shape_dict)
channels = out_shapes[0][0] if not transpose else out_shapes[0][1]
return channels
def _fully_connected(opset):
def _impl(inputs, attr, params):
# get number of channels
channels = _infer_channels(inputs[1], params)
channels = infer_channels(inputs[1], params)
attr['units'] = channels
return AttrCvt('dense', ignores=['axis', 'axis_w'])(inputs, attr)
......
"""Util functions shared by the ONNX and Caffe2 frontends."""
from __future__ import absolute_import as _abs
from nnvm import graph as _graph
from nnvm.compiler import graph_util
def dimension_picker(prefix, surfix=''):
def _impl(attr):
kernel = attr['kernel_shape']
if len(kernel) == 2:
return prefix + '2d' + surfix
else:
raise NotImplementedError("Only 2d kernel supported.")
return _impl
def dimension_constraint():
def _dim_check(attrs):
if len(attrs['kernel_shape']) == 2:
return True
return False
return _dim_check, "Only 2d kernel supported."
def infer_channels(inputs, params, transpose=False):
"""A hack for getting 'channels' or 'units' since caffe2 don't provide
these attributes. We check the shape of weights provided to get the number.
"""
g = _graph.create(inputs)
shape_dict = {k: v.shape for k, v in params.items()}
_, out_shapes = graph_util.infer_shape(g, **shape_dict)
channels = out_shapes[0][0] if not transpose else out_shapes[0][1]
return channels
def revert_caffe2_pad(pads):
"""Caffe2 require two times the normal padding."""
if len(pads) == 4:
pads = pads[:2]
elif len(pads) == 2:
pass
else:
raise ValueError("Invalid caffe2 type padding: {}".format(pads))
return pads
"""Store for caffe2 examples and common models."""
from __future__ import absolute_import as _abs
import os
import importlib
models = [
'squeezenet',
'resnet50',
'vgg19',
]
# skip download if model exist
for model in models:
try:
locals()['c2_' + model] = importlib.import_module('caffe2.python.models.' + model)
except ImportError:
os.system("python -m caffe2.python.models.download -i -f " + model)
locals()['c2_' + model] = importlib.import_module('caffe2.python.models.' + model)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# coding: utf-8
# pylint: disable=unused-argument
"""
Symbol of SqueezeNet
Reference:
Iandola, Forrest N., et al.
"Squeezenet: Alexnet-level accuracy with 50x fewer parameters and< 0.5 mb model size." (2016).
"""
from nnvm import symbol as sym
from nnvm.testing.utils import create_workload
# Helpers
def _make_fire(net, squeeze_channels, expand1x1_channels, expand3x3_channels):
net = _make_fire_conv(net, squeeze_channels, 1, 0)
left = _make_fire_conv(net, expand1x1_channels, 1, 0)
right = _make_fire_conv(net, expand3x3_channels, 3, 1)
# NOTE : Assume NCHW layout here
net = sym.concatenate(left, right, axis=1)
return net
def _make_fire_conv(net, channels, kernel_size, padding=0):
net = sym.conv2d(net, channels=channels, kernel_size=(kernel_size, kernel_size),
padding=(padding, padding))
net = sym.relu(net)
return net
# Net
def get_symbol(num_classes, version, **kwargs):
"""Get symbol of SqueezeNet
Parameters
----------
num_classes: int
The number of classification results
version : str, optional
"1.0" or "1.1" of SqueezeNet
"""
assert version == '1.1', ("Unsupported SqueezeNet version {version}:"
"1.1 expected".format(version=version))
net = sym.Variable("data")
net = sym.conv2d(net, channels=64, kernel_size=(3, 3), strides=(2, 2))
net = sym.relu(net)
net = sym.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
net = _make_fire(net, 16, 64, 64)
net = _make_fire(net, 16, 64, 64)
net = sym.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
net = _make_fire(net, 32, 128, 128)
net = _make_fire(net, 32, 128, 128)
net = sym.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
net = _make_fire(net, 48, 192, 192)
net = _make_fire(net, 48, 192, 192)
net = _make_fire(net, 64, 256, 256)
net = _make_fire(net, 64, 256, 256)
net = sym.dropout(net, rate=0.5)
net = sym.conv2d(net, channels=num_classes, kernel_size=(1, 1))
net = sym.relu(net)
net = sym.global_avg_pool2d(net)
return sym.softmax(net, axis=1)
def get_workload(batch_size=1, num_classes=1000, version='1.0',
image_shape=(3, 224, 224), dtype="float32", **kwargs):
"""Get benchmark workload for SqueezeNet
Parameters
----------
batch_size : int
The batch size used in the model
num_classes : int, optional
Number of classes
version : str, optional
"1.0" or "1.1" of SqueezeNet
image_shape : tuple, optional
The input image shape
dtype : str, optional
The data type
kwargs : dict
Extra arguments
Returns
-------
net : nnvm.Symbol
The computational graph
params : dict of str to NDArray
The parameters.
"""
net = get_symbol(num_classes=num_classes, version=version, **kwargs)
return create_workload(net, batch_size, image_shape, dtype)
import numpy as np
import nnvm
import tvm
from tvm.contrib import graph_runtime
from nnvm.testing.config import ctx_list
from model_zoo import c2_squeezenet, c2_resnet50, c2_vgg19
from caffe2.python import workspace
def get_tvm_output(model,
input_data,
target,
ctx,
output_shape,
output_dtype='float32'):
""" Generic function to execute and get tvm output"""
sym, params = nnvm.frontend.from_caffe2(model.init_net, model.predict_net)
# supporting multiple inputs in caffe2 in a bit tricky,
# because the input names can appear at the beginning or end of model.predict_net.external_input
assert isinstance(input_data, np.ndarray)
# here we use the first input blob to the first op to get the input name
input_names = model.predict_net.op[0].input[0]
shape_dict = {input_names: input_data.shape}
dtype_dict = {input_names: input_data.dtype}
graph, lib, params = nnvm.compiler.build(
sym, target, shape=shape_dict, dtype=dtype_dict, params=params)
ctx = tvm.cpu(0)
m = graph_runtime.create(graph, lib, ctx)
# set inputs
m.set_input(input_names, tvm.nd.array(input_data.astype(input_data.dtype)))
m.set_input(**params)
# execute
m.run()
# get outputs
if isinstance(output_shape, list) and isinstance(output_dtype, list):
tvm_output_list = []
for i, s in enumerate(output_shape):
tvm_output = m.get_output(i, tvm.nd.empty((s), output_dtype[i]))
tvm_output_list.append(tvm_output.asnumpy())
return tvm_output_list
else:
tvm_output = m.get_output(0, tvm.nd.empty((output_shape),
output_dtype))
return tvm_output.asnumpy()
def get_caffe2_output(model, x, dtype='float32'):
workspace.RunNetOnce(model.init_net)
input_blob = model.predict_net.op[0].input[0]
workspace.FeedBlob(input_blob, x.astype(dtype))
workspace.RunNetOnce(model.predict_net)
output_blob = model.predict_net.external_output[0]
c2_output = workspace.FetchBlob(output_blob)
return c2_output
def verify_caffe2_forward_impl(model, data_shape, out_shape):
dtype = 'float32'
data = np.random.uniform(size=data_shape).astype(dtype)
c2_out = get_caffe2_output(model, data, dtype)
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, data, target, ctx, out_shape, dtype)
tvm.testing.assert_allclose(c2_out, tvm_out, rtol=1e-5, atol=1e-5)
def verify_squeezenet1_1():
verify_caffe2_forward_impl(c2_squeezenet, (1, 3, 224, 224),
(1, 1000, 1, 1))
def verify_resnet50():
verify_caffe2_forward_impl(c2_resnet50, (1, 3, 224, 224),
(1, 1000))
def verify_vgg19():
verify_caffe2_forward_impl(c2_vgg19, (1, 3, 224, 224), (1, 1000))
if __name__ == '__main__':
verify_squeezenet1_1()
verify_resnet50()
verify_vgg19()
"""Test graph equality of caffe2 models."""
import nnvm
from nnvm.compiler import graph_util, graph_attr
from model_zoo import c2_squeezenet, squeezenet
def compare_graph(init, predict, nnvm_sym, ishape):
caffe2_sym, params = nnvm.frontend.from_caffe2(init, predict)
g1 = nnvm.graph.create(caffe2_sym)
g2 = nnvm.graph.create(nnvm_sym)
input_name = predict.external_input[0]
ishapes = {input_name: ishape}
graph_attr.set_shape_inputs(g1, ishapes)
graph_attr.set_shape_inputs(g2, ishapes)
g1 = g1.apply("InferShape").apply("SimplifyInference")
g2 = g2.apply("InferShape").apply("SimplifyInference")
graph_util.check_graph_equal(g1, g2)
def test_squeeze_net():
symbol, params = squeezenet.get_workload(version='1.1')
compare_graph(c2_squeezenet.init_net, c2_squeezenet.predict_net, symbol, ishape=(1, 3, 224, 224))
if __name__ == '__main__':
test_squeeze_net()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment