Commit 373a8caa by Siva Committed by Tianqi Chen

[NNVM][TENSORFLOW] Mobilenet support. (#1335)

parent ca2ad6d4
......@@ -35,6 +35,11 @@ class AttrCvt(object):
self._ignores.append('use_cudnn_on_gpu')
self._ignores.append('_node_name')
self._ignores.append('is_training')
# Retain the names
try:
attrs['name'] = attrs['_node_name']
except KeyError:
pass
return AttrConvert(self._op_name, self._transforms, self._excludes,
self._disables, self._ignores, self._extras,
self._custom_check)(inputs, attrs, *args)
......@@ -405,13 +410,19 @@ def _concat():
def _reshape():
def _impl(inputs, attr, params):
pop_node = inputs.pop(1)
shape_arg = params[pop_node.list_output_names()[0]]
params.pop(pop_node.list_output_names()[0])
return AttrCvt(
op_name="reshape",
extras={'shape':tuple(shape_arg.asnumpy())},
ignores=['Tshape'])(inputs, attr)
try:
pop_node = inputs[1]
shape_arg = params.pop(pop_node.list_output_names()[0])
inputs.pop(1)
return AttrCvt(
op_name="reshape",
extras={'shape':tuple(shape_arg.asnumpy())},
ignores=['Tshape'])(inputs, attr)
except KeyError:
return AttrCvt(
op_name="reshape_like",
ignores=['Tshape'])(inputs, attr)
return _impl
def _bias_add():
......@@ -427,6 +438,18 @@ def _squeeze():
ignores=['T'])(inputs, attr)
return _impl
def _fused_batch_norm():
def _impl(inputs, attr, params):
# Tensorflow: (data, gamma, beta, moving_mean, moving_variance)
# NNVM: (data, gamma, beta, moving_mean, moving_varience)
return AttrCvt(
op_name='batch_norm',
transforms={'scale_after_normalization':'scale', 'variance_epsilon':'epsilon'},
extras={'axis': 3}, # Fix axis
ignores=['data_format'],
disables=['momentum'])(inputs, attr)
return _impl
def _batch_norm():
def _impl(inputs, attr, params):
# Rearrange inputs from
......@@ -445,19 +468,14 @@ def _batch_norm():
def _relu6():
def _impl(inputs, attr, params):
return _sym.clip(inputs[0], a_min=0, a_max=6)
return _sym.clip(inputs[0], a_min=0, a_max=6, name=attr['_node_name'])
return _impl
def _shape():
def _impl(inputs, attr, params):
input_shapes = attr['_input_shapes'][inputs[0]]
# Fix the -1 dimensions to 1
input_shapes[0] = [1 if x == -1 else x for x in input_shapes[0]]
params[attr['_node_name']] = tvm.nd.array(input_shapes[0])
return _sym.Variable(name=attr['_node_name'],
shape=params[attr['_node_name']].shape)
# Result of this operator is prominently used by reshape operator.
# Just pass the input as it is so that reshape_like can be used there.
return inputs[0]
return _impl
# compatible operators that do NOT require any conversion.
......@@ -491,7 +509,7 @@ _convert_map = {
'Add' : _elemwise('add'),
'Rsqrt' : _rsqrt(),
'Squeeze' : _squeeze(),
'FusedBatchNorm' : _batch_norm(),
'FusedBatchNorm' : _fused_batch_norm(),
'Relu6' : _relu6(),
'DepthwiseConv2dNative' : _depthwise_conv(),
'Shape' : _shape(),
......
......@@ -153,6 +153,35 @@ def read_normalized_tensor_from_image_file(file_name,
np_array = normalized.eval()
return np_array
def get_workload(model_path):
""" Import workload from frozen protobuf
Parameters
----------
model_path: str
model_path on remote repository to download from.
Returns
-------
graph_def: graphdef
graph_def is the tensorflow workload for mobilenet.
"""
repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/'
model_name = os.path.basename(model_path)
model_url = os.path.join(repo_base, model_path)
from mxnet.gluon.utils import download
download(model_url, model_name)
# Creates graph from saved graph_def.pb.
with tf.gfile.FastGFile(os.path.join("./", model_name), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='')
return graph_def
def get_workload_inception_v3():
""" Import Inception V3 workload from frozen protobuf
......@@ -168,23 +197,15 @@ def get_workload_inception_v3():
"""
repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV3/'
model_name = 'inception_v3_2016_08_28_frozen-with_shapes.pb'
model_url = os.path.join(repo_base, model_name)
model_path = 'InceptionV3/inception_v3_2016_08_28_frozen-with_shapes.pb'
image_name = 'elephant-299.jpg'
image_url = os.path.join(repo_base, image_name)
from mxnet.gluon.utils import download
download(model_url, model_name)
download(image_url, image_name)
normalized = read_normalized_tensor_from_image_file(os.path.join("./", image_name))
# Creates graph from saved graph_def.pb.
with tf.gfile.FastGFile(os.path.join("./", model_name), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='')
return (normalized, graph_def)
return (normalized, get_workload(model_path))
def get_workload_inception_v1():
""" Import Inception V1 workload from frozen protobuf
......@@ -203,13 +224,11 @@ def get_workload_inception_v1():
"""
repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/'
model_name = 'classify_image_graph_def-with_shapes.pb'
model_url = os.path.join(repo_base, model_name)
model_path = 'InceptionV1/classify_image_graph_def-with_shapes.pb'
image_name = 'elephant-299.jpg'
image_url = os.path.join(repo_base, image_name)
from mxnet.gluon.utils import download
download(model_url, model_name)
download(image_url, image_name)
if not tf.gfile.Exists(os.path.join("./", image_name)):
......@@ -221,9 +240,20 @@ def get_workload_inception_v1():
tvm_data = Image.open(os.path.join("./", image_name)).resize((299, 299))
tvm_data = np.array(tvm_data)
# Creates graph from saved graph_def.pb.
with tf.gfile.FastGFile(os.path.join("./", model_name), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='')
return (image_data, tvm_data, graph_def)
return (image_data, tvm_data, get_workload(model_path))
def get_workload_mobilenet():
""" Import mobilenet workload from frozen protobuf
Parameters
----------
Nothing.
Returns
-------
graph_def: graphdef
graph_def is the tensorflow workload for mobilenet.
"""
return get_workload("MobilenetV1/mobilenet_v1_1.0_224_frozen-with-shapes.pb")
......@@ -407,6 +407,29 @@ def test_forward_inception_v1():
np.testing.assert_allclose(tf_output, tvm_output, rtol=2e-2, atol=2e-2)
#######################################################################
# Mobilenet
# ---------
def test_forward_mobilenet():
'''test mobilenet model'''
with tf.Graph().as_default():
graph_def = nnvm.testing.tf.get_workload_mobilenet()
# Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32')
out_node = 'MobilenetV1/Predictions/Reshape_1'
with tf.Session() as sess:
tf_output = run_tf_graph(sess, data, 'input:0', out_node + ':0')
out_shape = tf_output.shape
tvm_output = run_tvm_graph(graph_def, data, 'input', out_shape, 'float32')
top_tvm = np.squeeze(tvm_output).argsort()[-10:][::-1]
top_tf = np.squeeze(tf_output).argsort()[-10:][::-1]
np.testing.assert_allclose(np.squeeze(tvm_output), np.squeeze(tf_output), rtol=1e-5, atol=1e-5)
#######################################################################
# Main
# ----
if __name__ == '__main__':
......@@ -419,3 +442,4 @@ if __name__ == '__main__':
test_forward_multi_input()
test_forward_inception_v3()
test_forward_inception_v1()
test_forward_mobilenet()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment