Commit 046a3ed9 by Siva Committed by Tianqi Chen

[FRONTEND][TENSORFLOW] NCHW layout support (Resnet V1/V2). (#1743)

parent 160e4107
...@@ -110,11 +110,6 @@ def _elemwise(name): ...@@ -110,11 +110,6 @@ def _elemwise(name):
def _impl(inputs, attr, *args): def _impl(inputs, attr, *args):
assert len(inputs) == 2, "Math op take 2 inputs, {} given".format(len(inputs)) assert len(inputs) == 2, "Math op take 2 inputs, {} given".format(len(inputs))
op_name = _math_name_picker(name)(attr) op_name = _math_name_picker(name)(attr)
axis = int(attr.get('axis', 0))
conv_ops = ["conv2d", "conv2d_transpose"]
if op_name == 'broadcast_add' and inputs[0].attr('op_name') in conv_ops:
# TODO: remove hard coded infershape
inputs[1] = _sym.expand_dims(inputs[1], axis=axis, num_newaxis=2)
return get_nnvm_op(op_name)(*inputs) return get_nnvm_op(op_name)(*inputs)
return _impl return _impl
...@@ -128,8 +123,10 @@ def _pooling(name): ...@@ -128,8 +123,10 @@ def _pooling(name):
if attr['data_format'] == 'NHWC': if attr['data_format'] == 'NHWC':
attr['kernel_shape'] = (attr['ksize'][1], attr['ksize'][2]) attr['kernel_shape'] = (attr['ksize'][1], attr['ksize'][2])
attr['strides'] = (attr['strides'][1], attr['strides'][2])
elif attr['data_format'] == 'NCHW': elif attr['data_format'] == 'NCHW':
attr['kernel_shape'] = (attr['ksize'][2], attr['ksize'][3]) attr['kernel_shape'] = (attr['ksize'][2], attr['ksize'][3])
attr['strides'] = (attr['strides'][2], attr['strides'][3])
else: else:
raise TypeError("Unsupported data_format type : {}".format(attr['data_format'])) raise TypeError("Unsupported data_format type : {}".format(attr['data_format']))
...@@ -140,9 +137,6 @@ def _pooling(name): ...@@ -140,9 +137,6 @@ def _pooling(name):
attr['data_format'] = "NCHW" attr['data_format'] = "NCHW"
flip_layout = True flip_layout = True
# Fix strides
attr['strides'] = (attr['strides'][1], attr['strides'][2])
# Fix padding # Fix padding
attr['padding'] = attr['padding'].decode("utf-8") attr['padding'] = attr['padding'].decode("utf-8")
...@@ -188,8 +182,15 @@ def _conv(opname): ...@@ -188,8 +182,15 @@ def _conv(opname):
attr['data_format'] = attr['data_format'].decode("utf-8") attr['data_format'] = attr['data_format'].decode("utf-8")
flip_layout = False flip_layout = False
# NCHW Layout require weights transpose
if attr['data_format'] == 'NCHW':
tmp_shape = attr['_input_shapes'][inputs[1]][0]
tmp_shape = [tmp_shape[ii] for ii in (3, 2, 0, 1)]
inputs[1] = _sym.transpose(inputs[1], axes=(3, 2, 0, 1))
attr['_input_shapes'][inputs[1]] = [tmp_shape]
input_shape = attr['_input_shapes'][inputs[0]][0] input_shape = attr['_input_shapes'][inputs[0]][0]
weights_shape = params[inputs[1].list_output_names()[0]].shape weights_shape = attr['_input_shapes'][inputs[1]][0]
if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC": if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC":
input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)] input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)]
...@@ -202,6 +203,7 @@ def _conv(opname): ...@@ -202,6 +203,7 @@ def _conv(opname):
inputs[1] = _sym.transpose(inputs[1], axes=(2, 3, 0, 1)) inputs[1] = _sym.transpose(inputs[1], axes=(2, 3, 0, 1))
attr['data_format'] = "NCHW" attr['data_format'] = "NCHW"
attr['strides'] = [attr['strides'][ii] for ii in (0, 3, 1, 2)]
flip_layout = True flip_layout = True
if attr['data_format'] == 'NHWC': if attr['data_format'] == 'NHWC':
...@@ -214,6 +216,7 @@ def _conv(opname): ...@@ -214,6 +216,7 @@ def _conv(opname):
if 'dilations' in attr: if 'dilations' in attr:
attr['dilations'] = (attr['dilations'][0], attr['dilations'][1]) attr['dilations'] = (attr['dilations'][0], attr['dilations'][1])
attr['strides'] = (attr['strides'][1], attr['strides'][2])
elif attr['data_format'] == 'NCHW': elif attr['data_format'] == 'NCHW':
depth_mult, _, kernel_h, kernel_w = weights_shape depth_mult, _, kernel_h, kernel_w = weights_shape
attr['kernel_shape'] = (weights_shape[2], weights_shape[3]) attr['kernel_shape'] = (weights_shape[2], weights_shape[3])
...@@ -226,6 +229,7 @@ def _conv(opname): ...@@ -226,6 +229,7 @@ def _conv(opname):
if 'dilations' in attr: if 'dilations' in attr:
attr['dilations'] = (attr['dilations'][2], attr['dilations'][3]) attr['dilations'] = (attr['dilations'][2], attr['dilations'][3])
attr['strides'] = (attr['strides'][2], attr['strides'][3])
else: else:
raise TypeError("Unsupported data format type : {}".format(attr['data_format'])) raise TypeError("Unsupported data format type : {}".format(attr['data_format']))
...@@ -233,9 +237,6 @@ def _conv(opname): ...@@ -233,9 +237,6 @@ def _conv(opname):
if opname == 'depthwise': if opname == 'depthwise':
attr['groups'] = attr['channels'] attr['groups'] = attr['channels']
# Fix strides
attr['strides'] = (attr['strides'][1], attr['strides'][2])
# Fix padding # Fix padding
attr['padding'] = attr['padding'].decode("utf-8") attr['padding'] = attr['padding'].decode("utf-8")
...@@ -416,12 +417,27 @@ def _fused_batch_norm(): ...@@ -416,12 +417,27 @@ def _fused_batch_norm():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
# Tensorflow: (data, gamma, beta, moving_mean, moving_variance) # Tensorflow: (data, gamma, beta, moving_mean, moving_variance)
# NNVM: (data, gamma, beta, moving_mean, moving_varience) # NNVM: (data, gamma, beta, moving_mean, moving_varience)
return AttrCvt( axis = 3
op_name='batch_norm', need_cast = False
transforms={'scale_after_normalization':'scale', 'variance_epsilon':'epsilon'},
extras={'axis': 3}, # Fix axis if 'data_format' in attr:
ignores=['data_format'], attr['data_format'] = attr['data_format'].decode("utf-8")
if attr['data_format'] == 'NCHW':
axis = 1
if 'U' in attr:
need_cast = True
inputs[0] = _sym.cast(inputs[0], dtype=attr['U'].name)
out = AttrCvt(op_name='batch_norm',
transforms={'scale_after_normalization':'scale',
'variance_epsilon':'epsilon'},
extras={'axis': axis},
ignores=['data_format', 'U'],
disables=['momentum'])(inputs, attr) disables=['momentum'])(inputs, attr)
if need_cast:
out = _sym.cast(out, dtype=attr['T'].name)
return out
return _impl return _impl
def _batch_norm(): def _batch_norm():
...@@ -432,10 +448,16 @@ def _batch_norm(): ...@@ -432,10 +448,16 @@ def _batch_norm():
# (data, gamma, beta, moving_mean, moving_var) # (data, gamma, beta, moving_mean, moving_var)
new_inputs = [inputs[0], inputs[4], inputs[3], inputs[1], inputs[2]] new_inputs = [inputs[0], inputs[4], inputs[3], inputs[1], inputs[2]]
axis = 3
if 'data_format' in attr:
attr['data_format'] = attr['data_format'].decode("utf-8")
if attr['data_format'] == 'NCHW':
axis = 1
return AttrCvt( return AttrCvt(
op_name='batch_norm', op_name='batch_norm',
transforms={'scale_after_normalization':'scale', 'variance_epsilon':'epsilon'}, transforms={'scale_after_normalization':'scale', 'variance_epsilon':'epsilon'},
extras={'axis': 3}, # Fix axis extras={'axis': axis},
ignores=['data_format'], ignores=['data_format'],
disables=['momentum'])(new_inputs, attr) disables=['momentum'])(new_inputs, attr)
return _impl return _impl
...@@ -729,6 +751,14 @@ def _selu(): ...@@ -729,6 +751,14 @@ def _selu():
return gamma * (-alpha * _sym.relu(1 - _sym.exp(inputs[0])) + _sym.relu(inputs[0])) return gamma * (-alpha * _sym.relu(1 - _sym.exp(inputs[0])) + _sym.relu(inputs[0]))
return _impl return _impl
def _mean():
def _impl(inputs, attr, params):
axis = params.pop(inputs[1].list_output_names()[0])
return AttrCvt(op_name="mean", ignores=['Tdim', 'Tidx'],
transforms={'keep_dims': 'keepdims'},
extras={'axis': tuple(axis.asnumpy())})(inputs[0], attr)
return _impl
# compatible operators that do NOT require any conversion. # compatible operators that do NOT require any conversion.
_identity_list = [] _identity_list = []
...@@ -773,6 +803,7 @@ _convert_map = { ...@@ -773,6 +803,7 @@ _convert_map = {
'Rsqrt' : _rsqrt(), 'Rsqrt' : _rsqrt(),
'Squeeze' : _squeeze(), 'Squeeze' : _squeeze(),
'FusedBatchNorm' : _fused_batch_norm(), 'FusedBatchNorm' : _fused_batch_norm(),
'FusedBatchNormV2' : _fused_batch_norm(),
'Relu6' : _relu6(), 'Relu6' : _relu6(),
'DepthwiseConv2dNative' : _conv('depthwise'), 'DepthwiseConv2dNative' : _conv('depthwise'),
'Shape' : _shape(), 'Shape' : _shape(),
...@@ -787,6 +818,7 @@ _convert_map = { ...@@ -787,6 +818,7 @@ _convert_map = {
'Rank' : _rank(), 'Rank' : _rank(),
'Transpose' : _transpose(), 'Transpose' : _transpose(),
'Tanh' : AttrCvt('tanh'), 'Tanh' : AttrCvt('tanh'),
'Mean' : _mean(),
} }
# _convert_map_rnn defines maps of rnn operator name to # _convert_map_rnn defines maps of rnn operator name to
......
...@@ -88,7 +88,7 @@ def run_tf_graph(sess, input_data, input_node, output_node): ...@@ -88,7 +88,7 @@ def run_tf_graph(sess, input_data, input_node, output_node):
return output_data return output_data
def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False): def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False, no_gpu=False):
"""Generic function to generate and compare tensorflow and TVM output""" """Generic function to generate and compare tensorflow and TVM output"""
out_node = out_name.split(':')[0] if ":" in out_name else out_name out_node = out_name.split(':')[0] if ":" in out_name else out_name
...@@ -116,6 +116,8 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False) ...@@ -116,6 +116,8 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False)
if not ctx.exist: if not ctx.exist:
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
continue continue
if no_gpu and device == 'cuda':
continue
tvm_output = run_tvm_graph(final_graph_def, in_data, tvm_output = run_tvm_graph(final_graph_def, in_data,
in_node, tf_output.shape, tf_output.dtype, target=device) in_node, tf_output.shape, tf_output.dtype, target=device)
...@@ -123,10 +125,20 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False) ...@@ -123,10 +125,20 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False)
sess.close() sess.close()
def is_gpu_available():
from tensorflow.python.client import device_lib
local_device_protos = device_lib.list_local_devices()
gpu_list = [x.name for x in local_device_protos if x.device_type == 'GPU']
if len(gpu_list) < 0:
print("Tensorflow GPU:", gpu_list)
return True
else:
return False
####################################################################### #######################################################################
# Pooling # Pooling
# ------- # -------
def _test_pooling(input_shape, **kwargs): def _test_pooling_iteration(input_shape, **kwargs):
""" One iteration of pool operation with given shapes and attributes """ """ One iteration of pool operation with given shapes and attributes """
x = -np.arange( x = -np.arange(
...@@ -143,61 +155,45 @@ def _test_pooling(input_shape, **kwargs): ...@@ -143,61 +155,45 @@ def _test_pooling(input_shape, **kwargs):
compare_tf_with_tvm(x, 'Placeholder:0', out_name) compare_tf_with_tvm(x, 'Placeholder:0', out_name)
def _test_pooling(input_shape, **kwargs):
_test_pooling_iteration(input_shape, **kwargs)
if is_gpu_available():
input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)]
kwargs['data_layout'] = 'NCHW'
_test_pooling_iteration(input_shape, **kwargs)
def test_forward_pooling(): def test_forward_pooling():
""" Pooling """ """ Pooling """
for pool_type in ['AVG', 'MAX']:
_test_pooling(input_shape=[2, 9, 10, 2], _test_pooling(input_shape=[2, 9, 10, 2],
window_shape=[1, 1], window_shape=[1, 1],
padding='SAME', padding='SAME',
pooling_type='MAX', pooling_type=pool_type,
dilation_rate=[1, 1],
strides=[1, 1])
_test_pooling(input_shape=[2, 9, 10, 2],
window_shape=[1, 1],
padding='SAME',
pooling_type='AVG',
dilation_rate=[1, 1], dilation_rate=[1, 1],
strides=[1, 1]) strides=[1, 1])
_test_pooling(input_shape=[2, 10, 9, 2], _test_pooling(input_shape=[2, 10, 9, 2],
window_shape=[1, 1], window_shape=[1, 1],
padding='SAME', padding='SAME',
pooling_type='MAX', pooling_type=pool_type,
dilation_rate=[1, 1],
strides=[1, 1])
_test_pooling(input_shape=[2, 10, 9, 2],
window_shape=[1, 1],
padding='SAME',
pooling_type='AVG',
dilation_rate=[1, 1], dilation_rate=[1, 1],
strides=[1, 1]) strides=[1, 1])
_test_pooling(input_shape=[2, 9, 10, 2], _test_pooling(input_shape=[2, 9, 10, 2],
window_shape=[2, 1], window_shape=[2, 1],
padding='SAME', padding='SAME',
pooling_type='MAX', pooling_type=pool_type,
dilation_rate=[1, 1], dilation_rate=[1, 1],
strides=[1, 1]) strides=[1, 1])
_test_pooling(input_shape=[2, 9, 10, 2],
window_shape=[2, 1],
padding='SAME',
pooling_type='AVG',
dilation_rate=[1, 1],
strides=[2, 1])
_test_pooling(input_shape=[2, 10, 9, 2], _test_pooling(input_shape=[2, 10, 9, 2],
window_shape=[2, 3], window_shape=[2, 3],
padding='SAME', padding='SAME',
pooling_type='MAX', pooling_type=pool_type,
dilation_rate=[1, 1], dilation_rate=[1, 1],
strides=[2, 1]) strides=[2, 1])
_test_pooling(input_shape=[2, 10, 9, 2],
window_shape=[2, 3],
padding='SAME',
pooling_type='AVG',
dilation_rate=[1, 1],
strides=[1, 2])
####################################################################### #######################################################################
# Convolution # Convolution
...@@ -234,6 +230,12 @@ def _test_convolution(tensor_in_sizes, filter_in_sizes, ...@@ -234,6 +230,12 @@ def _test_convolution(tensor_in_sizes, filter_in_sizes,
'Placeholder:0', 'Conv2D:0') 'Placeholder:0', 'Conv2D:0')
def test_forward_convolution(): def test_forward_convolution():
if is_gpu_available():
_test_convolution([4, 176, 8, 8], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NCHW')
_test_convolution([4, 19, 17, 17], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NCHW')
_test_convolution([4, 124, 17, 17], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME', 'NCHW')
_test_convolution([4, 12, 17, 17], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', 'NCHW')
_test_convolution([4, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NHWC') _test_convolution([4, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NHWC')
_test_convolution([4, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NHWC') _test_convolution([4, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NHWC')
_test_convolution([4, 17, 17, 124], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME', 'NHWC') _test_convolution([4, 17, 17, 124], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME', 'NHWC')
...@@ -712,6 +714,25 @@ def test_forward_mobilenet(): ...@@ -712,6 +714,25 @@ def test_forward_mobilenet():
np.testing.assert_allclose(np.squeeze(tvm_output), np.squeeze(tf_output), rtol=1e-5, atol=1e-5) np.testing.assert_allclose(np.squeeze(tvm_output), np.squeeze(tf_output), rtol=1e-5, atol=1e-5)
####################################################################### #######################################################################
# ResnetV2
# ---------
def test_forward_resnetv2():
'''test resnet model'''
if is_gpu_available():
with tf.Graph().as_default():
graph_def = nnvm.testing.tf.get_workload("ResnetV2/resnet-20180601_resnet_v2_imagenet-shapes.pb")
# Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
data = np.random.uniform(size=(128, 224, 224, 3)).astype('float32')
out_node = 'ArgMax'
with tf.Session() as sess:
tf_output = run_tf_graph(sess, data, 'input_tensor:0', out_node + ':0')
tvm_output = run_tvm_graph(graph_def, data, 'input_tensor', tf_output.shape, 'float32')
np.testing.assert_allclose(np.squeeze(tvm_output), np.squeeze(tf_output), rtol=1e-5, atol=1e-5)
#######################################################################
# PTB # PTB
# --- # ---
dir(tf.contrib) dir(tf.contrib)
...@@ -947,37 +968,69 @@ def test_forward_tanh(): ...@@ -947,37 +968,69 @@ def test_forward_tanh():
compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Tanh:0') compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Tanh:0')
####################################################################### #######################################################################
# Mean
# ----
def test_forward_mean():
def check_mean(ishape, **kwargs):
inp_array = np.random.uniform(size=ishape).astype(np.float32)
with tf.Graph().as_default():
in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype)
tf.keras.backend.mean(in1, **kwargs)
compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Mean:0', no_gpu=True)
check_mean((10, 8, 16, 32))
check_mean((10, 8, 16, 32), axis=(2,3))
check_mean((10, 8, 16, 32), axis=(1,2), keepdims=True)
#######################################################################
# Main # Main
# ---- # ----
if __name__ == '__main__': if __name__ == '__main__':
# Transforms
test_forward_transpose() test_forward_transpose()
test_forward_convolution()
test_forward_pooling()
test_forward_reshape() test_forward_reshape()
test_forward_squeeze() test_forward_squeeze()
test_forward_pack()
test_forward_resize_bilinear()
test_forward_pad()
test_forward_gather()
#test_forward_stridedslice()
# Activations
test_forward_sigmoid() test_forward_sigmoid()
test_forward_relu()
test_forward_leaky_relu()
test_forward_elu()
test_forward_selu()
test_forward_tanh()
# Reductions
test_forward_argminmax() test_forward_argminmax()
test_forward_reduce() test_forward_reduce()
test_forward_mean()
# NN
test_forward_convolution()
test_forward_pooling()
if tf.__version__ == '1.4.1': if tf.__version__ == '1.4.1':
_test_forward_concat_v2() _test_forward_concat_v2()
test_forward_lrn()
test_forward_l2_normalize()
# General
test_forward_multi_input() test_forward_multi_input()
test_forward_pack() test_forward_variable()
# End to End
test_forward_inception_v3() test_forward_inception_v3()
test_forward_inception_v1() test_forward_inception_v1()
test_forward_mobilenet() test_forward_mobilenet()
test_forward_variable() test_forward_resnetv2()
test_forward_resize_bilinear()
test_forward_pad()
#test_forward_lstm()
#test_forward_stridedslice()
test_forward_gather()
test_forward_ptb() test_forward_ptb()
test_forward_lrn()
test_forward_l2_normalize() # RNN
#test_forward_lstm()
# Elementwise
test_forward_ceil() test_forward_ceil()
test_forward_floor() test_forward_floor()
test_forward_relu()
test_forward_leaky_relu()
test_forward_elu()
test_forward_selu()
test_forward_tanh()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment