Commit ade98e14 by hlu1 Committed by Tianqi Chen

[nnvm] Add caffe2 frontend (#1981)

parent c5e1da93
......@@ -6,3 +6,4 @@ from .coreml import from_coreml
from .keras import from_keras
from .darknet import from_darknet
from .tensorflow import from_tensorflow
from .caffe2 import from_caffe2
# pylint: disable=import-self, invalid-name, line-too-long, unused-argument
"""Caffe2 frontend"""
from __future__ import absolute_import as _abs
import tvm
from nnvm import symbol as _sym
from nnvm.frontend.common import get_nnvm_op, Renamer, AttrConverter as AttrCvt
from .onnx_caffe2_utils import dimension_picker, dimension_constraint, infer_channels, revert_caffe2_pad
from . import onnx
__all__ = ['from_caffe2']
def _clean_up_pool_args(args):
""" A helper function to clean up common arguments in conv and pooling ops.
"""
assert isinstance(args, dict)
if 'stride_h' in args and 'stride_w' in args:
assert 'stride' not in args and 'strides' not in args
args['strides'] = [args['stride_h'], args['stride_w']]
args.pop('stride_h')
args.pop('stride_w')
elif 'stride' in args:
args['strides'] = [args['stride'], args['stride']]
args.pop('stride')
# rename 'kernel', 'kernels', to 'kernel_shape'
if 'kernel_h' in args and 'kernel_w' in args:
assert 'kernel' not in args and 'kernels' not in args
args['kernel_shape'] = [args['kernel_h'], args['kernel_w']]
args.pop('kernel_h')
args.pop('kernel_w')
elif 'kernel' in args:
args['kernel_shape'] = [args['kernel'], args['kernel']]
args.pop('kernel')
elif 'kernels' in args:
args['kernel_shape'] = args['kernels']
args.pop('kernels')
if 'pad_t' in args and 'pad_l' in args and 'pad_b' in args and 'pad_r' in args:
assert 'pad' not in args and 'pads' not in args
args['pads'] = [
args['pad_t'], args['pad_l'], args['pad_b'], args['pad_r']
]
for pad in ['pad_t', 'pad_l', 'pad_b', 'pad_r']:
args.pop(pad)
elif 'pad' in args:
args['pads'] = [args['pad'], args['pad']]
args.pop('pad')
if 'dilation_h' in args and 'dilation_w' in args:
assert 'dilation' not in args and 'dilations' not in args
args['dilations'] = [args['dilation_h'], args['dilation_w']]
args.pop('dilation_h')
args.pop('dilation_w')
elif 'dilation' in args:
args['dilations'] = [args['dilation'], args['dilation']]
args.pop('dilation')
return args
class Caffe2OpConverter(object):
""" A helper class for holding Caffe2 op converters.
"""
@classmethod
def get_converter(cls):
""" Get converter.
:return: converter, which should be `_impl`.
"""
if hasattr(cls, '_impl'):
return getattr(cls, '_impl')
else:
raise NotImplementedError('{} not implemented'.format(
cls.__name__))
_caffe2_internal_args = {
# nnpack args
'algo',
'convolution_transform_strategy',
'float16_compute',
'shared_buffer',
# training args
'init_params',
'cudnn_exhaustive_search',
'exhaustive_search',
# training args
'adj',
'hwgq',
# args that we don't care
'legacy_pad',
}
class Pool(Caffe2OpConverter):
""" A helper class for pool op converters.
"""
name = ''
@classmethod
def _impl(cls, inputs, args, params):
_clean_up_pool_args(args)
if 'global_pooling' in args and args['global_pooling'] == 1:
op_name = dimension_picker('global_' + cls.name)
return get_nnvm_op(op_name(args))(*inputs)
return AttrCvt(
op_name=dimension_picker(cls.name),
transforms={
'kernel_shape': 'pool_size',
'pads': ('padding', (0, 0), revert_caffe2_pad),
'strides': 'strides',
},
excludes={
# TVM poolop does not support dilation
'dilations',
},
ignores=_caffe2_internal_args | {'global_pooling', 'order'},
custom_check=dimension_constraint())(inputs, args, params)
class AveragePool(Pool):
name = 'avg_pool'
class MaxPool(Pool):
name = 'max_pool'
class Conv(Caffe2OpConverter):
""" Operator converter for Conv.
"""
@classmethod
def _impl(cls, inputs, args, params):
# get number of channels
channels = infer_channels(inputs[1], params)
args['channels'] = channels
_clean_up_pool_args(args)
return AttrCvt(
op_name=dimension_picker('conv'),
transforms={
'group': ('groups', 1),
'kernel_shape':
'kernel_size',
'pads': ('padding', (0, 0), revert_caffe2_pad),
'strides':
'strides',
'dilations': ('dilation', (1, 1)),
'order':
('layout', ("NCHW"),
lambda x: x if isinstance(x, str) else x.decode('UTF-8')),
},
excludes={},
ignores=_caffe2_internal_args,
extras={'use_bias': len(inputs) == 3},
custom_check=dimension_constraint())(inputs, args, params)
class Concat(Caffe2OpConverter):
""" Operator converter for Concat.
"""
@classmethod
def _impl(cls, inputs, args, params):
def _get_axis_from_order_str(order):
order = order if isinstance(order, str) else order.decode('UTF-8')
if order == 'NCHW':
return 1
elif order == 'NHWC':
return 3
else:
raise RuntimeError(
"Unsupported storage order: {} in caffe2".format(order))
return AttrCvt(
op_name='concatenate',
transforms={
'order': ('axis', (1), _get_axis_from_order_str),
},
excludes={
'add_axis',
})(inputs, args, params)
class NormalizePlanarYUV(Caffe2OpConverter):
""" Operator converter for NormalizePlanarYUV.
caffe2 definition: https://github.com/pytorch/pytorch/blob/master/caffe2/operators/norm_planar_yuv_op.cc
"""
@classmethod
def _impl(cls, inputs, args, params):
assert len(inputs) == 3
mean = _sym.expand_dims(inputs[1], axis=2, num_newaxis=2)
std = _sym.expand_dims(inputs[2], axis=2, num_newaxis=2)
return _sym.broadcast_div(_sym.broadcast_sub(inputs[0], mean), std)
class ResizeNearest(Caffe2OpConverter):
""" Operator converter for Upsample (nearest mode).
"""
@classmethod
def _impl(cls, inputs, args, params):
width_scale = args['width_scale'] if 'width_scale' in args else 1
height_scale = args['height_scale'] if 'height_scale' in args else 1
assert width_scale == height_scale
return _sym.upsampling(
inputs[0], scale=int(width_scale), method="NEAREST_NEIGHBOR")
class FC(Caffe2OpConverter):
""" Operator converter for FC.
"""
@classmethod
def _impl(cls, inputs, args, params):
inputs[0] = _sym.flatten(inputs[0])
args['units'] = infer_channels(inputs[1], params)
return AttrCvt(
'dense',
ignores=['axis', 'axis_w'],
extras={'use_bias': len(inputs) == 3},
)(inputs, args, params)
class SpatialBN(Caffe2OpConverter):
""" Operator converter for SpatialBN.
"""
@classmethod
def _impl(cls, inputs, args, params):
return AttrCvt(
op_name='batch_norm',
disables=['momentum'],
ignores=[
'order', 'spatial', 'is_test', 'consumed_inputs', 'num_batches'
])(inputs, args, params)
# compatible operators that do NOT require any conversion.
_identity_list = []
# _convert_map defines maps of name to converter functor(callable)
# for 1 to 1 mapping, use Renamer if nothing but name is different
# use AttrCvt if attributes need to be converted
# for 1 to N mapping(composed), use custom callable functions
# for N to 1 mapping, currently not supported(?)
# Minimal set of ops for squeezenet and resnet50
def _get_convert_map():
return {
# caffe2/onnx common operators
'Add': onnx.Add.get_converter(opset=1),
'Sum': onnx.Sum.get_converter(opset=1),
'Softmax': onnx.Softmax.get_converter(opset=1),
# nn
'AveragePool': AveragePool.get_converter(),
'MaxPool': MaxPool.get_converter(),
'Conv': Conv.get_converter(),
'Concat': Concat.get_converter(),
'FC': FC.get_converter(),
'SpatialBN': SpatialBN.get_converter(),
'ResizeNearest': ResizeNearest.get_converter(),
'Relu': AttrCvt('relu', {}, ignores=['order']),
'Sigmoid': Renamer('sigmoid'),
'Dropout': AttrCvt('dropout', {'ratio': 'rate'}, ignores=['is_test']),
# c2 image preprocessing ops
'NormalizePlanarYUV': NormalizePlanarYUV.get_converter(),
}
class Caffe2NetDef(object):
"""A helper class for handling nnvm graph copying from pb2.GraphProto.
Definition: https://github.com/pytorch/pytorch/blob/master/caffe2/proto/caffe2.proto
"""
def __init__(self):
self._nodes = {}
self._params = {}
self._visited_nodes = set()
self._ops = {}
def from_caffe2(self, init_net, predict_net):
"""Construct nnvm nodes from caffe2 graph.
Parameters
----------
workspace : Caffe2 workspace
predict_net : protobuf object
Returns
-------
sym : nnvm.sym.Symbol
The returned nnvm symbol
params : dict
A dict of name: tvm.nd.array pairs, used as pretrained weights
"""
from caffe2.python import workspace
workspace.RunNetOnce(init_net)
# Input
input_name = predict_net.op[0].input[0]
# Params
self._params = {}
used_blobs = set()
for c2_op in predict_net.op:
for i in c2_op.input:
used_blobs.add(i)
for blob in workspace.Blobs():
if blob in used_blobs and blob != input_name:
self._params[blob] = tvm.nd.array(workspace.FetchBlob(blob))
# Variables
self._nodes = {}
for blob in predict_net.external_input:
self._nodes[blob] = _sym.Variable(name=blob)
# Ops
for c2_op in predict_net.op:
for blob in c2_op.output:
self._ops[blob] = c2_op
for c2_op in predict_net.op:
self._process_op(c2_op)
# Outputs
out = []
for blob in predict_net.external_output:
out.append(self._nodes[blob])
if len(out) > 1:
sym = _sym.Group(out)
else:
sym = out[0]
return sym, self._params
def _get_node(self, blob):
"""Get the nnvm Symbol of blob and detect cyclic dependency in the graph."""
if blob in self._nodes:
return self._nodes[blob]
assert blob not in self._visited_nodes, 'Cyclic dependency in the graph (in {})'.format(
blob)
self._visited_nodes.add(blob)
self._process_op(self._ops[blob])
return self._nodes[blob]
def _process_op(self, c2_op):
op_type = c2_op.type
args = self._parse_arg(c2_op.arg)
inputs = [self._get_node(i) for i in c2_op.input]
tvm_op = self._convert_operator(op_type, inputs, args)
# Ignore all outputs except the first one
self._nodes[c2_op.output[0]] = tvm_op[0]
def _parse_arg(self, arg):
"""Convert a list of Argument to a dict, with names as keys."""
args = {}
for a in arg:
for f in ['f', 'i', 's']:
if a.HasField(f):
args[a.name] = getattr(a, f)
for f in ['floats', 'ints', 'strings']:
if list(getattr(a, f)):
assert a.name not in args, "Only one type of attr is allowed"
args[a.name] = tuple(getattr(a, f))
for f in ['n']:
if a.HasField(f):
raise NotImplementedError(
"Field {} is not supported in nnvm.".format(f))
for f in ['nets']:
if list(getattr(a, f)):
raise NotImplementedError(
"Field {} is not supported in nnvm.".format(f))
if a.name not in args:
raise ValueError("Cannot parse attribute: \n{}\n.".format(a))
return args
def _convert_operator(self,
op_type,
inputs,
args,
identity_list=None,
convert_map=None):
"""Convert from Caffe2 operator to nnvm operator.
The converter must specify conversions explicity for incompatible name, and
apply handlers to operator attributes.
Parameters
----------
op_type : str
Operator name, such as Convolution, FullyConnected
inputs : list of nnvm.Symbol
List of input symbols.
args : dict
Dict of operator attributes
identity_list : list
List of operators that don't require conversion
convert_map : dict
Dict of name : callable, where name is the op's name that
require conversion to nnvm, callable are functions which
take args and return (new_op_type, new_args)
Returns
-------
sym : nnvm.Symbol
Converted nnvm Symbol
"""
identity_list = identity_list if identity_list else _identity_list
convert_map = convert_map if convert_map else _get_convert_map()
if op_type in identity_list:
sym = get_nnvm_op(op_type)(*inputs, **args)
elif op_type in convert_map:
# Add a sanitizing step to convert all byte strings in args to strings
sym = convert_map[op_type](inputs, args, self._params)
else:
raise NotImplementedError(
"Operator {} not implemented.".format(op_type))
return sym
def from_caffe2(init_net, predict_net):
"""Load caffe2 graph which contains init_net and predict_net into nnvm graph.
Parameters
----------
init_net : protobuf object
Caffe2 NetDef containing the weights
predict_net : protobuf object
Caffe2 NetDef containing the graph
Returns
-------
sym : nnvm.Symbol
Compatible nnvm symbol
params : dict of str to tvm.ndarray
Dict of converted parameters stored in tvm.ndarray format
"""
caffe2 = Caffe2NetDef()
return caffe2.from_caffe2(init_net, predict_net)
......@@ -4,9 +4,9 @@ from __future__ import absolute_import as _abs
import numpy as np
import tvm
from .. import symbol as _sym
from .. import graph as _graph
from ..compiler import graph_util
from .common import get_nnvm_op, Renamer, SymbolTable, AttrConverter as AttrCvt
from .onnx_caffe2_utils import dimension_picker, dimension_constraint, \
infer_channels, revert_caffe2_pad
__all__ = ['from_onnx']
......@@ -74,16 +74,16 @@ class Pool(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
return AttrCvt(
op_name=_dimension_picker(cls.name),
op_name=dimension_picker(cls.name),
transforms={
'kernel_shape': 'pool_size',
'pads': ('padding', (0, 0), _revert_caffe2_pad)
'pads': ('padding', (0, 0), revert_caffe2_pad)
},
# very weird attributes here in onnx, force check
ignores=['dilations'],
# TODO(zhreshold): make sure ceil_mode in onnx, and layout?
extras={'ceil_mode': False},
custom_check=_dimension_constraint())(inputs, attr, params)
custom_check=dimension_constraint())(inputs, attr, params)
class Absolute(OnnxOpConverter):
......@@ -118,18 +118,18 @@ class Conv(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
# get number of channels
channels = _infer_channels(inputs[1], params)
channels = infer_channels(inputs[1], params)
attr['channels'] = channels
return AttrCvt(
op_name=_dimension_picker('conv'),
op_name=dimension_picker('conv'),
transforms={
'kernel_shape': 'kernel_size',
'dilations': ('dilation', (0, 0)),
'pads': ('padding', (0, 0), _revert_caffe2_pad),
'pads': ('padding', (0, 0), revert_caffe2_pad),
'group': ('groups', 1)
},
extras={'use_bias': len(inputs) == 3},
custom_check=_dimension_constraint())(inputs, attr, params)
custom_check=dimension_constraint())(inputs, attr, params)
class ConvTranspose(OnnxOpConverter):
......@@ -137,20 +137,20 @@ class ConvTranspose(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
# get number of channels
channels = _infer_channels(inputs[1], params, True)
channels = infer_channels(inputs[1], params, True)
attr['channels'] = channels
groups = attr.pop('group')
attr['groups'] = groups
return AttrCvt(
op_name=_dimension_picker('conv', '_transpose'),
op_name=dimension_picker('conv', '_transpose'),
transforms={
'kernel_shape': 'kernel_size',
'dilations': ('dilation', (0, 0)),
'pads': ('padding', (0, 0), _revert_caffe2_pad)
'pads': ('padding', (0, 0), revert_caffe2_pad)
},
disables=['output_shape'],
extras={'use_bias': len(inputs) == 3},
custom_check=_dimension_constraint())(inputs, attr, params)
custom_check=dimension_constraint())(inputs, attr, params)
class Div(Elemwise):
......@@ -180,7 +180,7 @@ class Gemm(OnnxOpConverter):
transA = int(attr.get('transA', 0))
transB = int(attr.get('transB', 0))
# get number of channels
channels = _infer_channels(inputs[1], params, not transB)
channels = infer_channels(inputs[1], params, not transB)
if transA:
inputs[0] = _sym.transpose(inputs[0], axes=(1, 0))
if not transB:
......@@ -254,7 +254,7 @@ class Prelu(OnnxOpConverter):
def _impl_v1(cls, inputs, attr, params):
assert len(inputs) == 2, "Prelu need 2 inputs, {} given".format(
len(inputs))
channels = _infer_channels(inputs[1], params, False)
channels = infer_channels(inputs[1], params, False)
if channels == 1:
return inputs[0] * inputs[1]
return _sym.broadcast_mul(inputs[0], inputs[1])
......@@ -362,17 +362,6 @@ class ImageScaler(OnnxOpConverter):
return ret
def _revert_caffe2_pad(attr):
"""Caffe2 require two times the normal padding."""
if len(attr) == 4:
attr = attr[:2]
elif len(attr) == 2:
pass
else:
raise ValueError("Invalid caffe2 type padding: {}".format(attr))
return attr
def _broadcast_constraint():
def _broadcast_check(attrs):
......@@ -383,43 +372,11 @@ def _broadcast_constraint():
return _broadcast_check, "Specifying broadcast axis not allowed."
def _dimension_picker(prefix, surfix=''):
def _impl(attr):
kernel = attr['kernel_shape']
if len(kernel) == 2:
return prefix + '2d' + surfix
raise NotImplementedError("Only 2d kernel supported.")
return _impl
def _dimension_constraint():
def _dim_check(attrs):
if len(attrs['kernel_shape']) == 2:
return True
return False
return _dim_check, "Only 2d kernel supported."
def _infer_channels(inputs, params, transpose=False):
"""A hack for getting 'channles' or 'units' since onnx don't provide
these attributes. We check the shape of weights provided to get the number.
"""
g = _graph.create(inputs)
shape_dict = {k: v.shape for k, v in params.items()}
_, out_shapes = graph_util.infer_shape(g, **shape_dict)
channels = out_shapes[0][0] if not transpose else out_shapes[0][1]
return channels
def _fully_connected(opset):
def _impl(inputs, attr, params):
# get number of channels
channels = _infer_channels(inputs[1], params)
channels = infer_channels(inputs[1], params)
attr['units'] = channels
return AttrCvt('dense', ignores=['axis', 'axis_w'])(inputs, attr)
......
"""Util functions shared by the ONNX and Caffe2 frontends."""
from __future__ import absolute_import as _abs
from nnvm import graph as _graph
from nnvm.compiler import graph_util
def dimension_picker(prefix, surfix=''):
def _impl(attr):
kernel = attr['kernel_shape']
if len(kernel) == 2:
return prefix + '2d' + surfix
else:
raise NotImplementedError("Only 2d kernel supported.")
return _impl
def dimension_constraint():
def _dim_check(attrs):
if len(attrs['kernel_shape']) == 2:
return True
return False
return _dim_check, "Only 2d kernel supported."
def infer_channels(inputs, params, transpose=False):
"""A hack for getting 'channels' or 'units' since caffe2 don't provide
these attributes. We check the shape of weights provided to get the number.
"""
g = _graph.create(inputs)
shape_dict = {k: v.shape for k, v in params.items()}
_, out_shapes = graph_util.infer_shape(g, **shape_dict)
channels = out_shapes[0][0] if not transpose else out_shapes[0][1]
return channels
def revert_caffe2_pad(pads):
"""Caffe2 require two times the normal padding."""
if len(pads) == 4:
pads = pads[:2]
elif len(pads) == 2:
pass
else:
raise ValueError("Invalid caffe2 type padding: {}".format(pads))
return pads
"""Store for caffe2 examples and common models."""
from __future__ import absolute_import as _abs
import os
import importlib
models = [
'squeezenet',
'resnet50',
'vgg19',
]
# skip download if model exist
for model in models:
try:
locals()['c2_' + model] = importlib.import_module('caffe2.python.models.' + model)
except ImportError:
os.system("python -m caffe2.python.models.download -i -f " + model)
locals()['c2_' + model] = importlib.import_module('caffe2.python.models.' + model)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# coding: utf-8
# pylint: disable=unused-argument
"""
Symbol of SqueezeNet
Reference:
Iandola, Forrest N., et al.
"Squeezenet: Alexnet-level accuracy with 50x fewer parameters and< 0.5 mb model size." (2016).
"""
from nnvm import symbol as sym
from nnvm.testing.utils import create_workload
# Helpers
def _make_fire(net, squeeze_channels, expand1x1_channels, expand3x3_channels):
net = _make_fire_conv(net, squeeze_channels, 1, 0)
left = _make_fire_conv(net, expand1x1_channels, 1, 0)
right = _make_fire_conv(net, expand3x3_channels, 3, 1)
# NOTE : Assume NCHW layout here
net = sym.concatenate(left, right, axis=1)
return net
def _make_fire_conv(net, channels, kernel_size, padding=0):
net = sym.conv2d(net, channels=channels, kernel_size=(kernel_size, kernel_size),
padding=(padding, padding))
net = sym.relu(net)
return net
# Net
def get_symbol(num_classes, version, **kwargs):
"""Get symbol of SqueezeNet
Parameters
----------
num_classes: int
The number of classification results
version : str, optional
"1.0" or "1.1" of SqueezeNet
"""
assert version == '1.1', ("Unsupported SqueezeNet version {version}:"
"1.1 expected".format(version=version))
net = sym.Variable("data")
net = sym.conv2d(net, channels=64, kernel_size=(3, 3), strides=(2, 2))
net = sym.relu(net)
net = sym.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
net = _make_fire(net, 16, 64, 64)
net = _make_fire(net, 16, 64, 64)
net = sym.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
net = _make_fire(net, 32, 128, 128)
net = _make_fire(net, 32, 128, 128)
net = sym.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
net = _make_fire(net, 48, 192, 192)
net = _make_fire(net, 48, 192, 192)
net = _make_fire(net, 64, 256, 256)
net = _make_fire(net, 64, 256, 256)
net = sym.dropout(net, rate=0.5)
net = sym.conv2d(net, channels=num_classes, kernel_size=(1, 1))
net = sym.relu(net)
net = sym.global_avg_pool2d(net)
return sym.softmax(net, axis=1)
def get_workload(batch_size=1, num_classes=1000, version='1.0',
image_shape=(3, 224, 224), dtype="float32", **kwargs):
"""Get benchmark workload for SqueezeNet
Parameters
----------
batch_size : int
The batch size used in the model
num_classes : int, optional
Number of classes
version : str, optional
"1.0" or "1.1" of SqueezeNet
image_shape : tuple, optional
The input image shape
dtype : str, optional
The data type
kwargs : dict
Extra arguments
Returns
-------
net : nnvm.Symbol
The computational graph
params : dict of str to NDArray
The parameters.
"""
net = get_symbol(num_classes=num_classes, version=version, **kwargs)
return create_workload(net, batch_size, image_shape, dtype)
import numpy as np
import nnvm
import tvm
from tvm.contrib import graph_runtime
from nnvm.testing.config import ctx_list
from model_zoo import c2_squeezenet, c2_resnet50, c2_vgg19
from caffe2.python import workspace
def get_tvm_output(model,
input_data,
target,
ctx,
output_shape,
output_dtype='float32'):
""" Generic function to execute and get tvm output"""
sym, params = nnvm.frontend.from_caffe2(model.init_net, model.predict_net)
# supporting multiple inputs in caffe2 in a bit tricky,
# because the input names can appear at the beginning or end of model.predict_net.external_input
assert isinstance(input_data, np.ndarray)
# here we use the first input blob to the first op to get the input name
input_names = model.predict_net.op[0].input[0]
shape_dict = {input_names: input_data.shape}
dtype_dict = {input_names: input_data.dtype}
graph, lib, params = nnvm.compiler.build(
sym, target, shape=shape_dict, dtype=dtype_dict, params=params)
ctx = tvm.cpu(0)
m = graph_runtime.create(graph, lib, ctx)
# set inputs
m.set_input(input_names, tvm.nd.array(input_data.astype(input_data.dtype)))
m.set_input(**params)
# execute
m.run()
# get outputs
if isinstance(output_shape, list) and isinstance(output_dtype, list):
tvm_output_list = []
for i, s in enumerate(output_shape):
tvm_output = m.get_output(i, tvm.nd.empty((s), output_dtype[i]))
tvm_output_list.append(tvm_output.asnumpy())
return tvm_output_list
else:
tvm_output = m.get_output(0, tvm.nd.empty((output_shape),
output_dtype))
return tvm_output.asnumpy()
def get_caffe2_output(model, x, dtype='float32'):
workspace.RunNetOnce(model.init_net)
input_blob = model.predict_net.op[0].input[0]
workspace.FeedBlob(input_blob, x.astype(dtype))
workspace.RunNetOnce(model.predict_net)
output_blob = model.predict_net.external_output[0]
c2_output = workspace.FetchBlob(output_blob)
return c2_output
def verify_caffe2_forward_impl(model, data_shape, out_shape):
dtype = 'float32'
data = np.random.uniform(size=data_shape).astype(dtype)
c2_out = get_caffe2_output(model, data, dtype)
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, data, target, ctx, out_shape, dtype)
tvm.testing.assert_allclose(c2_out, tvm_out, rtol=1e-5, atol=1e-5)
def verify_squeezenet1_1():
verify_caffe2_forward_impl(c2_squeezenet, (1, 3, 224, 224),
(1, 1000, 1, 1))
def verify_resnet50():
verify_caffe2_forward_impl(c2_resnet50, (1, 3, 224, 224),
(1, 1000))
def verify_vgg19():
verify_caffe2_forward_impl(c2_vgg19, (1, 3, 224, 224), (1, 1000))
if __name__ == '__main__':
verify_squeezenet1_1()
verify_resnet50()
verify_vgg19()
"""Test graph equality of caffe2 models."""
import nnvm
from nnvm.compiler import graph_util, graph_attr
from model_zoo import c2_squeezenet, squeezenet
def compare_graph(init, predict, nnvm_sym, ishape):
caffe2_sym, params = nnvm.frontend.from_caffe2(init, predict)
g1 = nnvm.graph.create(caffe2_sym)
g2 = nnvm.graph.create(nnvm_sym)
input_name = predict.external_input[0]
ishapes = {input_name: ishape}
graph_attr.set_shape_inputs(g1, ishapes)
graph_attr.set_shape_inputs(g2, ishapes)
g1 = g1.apply("InferShape").apply("SimplifyInference")
g2 = g2.apply("InferShape").apply("SimplifyInference")
graph_util.check_graph_equal(g1, g2)
def test_squeeze_net():
symbol, params = squeezenet.get_workload(version='1.1')
compare_graph(c2_squeezenet.init_net, c2_squeezenet.predict_net, symbol, ishape=(1, 3, 224, 224))
if __name__ == '__main__':
test_squeeze_net()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment