Commit 770ac84e by Alexey Romanov Committed by Tianqi Chen

[Relay][Frontend] Simplify parameter handling in Tensorflow frontend (#2993)

parent 5999f7a6
......@@ -63,7 +63,7 @@ def _get_relay_op(op_name):
return op
class AttrCvt(object):
"""Common attribute conveter. An AttrConverter instance is a callable:
"""Common attribute converter. An AttrConverter instance is a callable:
```
attr_converter = AttrConverter(op_name, transforms={'a':'b', 'c':('d', 1)})
new_op_name, new_attr = attr_converter(attrs)
......@@ -222,17 +222,37 @@ def _dimension_constraint():
return False
return _dim_check, "Only 2d kernel supported."
def _infer_channels(inputs, params, transpose=False):
"""A hack for getting 'channles' or 'units' since tensorflow don't provide
def _infer_channels(node, params, transpose=False):
"""A hack for getting 'channels' or 'units' since tensorflow don't provide
these attributes. We check the shape of weights provided to get the number.
"""
out_type = ir_pass.infer_type(inputs)
out_shapes = [get_const_tuple(out_type.checked_type.shape)]
channels = out_shapes[0][0] if not transpose else out_shapes[0][1]
out_shape = _infer_shape(node, params)
channels = out_shape[0] if not transpose else out_shape[1]
return channels
def _infer_out_shapes(inputs, params):
"""A method to get the output shape of intermediate nodes in the relay graph."""
return [_infer_shape(inputs, params)]
def _infer_shape(node, params=None):
"""A method to get the output shape of an intermediate node in the relay graph."""
out_type = ir_pass.infer_type(node)
return get_const_tuple(out_type.checked_type.shape)
def _get_param(params, input_node):
return params.pop(input_node.name_hint).asnumpy()
def _get_num_param(params, input_node):
return _get_param(params, input_node)[0]
def _get_list_param(params, input_node):
return _get_param(params, input_node).tolist()
def _get_tuple_param(params, input_node):
return tuple(_get_param(params, input_node))
def _rsqrt():
def _impl(inputs, attr, *args):
def _impl(inputs, attr, params):
inputs.append(tvm.relay.const(-0.5, attr['T'].name))
return AttrCvt(op_name="power")(inputs, attr)
return _impl
......@@ -243,16 +263,15 @@ def _argx(func, func_name):
try:
# In Tensorflow, `axis` argument is a Tensor, not attribute. We
# support the case where it inputs from a scalar constant.
axis_input_name = inputs[1].name_hint
axis_input_vlaue = [params[axis_input_name].asnumpy()[0]]
axis_input_value = [_get_num_param(params, inputs[1])]
except (IndexError, KeyError):
raise TypeError( \
"Unsupported argument for `{}` : `axis` should be a constant".format(func_name))
return func(inputs[0], axis=axis_input_vlaue, keepdims=False)
return func(inputs[0], axis=axis_input_value, keepdims=False)
return _impl
def _elemwise(name):
def _impl(inputs, attr, *args):
def _impl(inputs, attr, params):
assert len(inputs) == 2, "{} take 2 inputs, {} given".format(name, len(inputs))
return _get_relay_op(name)(*inputs)
return _impl
......@@ -472,7 +491,7 @@ def _cast():
def _expand_dims():
def _impl(inputs, attr, params):
dim_input = inputs.pop(1)
axis = params.pop(_get_name_hint(dim_input)).asnumpy()[0]
axis = _get_num_param(params, dim_input)
return AttrCvt(op_name="expand_dims", ignores=['Tdim', 'N'],
extras={'axis': int(axis), 'num_newaxis': 1})(inputs, attr)
return _impl
......@@ -527,21 +546,19 @@ def _identity():
def _concatV2():
def _impl(inputs, attr, params):
pop_node = inputs.pop(len(inputs)-1)
axis = params[pop_node.name_hint]
params.pop(pop_node.name_hint)
axis = int(_get_num_param(params, pop_node))
return AttrCvt(
op_name="concatenate", ignores=['T', 'N', 'Tidx'],
extras={'axis': int(axis.asnumpy()[0])})([inputs], attr)
extras={'axis': axis})([inputs], attr)
return _impl
def _concat():
def _impl(inputs, attr, params):
pop_node = inputs.pop(0)
axis = params[pop_node.name_hint]
params.pop(pop_node.name_hint)
axis = int(_get_num_param(params, pop_node))
return AttrCvt(
op_name="concatenate", ignores=['N'],
extras={'axis': int(axis.asnumpy()[0])})([inputs], attr)
extras={'axis': axis})([inputs], attr)
return _impl
def _pack():
......@@ -565,8 +582,8 @@ def _tile():
def _slice():
def _impl(inputs, attr, params):
begin = params.pop(_get_name_hint(inputs[1])).asnumpy().tolist()
size = params.pop(_get_name_hint(inputs[2])).asnumpy().tolist()
begin = _get_list_param(params, inputs[1])
size = _get_list_param(params, inputs[2])
data_shape = attr['_input_shapes'][inputs[0]]
data_dim = len(data_shape)
end = size
......@@ -581,24 +598,18 @@ def _slice():
def _reshape():
def _impl(inputs, attr, params):
pop_node = inputs.pop(1)
try:
pop_node = inputs[1]
shape_arg = params.pop(pop_node.name_hint)
inputs.pop(1)
return AttrCvt(
op_name="reshape",
extras={'newshape':tuple(shape_arg.asnumpy())},
ignores=['Tshape'])(inputs, attr)
shape_arg = _get_tuple_param(params, pop_node)
except AttributeError:
# Shape operator is already pruned, hence
# try to infer shape by precompute prune if possible.
params_new = _infer_value(inputs[1], params)
inputs.pop(1)
return AttrCvt(
op_name="reshape",
extras={'newshape':tuple(params_new.asnumpy().astype('int64').flatten())},
ignores=['Tshape'])(inputs, attr)
params_new = _infer_value(pop_node, params)
shape_arg = tuple(params_new.asnumpy().astype('int64').flatten())
return AttrCvt(
op_name="reshape",
extras={'newshape': shape_arg},
ignores=['Tshape'])(inputs, attr)
return _impl
......@@ -737,9 +748,10 @@ def _fill():
if -1 in output_shape:
output_shape = _infer_value(inputs[0], params).asnumpy().reshape([-1]).tolist()
fill_arg = params.pop(inputs.pop(1).name_hint)
return _op.full(tvm.relay.const(fill_arg.asnumpy()[0], attr['T'].name),
output_shape, attr['T'].name)
fill_arg = _get_num_param(params, inputs.pop(1))
dtype = attr['T'].name
return _op.full(tvm.relay.const(fill_arg, dtype),
output_shape, dtype)
return _impl
def _lrn():
......@@ -757,9 +769,7 @@ def _lrn():
def _sum():
def _impl(inputs, attr, params):
axis = params.pop(inputs[1].name_hint).asnumpy()
# convert to tuple for preventing invalid parameter format error
axis = tuple(axis)
axis = _get_tuple_param(params, inputs[1])
return AttrCvt(
op_name='sum',
extras={'axis': axis},
......@@ -786,25 +796,17 @@ def _square():
def _gather():
"GatherV2, Gather"
def _impl(inputs, attr, params):
axis = 0
if len(inputs) > 2:
axis = params[inputs.pop(2).name_hint].asnumpy()[0]
new_input = []
new_input.append(inputs.pop(0))
new_input.append(inputs.pop(0))
axis = _get_num_param(params, inputs.pop(2))
else:
axis = 0
new_input = inputs[0:2]
return AttrCvt(op_name="take",
extras={'axis': tvm.const(axis, 'int32')},
ignores=['Tindices', 'Tparams', 'validate_indices', \
ignores=['Tindices', 'Tparams', 'validate_indices',
'Taxis', '_class'])(new_input, attr)
return _impl
def _infer_out_shapes(inputs, params):
"""A method to get the output shape of an intermediate node in the relay graph."""
out_type = ir_pass.infer_type(inputs)
out_shapes = [get_const_tuple(out_type.checked_type.shape)]
return out_shapes
def _stridedSlice():
def _impl(inputs, attr, params):
"""Strided Slice.
......@@ -812,9 +814,9 @@ def _stridedSlice():
Tensorflow mask validation: https://github.com/tensorflow/tensorflow/blob/master/
tensorflow/core/util/strided_slice_op.cc#L147-L368
"""
begin = params.pop(inputs[1].name_hint).asnumpy().tolist()
end = params.pop(inputs[2].name_hint).asnumpy().tolist()
stride = params.pop(inputs[3].name_hint).asnumpy().tolist()
begin = _get_list_param(params, inputs[1])
end = _get_list_param(params, inputs[2])
stride = _get_list_param(params, inputs[3])
begin_mask = int(attr.get('begin_mask', 0))
end_mask = int(attr.get('end_mask', 0))
ellipsis_mask = int(attr.get('ellipsis_mask', 0))
......@@ -889,7 +891,7 @@ def _stridedSlice():
if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask:
begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask)
out = _op.strided_slice(inputs[0], begin=begin, end=end, strides=stride)
out_shape = _infer_out_shapes(out, params)[0]
out_shape = _infer_shape(out, params)
if not fshape_indices:
fshape_indices = range(len(out_shape))
......@@ -910,19 +912,14 @@ def _stridedSlice():
def _pad(name):
def _impl(inputs, attr, params):
padlist_key = inputs[1].name_hint
if padlist_key in params:
padlist = params.pop(padlist_key).asnumpy()
else:
raise tvm.error.OpAttributeRequired(
'Attribute {} not found in operator Pad.'.format(padlist_key))
paddings = tuple([tuple(l) for l in padlist])
padlist = _get_param(params, inputs[1])
paddings = tuple(tuple(l) for l in padlist)
attr['pad_width'] = paddings
attr['pad_value'] = 0
new_inputs = [inputs[0]]
if name == 'PadV2':
constant_values = params.pop(inputs[2].name_hint).asnumpy()
attr['pad_value'] = constant_values[0]
constant_values = _get_num_param(params, inputs[2])
attr['pad_value'] = constant_values
return AttrCvt(
op_name='pad',
ignores=['Tpaddings'],)(new_inputs, attr)
......@@ -932,10 +929,9 @@ def _transpose():
def _impl(inputs, attr, params):
# If perm is not specified, axes is left empty,
# otherwise its value is get from params
param_name = _get_name_hint(inputs[1])
if param_name in params:
axes = tuple(params.get(param_name).asnumpy())
else:
try:
axes = _get_list_param(params, inputs[1])
except (IndexError, KeyError):
axes = None
return _op.transpose(inputs[0], axes=axes)
return _impl
......@@ -947,7 +943,7 @@ def _where():
def _reverse_v2():
def _impl(inputs, attr, params):
axis = params.pop(inputs[1].name_hint).asnumpy()[0]
axis = _get_num_param(params, inputs[1])
return AttrCvt(
op_name="reverse",
ignores=['Tidx'],
......@@ -968,9 +964,9 @@ def _rank():
def _range():
def _impl(inputs, attr, params):
start = params.pop(inputs[0].name_hint).asnumpy()[0]
limit = params.pop(inputs[1].name_hint).asnumpy()[0]
delta = params.pop(inputs[2].name_hint).asnumpy()[0]
start = _get_num_param(params, inputs[0])
limit = _get_num_param(params, inputs[1])
delta = _get_num_param(params, inputs[2])
name = attr["_node_name"]
params[name] = tvm.nd.array([start, limit, delta])
......@@ -981,25 +977,27 @@ def _range():
def _elu():
def _impl(inputs, attr, params):
alpha = tvm.relay.const(-1.0, attr['T'].name)
return alpha * _op.nn.relu(tvm.relay.const(1, attr['T'].name) \
dtype = attr['T'].name
alpha = tvm.relay.const(-1.0, dtype)
return alpha * _op.nn.relu(tvm.relay.const(1, dtype) \
- _op.exp(inputs[0])) + _op.nn.relu(inputs[0])
return _impl
def _selu():
def _impl(inputs, attr, params):
alpha = tvm.relay.const(-1.6732632423543772848170429916717, attr['T'].name)
gamma = tvm.relay.const(1.0507009873554804934193349852946, attr['T'].name)
return gamma * (alpha * _op.nn.relu(tvm.relay.const(1, attr['T'].name) \
dtype = attr['T'].name
alpha = tvm.relay.const(-1.6732632423543772848170429916717, dtype)
gamma = tvm.relay.const(1.0507009873554804934193349852946, dtype)
return gamma * (alpha * _op.nn.relu(tvm.relay.const(1, dtype) \
- _op.exp(inputs[0])) + _op.nn.relu(inputs[0]))
return _impl
def _mean():
def _impl(inputs, attr, params):
axis = params.pop(inputs[1].name_hint)
axis = _get_tuple_param(params, inputs[1])
return AttrCvt(op_name="mean", ignores=['Tdim', 'Tidx'],
transforms={'keep_dims': 'keepdims'},
extras={'axis': tuple(axis.asnumpy())})([inputs[0]], attr)
extras={'axis': axis})([inputs[0]], attr)
return _impl
def _broadcast(name):
......@@ -1025,8 +1023,7 @@ def _split(has_size_vector):
if has_size_vector:
input_node_index = 0
input_axis_index = 2
size_splits_input_name = _get_name_hint(inputs[1])
size_splits = params[size_splits_input_name].asnumpy()
size_splits = _get_param(params, inputs[1])
section_beginnings = np.cumsum(size_splits)[:-1]
indices_or_sections = tuple(section_beginnings)
else:
......@@ -1034,8 +1031,7 @@ def _split(has_size_vector):
input_axis_index = 0
indices_or_sections = attr['num_split']
input_node = inputs[input_node_index]
axis_input_name = _get_name_hint(inputs[input_axis_index])
axis_input_value = params[axis_input_name].asnumpy()[0]
axis_input_value = _get_num_param(params, inputs[input_axis_index])
except (IndexError, KeyError):
raise TypeError( \
"Unsupported argument for split: `axis` and `num_or_size_splits` " \
......@@ -1105,8 +1101,8 @@ def _space_to_batch_nd():
def _impl(inputs, attr, params):
input_node = inputs[0]
input_shape = attr['_input_shapes'][input_node]
block_shape = params.pop(inputs[1].name_hint).asnumpy().tolist()
paddings = params.pop(inputs[2].name_hint).asnumpy().tolist()
block_shape = _get_list_param(params, inputs[1])
paddings = _get_list_param(params, inputs[2])
N = len(input_shape)
M = len(block_shape)
batch = input_shape[0]
......@@ -1127,7 +1123,7 @@ def _space_to_batch_nd():
axes = [2 * i + 2 for i in range(M)] + [0] + [2 * i + 1 for i in range(M)] + \
list(range(1 + 2 * M, 1 + 2 * M + remaining_shape_length))
permuted_reshaped_padded = tvm.relay.transpose(reshaped_padded, axes=axes)
permuted_reshaped_padded_shape = _infer_out_shapes(permuted_reshaped_padded, params)[0]
permuted_reshaped_padded_shape = _infer_shape(permuted_reshaped_padded, params)
# Reshape permuted_reshaped_padded to flatten block_shape into the batch dimension,
# producing an output tensor of shape:
# [batch * prod(block_shape)] + [padded_shape[1] / block_shape[0], ...,
......@@ -1144,8 +1140,8 @@ def _batch_to_space_nd():
def _impl(inputs, attr, params):
input_node = inputs[0]
input_shape = attr['_input_shapes'][input_node]
block_shape = params.pop(inputs[1].name_hint).asnumpy().tolist()
crops = params.pop(inputs[2].name_hint).asnumpy().tolist()
block_shape = _get_list_param(params, inputs[1])
crops = _get_list_param(params, inputs[2])
M = len(block_shape)
batch = input_shape[0]
# From https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d:
......@@ -1170,7 +1166,7 @@ def _batch_to_space_nd():
# [batch / prod(block_shape), input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1],
# ..., input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1],
# input_shape[M+1], ..., input_shape[N-1]]
reshaped_permuted_shape = _infer_out_shapes(reshaped_permuted, params)[0]
reshaped_permuted_shape = _infer_shape(reshaped_permuted, params)
cropped = reshaped_permuted
for axis in range(1, M+1):
crop = crops[axis - 1]
......@@ -1971,23 +1967,17 @@ class GraphProto(object):
# Infer shapes even without specifying "add_shapes=True"
if output_shapes == [None]:
out_shapes = []
for node_item in self._nodes[node.name]:
out_type = ir_pass.infer_type(node_item)
out_shapes.append(get_const_tuple(out_type.checked_type.shape))
out_shapes = [_infer_shape(node_item) for node_item in self._nodes[node.name]]
self._output_shapes[node.name] = out_shapes
if self._output_shapes[node.name] and shape and node.name in shape:
assert self._output_shapes[node.name] == list(shape[node.name])
# Infer shapes if passed explicitely
# Infer shapes if passed explicitly
node_output = self._nodes[node.name]
if shape and (not self._output_shapes[node.name][0]
or -1 in self._output_shapes[node.name][0]):
out_shapes = []
for node_item in node_output:
out_type = ir_pass.infer_type(node_item)
out_shapes.append(get_const_tuple(out_type.checked_type.shape))
out_shapes = [_infer_shape(node_item) for node_item in node_output]
self._output_shapes[node.name] = out_shapes
out = []
......
......@@ -56,31 +56,23 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
layout = None
if target == "cuda":
layout = "NCHW"
target_host = 'llvm'
if isinstance(input_data, list):
shape_dict = {}
dtype_dict = {}
for i, e in enumerate(input_node):
shape_dict[e] = input_data[i].shape
dtype_dict[e] = input_data[i].dtype
else:
shape_dict = {input_node: input_data.shape}
dtype_dict = {input_node: input_data.dtype}
target_host = None
shape_dict = {e: i.shape for e, i in zip(input_node, input_data)}
sym, params = relay.frontend.from_tensorflow(graph_def,
layout=layout,
shape=shape_dict,
outputs=out_names)
with relay.build_config(opt_level=opt_level):
graph, lib, params = relay.build(sym, target, params=params)
graph, lib, params = relay.build(sym, target, target_host, params)
ctx = tvm.context(target, 0)
from tvm.contrib import graph_runtime
m = graph_runtime.create(graph, lib, ctx)
# set inputs
for i, e in enumerate(input_node):
m.set_input(e, tvm.nd.array(input_data[i].astype(input_data[i].dtype)))
for e, i in zip(input_node, input_data):
m.set_input(e, tvm.nd.array(i))
m.set_input(**params)
# execute
......@@ -88,10 +80,7 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
# get outputs
assert out_names is None or num_output == len(out_names), (
"out_names: {} num_output: {}".format(out_names, num_output))
tvm_output_list = []
for i in range(0, num_output):
tvm_output = m.get_output(i)
tvm_output_list.append(tvm_output.asnumpy())
tvm_output_list = [m.get_output(i).asnumpy() for i in range(num_output)]
return tvm_output_list
def run_tf_graph(sess, input_data, input_node, output_node):
......@@ -100,13 +89,9 @@ def run_tf_graph(sess, input_data, input_node, output_node):
input_node = convert_to_list(input_node)
output_node = convert_to_list(output_node)
tensor = [0] * len(output_node)
for i in range(len(output_node)):
tensor[i] = sess.graph.get_tensor_by_name(output_node[i])
tensor = [sess.graph.get_tensor_by_name(output_name) for output_name in output_node]
input_dict = {}
for i, e in enumerate(input_node):
input_dict[e] = input_data[i]
input_dict = {e: input_data[i] for i, e in enumerate(input_node)}
output_data = sess.run(tensor, input_dict)
return output_data
......@@ -115,17 +100,15 @@ def run_tf_graph(sess, input_data, input_node, output_node):
def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False,
no_gpu=False, opt_level=3):
"""Generic function to generate and compare tensorflow and TVM output"""
def name_without_num(name):
return name.split(':')[0] if ":" in name else name
out_name = convert_to_list(out_name)
out_node = [0]*len(out_name)
for i in range(len(out_name)):
out_node[i] = out_name[i].split(':')[0] if ":" in out_name[i] else out_name[i]
out_node = [name_without_num(name) for name in out_name]
in_data = convert_to_list(in_data)
in_name = convert_to_list(in_name)
in_node = [0]*len(in_name)
for i in range(len(in_name)):
in_node[i] = in_name[i].split(':')[0] if ":" in in_name[i] else in_name[i]
in_node = [name_without_num(name) for name in in_name]
with tf.Session() as sess:
if init_global_variables:
sess.run(variables.global_variables_initializer())
......@@ -578,6 +561,38 @@ def test_forward_variable():
#######################################################################
# MatMul
# ------
def _test_matmul(i, j, k, dtype, outer=None):
""" One iteration of matmul """
A_shape_init = [i, j]
B_shape_init = [j, k]
for transpose_a in [False, True]:
for transpose_b in [False, True]:
outer = outer or []
A_shape = outer + (A_shape_init[::-1] if transpose_a else A_shape_init)
B_shape = outer + (B_shape_init[::-1] if transpose_b else B_shape_init)
with tf.Graph().as_default():
A = tf.placeholder(shape=A_shape, dtype=dtype, name='A')
B = tf.placeholder(shape=B_shape, dtype=dtype, name='B')
result = tf.matmul(A, B, transpose_a=transpose_a, transpose_b=transpose_b)
A_np = np.random.uniform(high=5.0, size=A_shape).astype(dtype)
B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype)
compare_tf_with_tvm([A_np, B_np], [A.name, B.name], result.name)
def test_forward_matmul():
""" Matmul op test"""
_test_matmul(1, 3, 6, 'int32')
_test_matmul(5, 3, 1, 'float64')
# TODO non-empty outer requires BatchMatMul (BatchMatMulV2 for some cases?) support
#######################################################################
# StridedSlice
# ------------
......@@ -1785,3 +1800,6 @@ if __name__ == '__main__':
test_forward_rel_ops()
test_forward_logical()
test_where()
test_forward_matmul()
# TODO missing tests: rank, range
\ No newline at end of file
......@@ -151,11 +151,7 @@ def get_const_tuple(in_tuple):
out_tuple : tuple of int
The output.
"""
out_tuple = ()
for elem in in_tuple:
value = get_const_int(elem)
out_tuple = out_tuple + (value, )
return out_tuple
return tuple(get_const_int(elem) for elem in in_tuple)
def get_float_tuple(in_tuple):
......@@ -171,11 +167,7 @@ def get_float_tuple(in_tuple):
out_tuple : tuple of float
The output.
"""
out_tuple = ()
for elem in in_tuple:
value = get_const_float(elem)
out_tuple = out_tuple + (value, )
return out_tuple
return tuple(get_const_float(elem) for elem in in_tuple)
def simplify(expr):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment