Commit fdf795a0 by Siva Committed by Tianqi Chen

[FRONTEND][TENSORFLOW] GPU support for tensorflow models. (#1718)

parent ae5a28db
...@@ -35,6 +35,7 @@ class AttrCvt(object): ...@@ -35,6 +35,7 @@ class AttrCvt(object):
self._ignores.append('use_cudnn_on_gpu') self._ignores.append('use_cudnn_on_gpu')
self._ignores.append('_node_name') self._ignores.append('_node_name')
self._ignores.append('is_training') self._ignores.append('is_training')
self._ignores.append('_target_layout')
# Retain the names # Retain the names
try: try:
attrs['name'] = attrs['_node_name'] attrs['name'] = attrs['_node_name']
...@@ -121,6 +122,9 @@ def _pooling(name): ...@@ -121,6 +122,9 @@ def _pooling(name):
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
attr['data_format'] = attr['data_format'].decode("utf-8") attr['data_format'] = attr['data_format'].decode("utf-8")
flip_layout = False
input_shape = attr['_input_shapes'][inputs[0]][0]
if attr['data_format'] == 'NHWC': if attr['data_format'] == 'NHWC':
attr['kernel_shape'] = (attr['ksize'][1], attr['ksize'][2]) attr['kernel_shape'] = (attr['ksize'][1], attr['ksize'][2])
...@@ -129,11 +133,17 @@ def _pooling(name): ...@@ -129,11 +133,17 @@ def _pooling(name):
else: else:
raise TypeError("Unsupported data_format type : {}".format(attr['data_format'])) raise TypeError("Unsupported data_format type : {}".format(attr['data_format']))
if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC":
tmp_shape = attr['_input_shapes'][inputs[0]][0]
input_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)]
inputs[0] = _sym.transpose(inputs[0], axes=(0, 3, 1, 2))
attr['data_format'] = "NCHW"
flip_layout = True
# Fix strides # Fix strides
attr['strides'] = (attr['strides'][1], attr['strides'][2]) attr['strides'] = (attr['strides'][1], attr['strides'][2])
# Fix padding # Fix padding
input_shapes = attr['_input_shapes'][inputs[0]]
attr['padding'] = attr['padding'].decode("utf-8") attr['padding'] = attr['padding'].decode("utf-8")
if attr['padding'] == 'VALID': if attr['padding'] == 'VALID':
...@@ -142,11 +152,11 @@ def _pooling(name): ...@@ -142,11 +152,11 @@ def _pooling(name):
stride_h, stride_w = attr['strides'] stride_h, stride_w = attr['strides']
kernel_h, kernel_w = attr['kernel_shape'] kernel_h, kernel_w = attr['kernel_shape']
if attr['data_format'] == 'NHWC': if attr['data_format'] == 'NHWC':
in_h = input_shapes[0][1] in_h = input_shape[1]
in_w = input_shapes[0][2] in_w = input_shape[2]
else: else:
in_h = input_shapes[0][2] in_h = input_shape[2]
in_w = input_shapes[0][3] in_w = input_shape[3]
pad_v = _get_pad_pair(in_h, kernel_h, stride_h) pad_v = _get_pad_pair(in_h, kernel_h, stride_h)
pad_h = _get_pad_pair(in_w, kernel_w, stride_w) pad_h = _get_pad_pair(in_w, kernel_w, stride_w)
...@@ -158,7 +168,7 @@ def _pooling(name): ...@@ -158,7 +168,7 @@ def _pooling(name):
if name == "avg_pool": if name == "avg_pool":
attr['count_include_pad'] = False attr['count_include_pad'] = False
return AttrCvt( out = AttrCvt(
op_name=_dimension_picker(name), op_name=_dimension_picker(name),
transforms={ transforms={
'kernel_shape':'pool_size', 'kernel_shape':'pool_size',
...@@ -166,33 +176,53 @@ def _pooling(name): ...@@ -166,33 +176,53 @@ def _pooling(name):
ignores=['ksize'], ignores=['ksize'],
extras={'ceil_mode': False}, extras={'ceil_mode': False},
custom_check=_dimension_constraint())(inputs, attr) custom_check=_dimension_constraint())(inputs, attr)
if flip_layout:
out = _sym.transpose(out, axes=(0, 2, 3, 1))
return out
return _impl return _impl
def _conv(opname): def _conv(opname):
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
attr['data_format'] = attr['data_format'].decode("utf-8") attr['data_format'] = attr['data_format'].decode("utf-8")
input_shapes = attr['_input_shapes'][inputs[0]] flip_layout = False
# Extract kernel shape from params input_shape = attr['_input_shapes'][inputs[0]][0]
conv_param_weights = params[inputs[1].list_output_names()[0]] weights_shape = params[inputs[1].list_output_names()[0]].shape
if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC":
input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)]
inputs[0] = _sym.transpose(inputs[0], axes=(0, 3, 1, 2))
if opname == 'conv':
weights_shape = [weights_shape[ii] for ii in (3, 2, 0, 1)]
inputs[1] = _sym.transpose(inputs[1], axes=(3, 2, 0, 1))
else:
weights_shape = [weights_shape[ii] for ii in (2, 3, 0, 1)]
inputs[1] = _sym.transpose(inputs[1], axes=(2, 3, 0, 1))
attr['data_format'] = "NCHW"
flip_layout = True
if attr['data_format'] == 'NHWC': if attr['data_format'] == 'NHWC':
kernel_h, kernel_w, _, depth_mult = conv_param_weights.shape kernel_h, kernel_w, _, depth_mult = weights_shape
attr['kernel_shape'] = (conv_param_weights.shape[0], conv_param_weights.shape[1]) attr['kernel_shape'] = (weights_shape[0], weights_shape[1])
if opname == 'conv': if opname == 'conv':
attr['channels'] = conv_param_weights.shape[3] attr['channels'] = weights_shape[3]
else: else:
attr['channels'] = input_shapes[0][3] * depth_mult attr['channels'] = input_shape[3] * depth_mult
if 'dilations' in attr: if 'dilations' in attr:
attr['dilations'] = (attr['dilations'][0], attr['dilations'][1]) attr['dilations'] = (attr['dilations'][0], attr['dilations'][1])
elif attr['data_format'] == 'NCHW': elif attr['data_format'] == 'NCHW':
depth_mult, _, kernel_h, kernel_w = conv_param_weights.shape depth_mult, _, kernel_h, kernel_w = weights_shape
attr['kernel_shape'] = (conv_param_weights.shape[2], conv_param_weights.shape[3]) attr['kernel_shape'] = (weights_shape[2], weights_shape[3])
if opname == 'conv': if opname == 'conv':
attr['channels'] = conv_param_weights.shape[1] attr['channels'] = weights_shape[0]
else: else:
attr['channels'] = input_shapes[0][1] * depth_mult attr['channels'] = input_shape[0] * depth_mult
if attr['channels'] < 0:
attr['channels'] *= -1
if 'dilations' in attr: if 'dilations' in attr:
attr['dilations'] = (attr['dilations'][2], attr['dilations'][3]) attr['dilations'] = (attr['dilations'][2], attr['dilations'][3])
...@@ -215,11 +245,11 @@ def _conv(opname): ...@@ -215,11 +245,11 @@ def _conv(opname):
stride_h, stride_w = attr['strides'] stride_h, stride_w = attr['strides']
kernel_h, kernel_w = attr['kernel_shape'] kernel_h, kernel_w = attr['kernel_shape']
if attr['data_format'] == 'NHWC': if attr['data_format'] == 'NHWC':
in_h = input_shapes[0][1] in_h = input_shape[1]
in_w = input_shapes[0][2] in_w = input_shape[2]
else: else:
in_h = input_shapes[0][2] in_h = input_shape[2]
in_w = input_shapes[0][3] in_w = input_shape[3]
pad_v = _get_pad_pair(in_h, kernel_h, stride_h) pad_v = _get_pad_pair(in_h, kernel_h, stride_h)
pad_h = _get_pad_pair(in_w, kernel_w, stride_w) pad_h = _get_pad_pair(in_w, kernel_w, stride_w)
...@@ -248,7 +278,7 @@ def _conv(opname): ...@@ -248,7 +278,7 @@ def _conv(opname):
else: else:
attr['kernel_layout'] = 'HWOI' if attr['data_format'] == 'NHWC' else 'OIHW' attr['kernel_layout'] = 'HWOI' if attr['data_format'] == 'NHWC' else 'OIHW'
return AttrCvt( out = AttrCvt(
op_name=_dimension_picker('conv'), op_name=_dimension_picker('conv'),
transforms={ transforms={
'kernel_shape': 'kernel_size', 'kernel_shape': 'kernel_size',
...@@ -257,6 +287,11 @@ def _conv(opname): ...@@ -257,6 +287,11 @@ def _conv(opname):
'group': ('groups', 1)}, 'group': ('groups', 1)},
extras={'use_bias': len(inputs) == 3}, extras={'use_bias': len(inputs) == 3},
custom_check=_dimension_constraint())(inputs, attr) custom_check=_dimension_constraint())(inputs, attr)
if flip_layout:
out = _sym.transpose(out, axes=(0, 2, 3, 1))
return out
return _impl return _impl
def _decode_image(): def _decode_image():
...@@ -305,7 +340,7 @@ def _matmul(): ...@@ -305,7 +340,7 @@ def _matmul():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
channels = _infer_channels(inputs[1], params, not attr['transpose_b']) channels = _infer_channels(inputs[1], params, not attr['transpose_b'])
if attr['transpose_a']: if attr['transpose_a']:
inputs[0] = _sym.transpose(inputs[0], axis(1, 0)) inputs[0] = _sym.transpose(inputs[0], axes(1, 0))
if not attr['transpose_b']: if not attr['transpose_b']:
inputs[1] = _sym.transpose(inputs[1], axes=(1, 0)) inputs[1] = _sym.transpose(inputs[1], axes=(1, 0))
return AttrCvt(op_name="dense", return AttrCvt(op_name="dense",
...@@ -948,7 +983,7 @@ class GraphProto(object): ...@@ -948,7 +983,7 @@ class GraphProto(object):
self._num_param = 0 self._num_param = 0
self._num_rnn_layer = False self._num_rnn_layer = False
def from_tensorflow(self, graph): def from_tensorflow(self, graph, layout="NHWC"):
"""Construct nnvm nodes from tensorflow graph definition - GraphDef. """Construct nnvm nodes from tensorflow graph definition - GraphDef.
Follow the tensorflow graph definition to parse and convert it to NNVM. Follow the tensorflow graph definition to parse and convert it to NNVM.
...@@ -1036,6 +1071,9 @@ class GraphProto(object): ...@@ -1036,6 +1071,9 @@ class GraphProto(object):
# Pass the node name too in attr # Pass the node name too in attr
attr["_node_name"] = node.name attr["_node_name"] = node.name
# Pass the target layout
attr["_target_layout"] = layout
#ToDo: Some of the tensorflow operators internaly maintain #ToDo: Some of the tensorflow operators internaly maintain
#execution layers and its output name will the layer number along with #execution layers and its output name will the layer number along with
#graph node name.eg: Node name:- 'Model/RNN/cell_0/RnnCell', but the #graph node name.eg: Node name:- 'Model/RNN/cell_0/RnnCell', but the
...@@ -1265,7 +1303,7 @@ class GraphProto(object): ...@@ -1265,7 +1303,7 @@ class GraphProto(object):
return inputs return inputs
def from_tensorflow(graph): def from_tensorflow(graph, layout="NHWC"):
""" Load tensorflow graph which is a python tensorflow graph object into nnvm graph. """ Load tensorflow graph which is a python tensorflow graph object into nnvm graph.
The companion parameters will be handled automatically. The companion parameters will be handled automatically.
...@@ -1283,5 +1321,5 @@ def from_tensorflow(graph): ...@@ -1283,5 +1321,5 @@ def from_tensorflow(graph):
Dict of converted parameters stored in tvm.ndarray format Dict of converted parameters stored in tvm.ndarray format
""" """
g = GraphProto() g = GraphProto()
sym, params = g.from_tensorflow(graph) sym, params = g.from_tensorflow(graph, layout)
return sym, params return sym, params
...@@ -26,11 +26,15 @@ import nnvm.testing.tf ...@@ -26,11 +26,15 @@ import nnvm.testing.tf
####################################################################### #######################################################################
# Generic run functions for TVM & tensorflow # Generic run functions for TVM & tensorflow
# ------------------------------------------ # ------------------------------------------
def run_tvm_graph(graph_def, input_data, input_node, output_shape, output_dtype): def run_tvm_graph(graph_def, input_data, input_node, output_shape, output_dtype, target='llvm'):
""" Generic function to compile on nnvm and execute on tvm """ """ Generic function to compile on nnvm and execute on tvm """
sym, params = nnvm.frontend.from_tensorflow(graph_def) layout = None
target = 'llvm' if target == "cuda":
layout = "NCHW"
sym, params = nnvm.frontend.from_tensorflow(graph_def, layout=layout)
target_host = 'llvm'
if isinstance(input_data, list): if isinstance(input_data, list):
shape_dict = {} shape_dict = {}
dtype_dict = {} dtype_dict = {}
...@@ -41,10 +45,10 @@ def run_tvm_graph(graph_def, input_data, input_node, output_shape, output_dtype) ...@@ -41,10 +45,10 @@ def run_tvm_graph(graph_def, input_data, input_node, output_shape, output_dtype)
shape_dict = {input_node: input_data.shape} shape_dict = {input_node: input_data.shape}
dtype_dict = {input_node: input_data.dtype} dtype_dict = {input_node: input_data.dtype}
graph, lib, params = nnvm.compiler.build(sym, target, shape_dict, graph, lib, params = nnvm.compiler.build(sym, target=target, target_host=target_host, shape=shape_dict,
dtype=dtype_dict, params=params) dtype=dtype_dict, params=params)
ctx = tvm.cpu(0) ctx = tvm.context(target, 0)
from tvm.contrib import graph_runtime from tvm.contrib import graph_runtime
m = graph_runtime.create(graph, lib, ctx) m = graph_runtime.create(graph, lib, ctx)
# set inputs # set inputs
...@@ -106,9 +110,17 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False) ...@@ -106,9 +110,17 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False)
) )
tf_output = run_tf_graph(sess, in_data, in_name, out_name) tf_output = run_tf_graph(sess, in_data, in_name, out_name)
tvm_output = run_tvm_graph(final_graph_def, in_data,
in_node, tf_output.shape, tf_output.dtype) for device in ["llvm", "cuda"]:
np.testing.assert_allclose(tf_output, tvm_output, atol=1e-5, rtol=1e-5) ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
continue
tvm_output = run_tvm_graph(final_graph_def, in_data,
in_node, tf_output.shape, tf_output.dtype, target=device)
np.testing.assert_allclose(tf_output, tvm_output, atol=1e-5, rtol=1e-5)
sess.close() sess.close()
####################################################################### #######################################################################
......
...@@ -50,6 +50,16 @@ map_proto_url = os.path.join(repo_base, map_proto) ...@@ -50,6 +50,16 @@ map_proto_url = os.path.join(repo_base, map_proto)
lable_map = 'imagenet_synset_to_human_label_map.txt' lable_map = 'imagenet_synset_to_human_label_map.txt'
lable_map_url = os.path.join(repo_base, lable_map) lable_map_url = os.path.join(repo_base, lable_map)
# Target settings
# Use these commented settings to build for cuda.
#target = 'cuda'
#target_host = 'llvm'
#layout = "NCHW"
#ctx = tvm.gpu(0)
target = 'llvm'
target_host = 'llvm'
layout = None
ctx = tvm.cpu(0)
###################################################################### ######################################################################
# Download required files # Download required files
...@@ -99,7 +109,7 @@ x = np.array(image) ...@@ -99,7 +109,7 @@ x = np.array(image)
# Results: # Results:
# sym: nnvm graph for given tensorflow protobuf. # sym: nnvm graph for given tensorflow protobuf.
# params: params converted from tensorflow params (tensor protobuf). # params: params converted from tensorflow params (tensor protobuf).
sym, params = nnvm.frontend.from_tensorflow(graph_def) sym, params = nnvm.frontend.from_tensorflow(graph_def, layout=layout)
print ("Tensorflow protobuf imported as nnvm graph") print ("Tensorflow protobuf imported as nnvm graph")
###################################################################### ######################################################################
...@@ -113,18 +123,16 @@ print ("Tensorflow protobuf imported as nnvm graph") ...@@ -113,18 +123,16 @@ print ("Tensorflow protobuf imported as nnvm graph")
# lib: target library which can be deployed on target with tvm runtime. # lib: target library which can be deployed on target with tvm runtime.
import nnvm.compiler import nnvm.compiler
target = 'llvm'
shape_dict = {'DecodeJpeg/contents': x.shape} shape_dict = {'DecodeJpeg/contents': x.shape}
dtype_dict = {'DecodeJpeg/contents': 'uint8'} dtype_dict = {'DecodeJpeg/contents': 'uint8'}
graph, lib, params = nnvm.compiler.build(sym, target, shape_dict, dtype=dtype_dict, params=params) graph, lib, params = nnvm.compiler.build(sym, shape=shape_dict, target=target, target_host=target_host, dtype=dtype_dict, params=params)
###################################################################### ######################################################################
# Execute the portable graph on TVM # Execute the portable graph on TVM
# --------------------------------- # ---------------------------------
# Now we can try deploying the NNVM compiled model on cpu target. # Now we can try deploying the NNVM compiled model on target.
from tvm.contrib import graph_runtime from tvm.contrib import graph_runtime
ctx = tvm.cpu(0)
dtype = 'uint8' dtype = 'uint8'
m = graph_runtime.create(graph, lib, ctx) m = graph_runtime.create(graph, lib, ctx)
# set inputs # set inputs
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment