Commit 881a78b3 by Yizhi Liu Committed by Haichen Shen

[Relay][Frontend] CoreML Support (#2476)

* [Relay][Frontend] Add CoreML Support

* pip install six in CI

* remove triggering nnvm coreml test

* set opt_level=2 for nnvm coreml test case
parent 5a30a22c
# install libraries for python package on ubuntu
pip2 install nose pylint numpy nose-timer cython decorator scipy tornado typing antlr4-python2-runtime attrs
pip3 install nose pylint numpy nose-timer cython decorator scipy tornado typed_ast pytest mypy orderedset antlr4-python3-runtime attrs
pip2 install nose pylint six numpy nose-timer cython decorator scipy tornado typing antlr4-python2-runtime attrs
pip3 install nose pylint six numpy nose-timer cython decorator scipy tornado typed_ast pytest mypy orderedset antlr4-python3-runtime attrs
......@@ -68,14 +68,15 @@ def ConvolutionLayerParams(op, insym, symtab):
else:
pos = [insym, weights]
if op.isDeconvolution:
ret = _sym.conv2d_transpose(*pos, **params)
else:
ret = _sym.conv2d(*pos, **params)
# consume padding layer
if symtab.in_padding:
params['padding'] = [sum(x) for x in zip(params.get('padding', [0, 0]), symtab.paddings)]
symtab.clear_padding()
if op.isDeconvolution:
ret = _sym.conv2d_transpose(*pos, **params)
else:
ret = _sym.conv2d(*pos, **params)
return ret
def BatchnormLayerParams(op, insym, symtab):
......
import urllib
from six.moves import urllib
import os
from PIL import Image
import numpy as np
......@@ -7,7 +7,7 @@ def download(url, path, overwrite=False):
if os.path.exists(path) and not overwrite:
return
print('Downloading {} to {}.'.format(url, path))
urllib.URLopener().retrieve(url, path)
urllib.request.urlretrieve(url, path)
def get_mobilenet():
url = 'https://docs-assets.developer.apple.com/coreml/models/MobileNet.mlmodel'
......
......@@ -15,9 +15,9 @@ import coremltools as cm
import model_zoo
def get_tvm_output(symbol, x, params, target, ctx,
out_shape=(1000,), input_name='image', dtype='float32'):
out_shape=(1, 1000), input_name='image', dtype='float32'):
shape_dict = {input_name : x.shape}
with nnvm.compiler.build_config(opt_level=3):
with nnvm.compiler.build_config(opt_level=2):
graph, lib, params = nnvm.compiler.build(symbol, target, shape_dict, params=params)
m = graph_runtime.create(graph, lib, ctx)
# set inputs
......@@ -28,7 +28,7 @@ def get_tvm_output(symbol, x, params, target, ctx,
out = m.get_output(0, tvm.nd.empty(out_shape, dtype))
return out.asnumpy()
def test_model_checkonly(model_file, model_name=''):
def run_model_checkonly(model_file, model_name=''):
model = cm.models.MLModel(model_file)
sym, params = nnvm.frontend.from_coreml(model)
x = model_zoo.get_cat_image()
......@@ -38,11 +38,11 @@ def test_model_checkonly(model_file, model_name=''):
def test_mobilenet_checkonly():
model_file = model_zoo.get_mobilenet()
test_model_checkonly(model_file, 'mobilenet')
run_model_checkonly(model_file, 'mobilenet')
def test_resnet50_checkonly():
model_file = model_zoo.get_resnet50()
test_model_checkonly(model_file, 'resnet50')
run_model_checkonly(model_file, 'resnet50')
def run_tvm_graph(graph_def, input_data, input_name, output_shape, output_dtype='float32'):
""" Generic function to compile on nnvm and execute on tvm """
......
......@@ -231,7 +231,7 @@ class Function(Expr):
_make.Function, params, body, ret_type, type_params, attrs)
def __call__(self, *args):
"""Invoke the gobal function.
"""Invoke the global function.
Parameters
----------
......
......@@ -11,3 +11,4 @@ from .mxnet import from_mxnet
from .keras import from_keras
from .onnx import from_onnx
from .tflite import from_tflite
from .coreml import from_coreml
......@@ -240,6 +240,7 @@ class ExprTable(object):
self.exprs = {}
self.params = {}
self.const_ctr = 1
self.in_padding = False
def new_const(self, value, shape=None, dtype="float32"):
name = "_param_%d" % (self.const_ctr)
......@@ -257,6 +258,13 @@ class ExprTable(object):
assert isinstance(expr, _expr.Expr)
self.exprs[name] = expr
def set_padding(self, paddings):
self.paddings = paddings
self.in_padding = True
def clear_padding(self):
self.in_padding = False
class AttrCvt(object):
"""Common attribute conveter. An AttrConverter instance is a callable:
......
......@@ -625,7 +625,7 @@ def keras_op_to_relay(inexpr, keras_layer, outname, etab):
etab.set_expr(name, out)
def from_keras(model, shape_dict):
def from_keras(model, shape=None):
"""Convert keras model to relay Function.
Parameters
......@@ -633,8 +633,8 @@ def from_keras(model, shape_dict):
model : keras.engine.training.Model
The keras model to be converted.
shape_dict : dict of str to int list/tuple
Input shapes of the model.
shape: dict of str to int list/tuple
Input shapes of the model, optional
Returns
-------
......@@ -642,7 +642,7 @@ def from_keras(model, shape_dict):
Compatible relay Function.
params : dict of str to tvm.NDArray
The parameter dict to be used by relay.
The parameter dict to be used by Relay.
"""
try:
import keras
......@@ -659,8 +659,8 @@ def from_keras(model, shape_dict):
for keras_layer in model.layers:
if isinstance(keras_layer, keras.engine.InputLayer):
input_name = keras_layer.name
shape = shape_dict[input_name] if input_name in shape_dict else None
etab.set_expr(input_name, _expr.var(input_name, shape=shape))
input_shape = shape[input_name] if shape is not None and input_name in shape else None
etab.set_expr(input_name, _expr.var(input_name, shape=input_shape))
else:
inbound_nodes = keras_layer.inbound_nodes if hasattr(keras_layer, 'inbound_nodes') \
else keras_layer._inbound_nodes if hasattr(keras_layer, '_inbound_nodes') \
......
from six.moves import urllib
import os
from PIL import Image
import numpy as np
def download(url, path, overwrite=False):
if os.path.exists(path) and not overwrite:
return
print('Downloading {} to {}.'.format(url, path))
urllib.request.urlretrieve(url, path)
def get_mobilenet():
url = 'https://docs-assets.developer.apple.com/coreml/models/MobileNet.mlmodel'
dst = 'mobilenet.mlmodel'
real_dst = os.path.abspath(os.path.join(os.path.dirname(__file__), dst))
download(url, real_dst)
return os.path.abspath(real_dst)
def get_resnet50():
url = 'https://docs-assets.developer.apple.com/coreml/models/Resnet50.mlmodel'
dst = 'resnet50.mlmodel'
real_dst = os.path.abspath(os.path.join(os.path.dirname(__file__), dst))
download(url, real_dst)
return os.path.abspath(real_dst)
def get_cat_image():
url = 'https://gist.githubusercontent.com/zhreshold/bcda4716699ac97ea44f791c24310193/raw/fa7ef0e9c9a5daea686d6473a62aacd1a5885849/cat.png'
dst = 'cat.png'
real_dst = os.path.abspath(os.path.join(os.path.dirname(__file__), dst))
download(url, real_dst)
img = Image.open(real_dst).resize((224, 224))
img = np.transpose(img, (2, 0, 1))[np.newaxis, :]
return np.asarray(img)
\ No newline at end of file
......@@ -27,6 +27,9 @@ python3 -m nose -v nnvm/tests/python/frontend/keras || exit -1
echo "Running nnvm Tensorflow frontend test..."
python3 -m nose -v nnvm/tests/python/frontend/tensorflow || exit -1
echo "Running nnvm CoreML frontend test..."
python3 -m nose -v nnvm/tests/python/frontend/coreml || exit -1
echo "Running relay MXNet frontend test..."
python3 -m nose -v tests/python/frontend/mxnet || exit -1
......@@ -36,6 +39,9 @@ python3 -m nose -v tests/python/frontend/keras || exit -1
echo "Running relay ONNX frondend test..."
python3 -m nose -v tests/python/frontend/onnx || exit -1
echo "Running relay CoreML frondend test..."
python3 -m nose -v tests/python/frontend/coreml || exit -1
echo "Running nnvm to relay frontend test..."
python3 -m nose -v tests/python/frontend/nnvm_to_relay || exit -1
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment