Commit fdf795a0 by Siva Committed by Tianqi Chen

[FRONTEND][TENSORFLOW] GPU support for tensorflow models. (#1718)

parent ae5a28db
......@@ -35,6 +35,7 @@ class AttrCvt(object):
self._ignores.append('use_cudnn_on_gpu')
self._ignores.append('_node_name')
self._ignores.append('is_training')
self._ignores.append('_target_layout')
# Retain the names
try:
attrs['name'] = attrs['_node_name']
......@@ -121,6 +122,9 @@ def _pooling(name):
def _impl(inputs, attr, params):
attr['data_format'] = attr['data_format'].decode("utf-8")
flip_layout = False
input_shape = attr['_input_shapes'][inputs[0]][0]
if attr['data_format'] == 'NHWC':
attr['kernel_shape'] = (attr['ksize'][1], attr['ksize'][2])
......@@ -129,11 +133,17 @@ def _pooling(name):
else:
raise TypeError("Unsupported data_format type : {}".format(attr['data_format']))
if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC":
tmp_shape = attr['_input_shapes'][inputs[0]][0]
input_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)]
inputs[0] = _sym.transpose(inputs[0], axes=(0, 3, 1, 2))
attr['data_format'] = "NCHW"
flip_layout = True
# Fix strides
attr['strides'] = (attr['strides'][1], attr['strides'][2])
# Fix padding
input_shapes = attr['_input_shapes'][inputs[0]]
attr['padding'] = attr['padding'].decode("utf-8")
if attr['padding'] == 'VALID':
......@@ -142,11 +152,11 @@ def _pooling(name):
stride_h, stride_w = attr['strides']
kernel_h, kernel_w = attr['kernel_shape']
if attr['data_format'] == 'NHWC':
in_h = input_shapes[0][1]
in_w = input_shapes[0][2]
in_h = input_shape[1]
in_w = input_shape[2]
else:
in_h = input_shapes[0][2]
in_w = input_shapes[0][3]
in_h = input_shape[2]
in_w = input_shape[3]
pad_v = _get_pad_pair(in_h, kernel_h, stride_h)
pad_h = _get_pad_pair(in_w, kernel_w, stride_w)
......@@ -158,7 +168,7 @@ def _pooling(name):
if name == "avg_pool":
attr['count_include_pad'] = False
return AttrCvt(
out = AttrCvt(
op_name=_dimension_picker(name),
transforms={
'kernel_shape':'pool_size',
......@@ -166,33 +176,53 @@ def _pooling(name):
ignores=['ksize'],
extras={'ceil_mode': False},
custom_check=_dimension_constraint())(inputs, attr)
if flip_layout:
out = _sym.transpose(out, axes=(0, 2, 3, 1))
return out
return _impl
def _conv(opname):
def _impl(inputs, attr, params):
attr['data_format'] = attr['data_format'].decode("utf-8")
input_shapes = attr['_input_shapes'][inputs[0]]
flip_layout = False
# Extract kernel shape from params
conv_param_weights = params[inputs[1].list_output_names()[0]]
input_shape = attr['_input_shapes'][inputs[0]][0]
weights_shape = params[inputs[1].list_output_names()[0]].shape
if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC":
input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)]
inputs[0] = _sym.transpose(inputs[0], axes=(0, 3, 1, 2))
if opname == 'conv':
weights_shape = [weights_shape[ii] for ii in (3, 2, 0, 1)]
inputs[1] = _sym.transpose(inputs[1], axes=(3, 2, 0, 1))
else:
weights_shape = [weights_shape[ii] for ii in (2, 3, 0, 1)]
inputs[1] = _sym.transpose(inputs[1], axes=(2, 3, 0, 1))
attr['data_format'] = "NCHW"
flip_layout = True
if attr['data_format'] == 'NHWC':
kernel_h, kernel_w, _, depth_mult = conv_param_weights.shape
attr['kernel_shape'] = (conv_param_weights.shape[0], conv_param_weights.shape[1])
kernel_h, kernel_w, _, depth_mult = weights_shape
attr['kernel_shape'] = (weights_shape[0], weights_shape[1])
if opname == 'conv':
attr['channels'] = conv_param_weights.shape[3]
attr['channels'] = weights_shape[3]
else:
attr['channels'] = input_shapes[0][3] * depth_mult
attr['channels'] = input_shape[3] * depth_mult
if 'dilations' in attr:
attr['dilations'] = (attr['dilations'][0], attr['dilations'][1])
elif attr['data_format'] == 'NCHW':
depth_mult, _, kernel_h, kernel_w = conv_param_weights.shape
attr['kernel_shape'] = (conv_param_weights.shape[2], conv_param_weights.shape[3])
depth_mult, _, kernel_h, kernel_w = weights_shape
attr['kernel_shape'] = (weights_shape[2], weights_shape[3])
if opname == 'conv':
attr['channels'] = conv_param_weights.shape[1]
attr['channels'] = weights_shape[0]
else:
attr['channels'] = input_shapes[0][1] * depth_mult
attr['channels'] = input_shape[0] * depth_mult
if attr['channels'] < 0:
attr['channels'] *= -1
if 'dilations' in attr:
attr['dilations'] = (attr['dilations'][2], attr['dilations'][3])
......@@ -215,11 +245,11 @@ def _conv(opname):
stride_h, stride_w = attr['strides']
kernel_h, kernel_w = attr['kernel_shape']
if attr['data_format'] == 'NHWC':
in_h = input_shapes[0][1]
in_w = input_shapes[0][2]
in_h = input_shape[1]
in_w = input_shape[2]
else:
in_h = input_shapes[0][2]
in_w = input_shapes[0][3]
in_h = input_shape[2]
in_w = input_shape[3]
pad_v = _get_pad_pair(in_h, kernel_h, stride_h)
pad_h = _get_pad_pair(in_w, kernel_w, stride_w)
......@@ -248,7 +278,7 @@ def _conv(opname):
else:
attr['kernel_layout'] = 'HWOI' if attr['data_format'] == 'NHWC' else 'OIHW'
return AttrCvt(
out = AttrCvt(
op_name=_dimension_picker('conv'),
transforms={
'kernel_shape': 'kernel_size',
......@@ -257,6 +287,11 @@ def _conv(opname):
'group': ('groups', 1)},
extras={'use_bias': len(inputs) == 3},
custom_check=_dimension_constraint())(inputs, attr)
if flip_layout:
out = _sym.transpose(out, axes=(0, 2, 3, 1))
return out
return _impl
def _decode_image():
......@@ -305,7 +340,7 @@ def _matmul():
def _impl(inputs, attr, params):
channels = _infer_channels(inputs[1], params, not attr['transpose_b'])
if attr['transpose_a']:
inputs[0] = _sym.transpose(inputs[0], axis(1, 0))
inputs[0] = _sym.transpose(inputs[0], axes(1, 0))
if not attr['transpose_b']:
inputs[1] = _sym.transpose(inputs[1], axes=(1, 0))
return AttrCvt(op_name="dense",
......@@ -948,7 +983,7 @@ class GraphProto(object):
self._num_param = 0
self._num_rnn_layer = False
def from_tensorflow(self, graph):
def from_tensorflow(self, graph, layout="NHWC"):
"""Construct nnvm nodes from tensorflow graph definition - GraphDef.
Follow the tensorflow graph definition to parse and convert it to NNVM.
......@@ -1036,6 +1071,9 @@ class GraphProto(object):
# Pass the node name too in attr
attr["_node_name"] = node.name
# Pass the target layout
attr["_target_layout"] = layout
#ToDo: Some of the tensorflow operators internaly maintain
#execution layers and its output name will the layer number along with
#graph node name.eg: Node name:- 'Model/RNN/cell_0/RnnCell', but the
......@@ -1265,7 +1303,7 @@ class GraphProto(object):
return inputs
def from_tensorflow(graph):
def from_tensorflow(graph, layout="NHWC"):
""" Load tensorflow graph which is a python tensorflow graph object into nnvm graph.
The companion parameters will be handled automatically.
......@@ -1283,5 +1321,5 @@ def from_tensorflow(graph):
Dict of converted parameters stored in tvm.ndarray format
"""
g = GraphProto()
sym, params = g.from_tensorflow(graph)
sym, params = g.from_tensorflow(graph, layout)
return sym, params
......@@ -26,11 +26,15 @@ import nnvm.testing.tf
#######################################################################
# Generic run functions for TVM & tensorflow
# ------------------------------------------
def run_tvm_graph(graph_def, input_data, input_node, output_shape, output_dtype):
def run_tvm_graph(graph_def, input_data, input_node, output_shape, output_dtype, target='llvm'):
""" Generic function to compile on nnvm and execute on tvm """
sym, params = nnvm.frontend.from_tensorflow(graph_def)
target = 'llvm'
layout = None
if target == "cuda":
layout = "NCHW"
sym, params = nnvm.frontend.from_tensorflow(graph_def, layout=layout)
target_host = 'llvm'
if isinstance(input_data, list):
shape_dict = {}
dtype_dict = {}
......@@ -41,10 +45,10 @@ def run_tvm_graph(graph_def, input_data, input_node, output_shape, output_dtype)
shape_dict = {input_node: input_data.shape}
dtype_dict = {input_node: input_data.dtype}
graph, lib, params = nnvm.compiler.build(sym, target, shape_dict,
graph, lib, params = nnvm.compiler.build(sym, target=target, target_host=target_host, shape=shape_dict,
dtype=dtype_dict, params=params)
ctx = tvm.cpu(0)
ctx = tvm.context(target, 0)
from tvm.contrib import graph_runtime
m = graph_runtime.create(graph, lib, ctx)
# set inputs
......@@ -106,9 +110,17 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False)
)
tf_output = run_tf_graph(sess, in_data, in_name, out_name)
tvm_output = run_tvm_graph(final_graph_def, in_data,
in_node, tf_output.shape, tf_output.dtype)
np.testing.assert_allclose(tf_output, tvm_output, atol=1e-5, rtol=1e-5)
for device in ["llvm", "cuda"]:
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
continue
tvm_output = run_tvm_graph(final_graph_def, in_data,
in_node, tf_output.shape, tf_output.dtype, target=device)
np.testing.assert_allclose(tf_output, tvm_output, atol=1e-5, rtol=1e-5)
sess.close()
#######################################################################
......
......@@ -50,6 +50,16 @@ map_proto_url = os.path.join(repo_base, map_proto)
lable_map = 'imagenet_synset_to_human_label_map.txt'
lable_map_url = os.path.join(repo_base, lable_map)
# Target settings
# Use these commented settings to build for cuda.
#target = 'cuda'
#target_host = 'llvm'
#layout = "NCHW"
#ctx = tvm.gpu(0)
target = 'llvm'
target_host = 'llvm'
layout = None
ctx = tvm.cpu(0)
######################################################################
# Download required files
......@@ -99,7 +109,7 @@ x = np.array(image)
# Results:
# sym: nnvm graph for given tensorflow protobuf.
# params: params converted from tensorflow params (tensor protobuf).
sym, params = nnvm.frontend.from_tensorflow(graph_def)
sym, params = nnvm.frontend.from_tensorflow(graph_def, layout=layout)
print ("Tensorflow protobuf imported as nnvm graph")
######################################################################
......@@ -113,18 +123,16 @@ print ("Tensorflow protobuf imported as nnvm graph")
# lib: target library which can be deployed on target with tvm runtime.
import nnvm.compiler
target = 'llvm'
shape_dict = {'DecodeJpeg/contents': x.shape}
dtype_dict = {'DecodeJpeg/contents': 'uint8'}
graph, lib, params = nnvm.compiler.build(sym, target, shape_dict, dtype=dtype_dict, params=params)
graph, lib, params = nnvm.compiler.build(sym, shape=shape_dict, target=target, target_host=target_host, dtype=dtype_dict, params=params)
######################################################################
# Execute the portable graph on TVM
# ---------------------------------
# Now we can try deploying the NNVM compiled model on cpu target.
# Now we can try deploying the NNVM compiled model on target.
from tvm.contrib import graph_runtime
ctx = tvm.cpu(0)
dtype = 'uint8'
m = graph_runtime.create(graph, lib, ctx)
# set inputs
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment