Commit 770ac84e by Alexey Romanov Committed by Tianqi Chen

[Relay][Frontend] Simplify parameter handling in Tensorflow frontend (#2993)

parent 5999f7a6
...@@ -63,7 +63,7 @@ def _get_relay_op(op_name): ...@@ -63,7 +63,7 @@ def _get_relay_op(op_name):
return op return op
class AttrCvt(object): class AttrCvt(object):
"""Common attribute conveter. An AttrConverter instance is a callable: """Common attribute converter. An AttrConverter instance is a callable:
``` ```
attr_converter = AttrConverter(op_name, transforms={'a':'b', 'c':('d', 1)}) attr_converter = AttrConverter(op_name, transforms={'a':'b', 'c':('d', 1)})
new_op_name, new_attr = attr_converter(attrs) new_op_name, new_attr = attr_converter(attrs)
...@@ -222,17 +222,37 @@ def _dimension_constraint(): ...@@ -222,17 +222,37 @@ def _dimension_constraint():
return False return False
return _dim_check, "Only 2d kernel supported." return _dim_check, "Only 2d kernel supported."
def _infer_channels(inputs, params, transpose=False): def _infer_channels(node, params, transpose=False):
"""A hack for getting 'channles' or 'units' since tensorflow don't provide """A hack for getting 'channels' or 'units' since tensorflow don't provide
these attributes. We check the shape of weights provided to get the number. these attributes. We check the shape of weights provided to get the number.
""" """
out_type = ir_pass.infer_type(inputs) out_shape = _infer_shape(node, params)
out_shapes = [get_const_tuple(out_type.checked_type.shape)] channels = out_shape[0] if not transpose else out_shape[1]
channels = out_shapes[0][0] if not transpose else out_shapes[0][1]
return channels return channels
def _infer_out_shapes(inputs, params):
"""A method to get the output shape of intermediate nodes in the relay graph."""
return [_infer_shape(inputs, params)]
def _infer_shape(node, params=None):
"""A method to get the output shape of an intermediate node in the relay graph."""
out_type = ir_pass.infer_type(node)
return get_const_tuple(out_type.checked_type.shape)
def _get_param(params, input_node):
return params.pop(input_node.name_hint).asnumpy()
def _get_num_param(params, input_node):
return _get_param(params, input_node)[0]
def _get_list_param(params, input_node):
return _get_param(params, input_node).tolist()
def _get_tuple_param(params, input_node):
return tuple(_get_param(params, input_node))
def _rsqrt(): def _rsqrt():
def _impl(inputs, attr, *args): def _impl(inputs, attr, params):
inputs.append(tvm.relay.const(-0.5, attr['T'].name)) inputs.append(tvm.relay.const(-0.5, attr['T'].name))
return AttrCvt(op_name="power")(inputs, attr) return AttrCvt(op_name="power")(inputs, attr)
return _impl return _impl
...@@ -243,16 +263,15 @@ def _argx(func, func_name): ...@@ -243,16 +263,15 @@ def _argx(func, func_name):
try: try:
# In Tensorflow, `axis` argument is a Tensor, not attribute. We # In Tensorflow, `axis` argument is a Tensor, not attribute. We
# support the case where it inputs from a scalar constant. # support the case where it inputs from a scalar constant.
axis_input_name = inputs[1].name_hint axis_input_value = [_get_num_param(params, inputs[1])]
axis_input_vlaue = [params[axis_input_name].asnumpy()[0]]
except (IndexError, KeyError): except (IndexError, KeyError):
raise TypeError( \ raise TypeError( \
"Unsupported argument for `{}` : `axis` should be a constant".format(func_name)) "Unsupported argument for `{}` : `axis` should be a constant".format(func_name))
return func(inputs[0], axis=axis_input_vlaue, keepdims=False) return func(inputs[0], axis=axis_input_value, keepdims=False)
return _impl return _impl
def _elemwise(name): def _elemwise(name):
def _impl(inputs, attr, *args): def _impl(inputs, attr, params):
assert len(inputs) == 2, "{} take 2 inputs, {} given".format(name, len(inputs)) assert len(inputs) == 2, "{} take 2 inputs, {} given".format(name, len(inputs))
return _get_relay_op(name)(*inputs) return _get_relay_op(name)(*inputs)
return _impl return _impl
...@@ -472,7 +491,7 @@ def _cast(): ...@@ -472,7 +491,7 @@ def _cast():
def _expand_dims(): def _expand_dims():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
dim_input = inputs.pop(1) dim_input = inputs.pop(1)
axis = params.pop(_get_name_hint(dim_input)).asnumpy()[0] axis = _get_num_param(params, dim_input)
return AttrCvt(op_name="expand_dims", ignores=['Tdim', 'N'], return AttrCvt(op_name="expand_dims", ignores=['Tdim', 'N'],
extras={'axis': int(axis), 'num_newaxis': 1})(inputs, attr) extras={'axis': int(axis), 'num_newaxis': 1})(inputs, attr)
return _impl return _impl
...@@ -527,21 +546,19 @@ def _identity(): ...@@ -527,21 +546,19 @@ def _identity():
def _concatV2(): def _concatV2():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
pop_node = inputs.pop(len(inputs)-1) pop_node = inputs.pop(len(inputs)-1)
axis = params[pop_node.name_hint] axis = int(_get_num_param(params, pop_node))
params.pop(pop_node.name_hint)
return AttrCvt( return AttrCvt(
op_name="concatenate", ignores=['T', 'N', 'Tidx'], op_name="concatenate", ignores=['T', 'N', 'Tidx'],
extras={'axis': int(axis.asnumpy()[0])})([inputs], attr) extras={'axis': axis})([inputs], attr)
return _impl return _impl
def _concat(): def _concat():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
pop_node = inputs.pop(0) pop_node = inputs.pop(0)
axis = params[pop_node.name_hint] axis = int(_get_num_param(params, pop_node))
params.pop(pop_node.name_hint)
return AttrCvt( return AttrCvt(
op_name="concatenate", ignores=['N'], op_name="concatenate", ignores=['N'],
extras={'axis': int(axis.asnumpy()[0])})([inputs], attr) extras={'axis': axis})([inputs], attr)
return _impl return _impl
def _pack(): def _pack():
...@@ -565,8 +582,8 @@ def _tile(): ...@@ -565,8 +582,8 @@ def _tile():
def _slice(): def _slice():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
begin = params.pop(_get_name_hint(inputs[1])).asnumpy().tolist() begin = _get_list_param(params, inputs[1])
size = params.pop(_get_name_hint(inputs[2])).asnumpy().tolist() size = _get_list_param(params, inputs[2])
data_shape = attr['_input_shapes'][inputs[0]] data_shape = attr['_input_shapes'][inputs[0]]
data_dim = len(data_shape) data_dim = len(data_shape)
end = size end = size
...@@ -581,24 +598,18 @@ def _slice(): ...@@ -581,24 +598,18 @@ def _slice():
def _reshape(): def _reshape():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
pop_node = inputs.pop(1)
try: try:
pop_node = inputs[1] shape_arg = _get_tuple_param(params, pop_node)
shape_arg = params.pop(pop_node.name_hint)
inputs.pop(1)
return AttrCvt(
op_name="reshape",
extras={'newshape':tuple(shape_arg.asnumpy())},
ignores=['Tshape'])(inputs, attr)
except AttributeError: except AttributeError:
# Shape operator is already pruned, hence # Shape operator is already pruned, hence
# try to infer shape by precompute prune if possible. # try to infer shape by precompute prune if possible.
params_new = _infer_value(inputs[1], params) params_new = _infer_value(pop_node, params)
inputs.pop(1) shape_arg = tuple(params_new.asnumpy().astype('int64').flatten())
return AttrCvt( return AttrCvt(
op_name="reshape", op_name="reshape",
extras={'newshape':tuple(params_new.asnumpy().astype('int64').flatten())}, extras={'newshape': shape_arg},
ignores=['Tshape'])(inputs, attr) ignores=['Tshape'])(inputs, attr)
return _impl return _impl
...@@ -737,9 +748,10 @@ def _fill(): ...@@ -737,9 +748,10 @@ def _fill():
if -1 in output_shape: if -1 in output_shape:
output_shape = _infer_value(inputs[0], params).asnumpy().reshape([-1]).tolist() output_shape = _infer_value(inputs[0], params).asnumpy().reshape([-1]).tolist()
fill_arg = params.pop(inputs.pop(1).name_hint) fill_arg = _get_num_param(params, inputs.pop(1))
return _op.full(tvm.relay.const(fill_arg.asnumpy()[0], attr['T'].name), dtype = attr['T'].name
output_shape, attr['T'].name) return _op.full(tvm.relay.const(fill_arg, dtype),
output_shape, dtype)
return _impl return _impl
def _lrn(): def _lrn():
...@@ -757,9 +769,7 @@ def _lrn(): ...@@ -757,9 +769,7 @@ def _lrn():
def _sum(): def _sum():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
axis = params.pop(inputs[1].name_hint).asnumpy() axis = _get_tuple_param(params, inputs[1])
# convert to tuple for preventing invalid parameter format error
axis = tuple(axis)
return AttrCvt( return AttrCvt(
op_name='sum', op_name='sum',
extras={'axis': axis}, extras={'axis': axis},
...@@ -786,25 +796,17 @@ def _square(): ...@@ -786,25 +796,17 @@ def _square():
def _gather(): def _gather():
"GatherV2, Gather" "GatherV2, Gather"
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
axis = 0
if len(inputs) > 2: if len(inputs) > 2:
axis = params[inputs.pop(2).name_hint].asnumpy()[0] axis = _get_num_param(params, inputs.pop(2))
new_input = [] else:
new_input.append(inputs.pop(0)) axis = 0
new_input.append(inputs.pop(0)) new_input = inputs[0:2]
return AttrCvt(op_name="take", return AttrCvt(op_name="take",
extras={'axis': tvm.const(axis, 'int32')}, extras={'axis': tvm.const(axis, 'int32')},
ignores=['Tindices', 'Tparams', 'validate_indices', \ ignores=['Tindices', 'Tparams', 'validate_indices',
'Taxis', '_class'])(new_input, attr) 'Taxis', '_class'])(new_input, attr)
return _impl return _impl
def _infer_out_shapes(inputs, params):
"""A method to get the output shape of an intermediate node in the relay graph."""
out_type = ir_pass.infer_type(inputs)
out_shapes = [get_const_tuple(out_type.checked_type.shape)]
return out_shapes
def _stridedSlice(): def _stridedSlice():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
"""Strided Slice. """Strided Slice.
...@@ -812,9 +814,9 @@ def _stridedSlice(): ...@@ -812,9 +814,9 @@ def _stridedSlice():
Tensorflow mask validation: https://github.com/tensorflow/tensorflow/blob/master/ Tensorflow mask validation: https://github.com/tensorflow/tensorflow/blob/master/
tensorflow/core/util/strided_slice_op.cc#L147-L368 tensorflow/core/util/strided_slice_op.cc#L147-L368
""" """
begin = params.pop(inputs[1].name_hint).asnumpy().tolist() begin = _get_list_param(params, inputs[1])
end = params.pop(inputs[2].name_hint).asnumpy().tolist() end = _get_list_param(params, inputs[2])
stride = params.pop(inputs[3].name_hint).asnumpy().tolist() stride = _get_list_param(params, inputs[3])
begin_mask = int(attr.get('begin_mask', 0)) begin_mask = int(attr.get('begin_mask', 0))
end_mask = int(attr.get('end_mask', 0)) end_mask = int(attr.get('end_mask', 0))
ellipsis_mask = int(attr.get('ellipsis_mask', 0)) ellipsis_mask = int(attr.get('ellipsis_mask', 0))
...@@ -889,7 +891,7 @@ def _stridedSlice(): ...@@ -889,7 +891,7 @@ def _stridedSlice():
if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask: if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask:
begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask) begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask)
out = _op.strided_slice(inputs[0], begin=begin, end=end, strides=stride) out = _op.strided_slice(inputs[0], begin=begin, end=end, strides=stride)
out_shape = _infer_out_shapes(out, params)[0] out_shape = _infer_shape(out, params)
if not fshape_indices: if not fshape_indices:
fshape_indices = range(len(out_shape)) fshape_indices = range(len(out_shape))
...@@ -910,19 +912,14 @@ def _stridedSlice(): ...@@ -910,19 +912,14 @@ def _stridedSlice():
def _pad(name): def _pad(name):
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
padlist_key = inputs[1].name_hint padlist = _get_param(params, inputs[1])
if padlist_key in params: paddings = tuple(tuple(l) for l in padlist)
padlist = params.pop(padlist_key).asnumpy()
else:
raise tvm.error.OpAttributeRequired(
'Attribute {} not found in operator Pad.'.format(padlist_key))
paddings = tuple([tuple(l) for l in padlist])
attr['pad_width'] = paddings attr['pad_width'] = paddings
attr['pad_value'] = 0 attr['pad_value'] = 0
new_inputs = [inputs[0]] new_inputs = [inputs[0]]
if name == 'PadV2': if name == 'PadV2':
constant_values = params.pop(inputs[2].name_hint).asnumpy() constant_values = _get_num_param(params, inputs[2])
attr['pad_value'] = constant_values[0] attr['pad_value'] = constant_values
return AttrCvt( return AttrCvt(
op_name='pad', op_name='pad',
ignores=['Tpaddings'],)(new_inputs, attr) ignores=['Tpaddings'],)(new_inputs, attr)
...@@ -932,10 +929,9 @@ def _transpose(): ...@@ -932,10 +929,9 @@ def _transpose():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
# If perm is not specified, axes is left empty, # If perm is not specified, axes is left empty,
# otherwise its value is get from params # otherwise its value is get from params
param_name = _get_name_hint(inputs[1]) try:
if param_name in params: axes = _get_list_param(params, inputs[1])
axes = tuple(params.get(param_name).asnumpy()) except (IndexError, KeyError):
else:
axes = None axes = None
return _op.transpose(inputs[0], axes=axes) return _op.transpose(inputs[0], axes=axes)
return _impl return _impl
...@@ -947,7 +943,7 @@ def _where(): ...@@ -947,7 +943,7 @@ def _where():
def _reverse_v2(): def _reverse_v2():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
axis = params.pop(inputs[1].name_hint).asnumpy()[0] axis = _get_num_param(params, inputs[1])
return AttrCvt( return AttrCvt(
op_name="reverse", op_name="reverse",
ignores=['Tidx'], ignores=['Tidx'],
...@@ -968,9 +964,9 @@ def _rank(): ...@@ -968,9 +964,9 @@ def _rank():
def _range(): def _range():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
start = params.pop(inputs[0].name_hint).asnumpy()[0] start = _get_num_param(params, inputs[0])
limit = params.pop(inputs[1].name_hint).asnumpy()[0] limit = _get_num_param(params, inputs[1])
delta = params.pop(inputs[2].name_hint).asnumpy()[0] delta = _get_num_param(params, inputs[2])
name = attr["_node_name"] name = attr["_node_name"]
params[name] = tvm.nd.array([start, limit, delta]) params[name] = tvm.nd.array([start, limit, delta])
...@@ -981,25 +977,27 @@ def _range(): ...@@ -981,25 +977,27 @@ def _range():
def _elu(): def _elu():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
alpha = tvm.relay.const(-1.0, attr['T'].name) dtype = attr['T'].name
return alpha * _op.nn.relu(tvm.relay.const(1, attr['T'].name) \ alpha = tvm.relay.const(-1.0, dtype)
return alpha * _op.nn.relu(tvm.relay.const(1, dtype) \
- _op.exp(inputs[0])) + _op.nn.relu(inputs[0]) - _op.exp(inputs[0])) + _op.nn.relu(inputs[0])
return _impl return _impl
def _selu(): def _selu():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
alpha = tvm.relay.const(-1.6732632423543772848170429916717, attr['T'].name) dtype = attr['T'].name
gamma = tvm.relay.const(1.0507009873554804934193349852946, attr['T'].name) alpha = tvm.relay.const(-1.6732632423543772848170429916717, dtype)
return gamma * (alpha * _op.nn.relu(tvm.relay.const(1, attr['T'].name) \ gamma = tvm.relay.const(1.0507009873554804934193349852946, dtype)
return gamma * (alpha * _op.nn.relu(tvm.relay.const(1, dtype) \
- _op.exp(inputs[0])) + _op.nn.relu(inputs[0])) - _op.exp(inputs[0])) + _op.nn.relu(inputs[0]))
return _impl return _impl
def _mean(): def _mean():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
axis = params.pop(inputs[1].name_hint) axis = _get_tuple_param(params, inputs[1])
return AttrCvt(op_name="mean", ignores=['Tdim', 'Tidx'], return AttrCvt(op_name="mean", ignores=['Tdim', 'Tidx'],
transforms={'keep_dims': 'keepdims'}, transforms={'keep_dims': 'keepdims'},
extras={'axis': tuple(axis.asnumpy())})([inputs[0]], attr) extras={'axis': axis})([inputs[0]], attr)
return _impl return _impl
def _broadcast(name): def _broadcast(name):
...@@ -1025,8 +1023,7 @@ def _split(has_size_vector): ...@@ -1025,8 +1023,7 @@ def _split(has_size_vector):
if has_size_vector: if has_size_vector:
input_node_index = 0 input_node_index = 0
input_axis_index = 2 input_axis_index = 2
size_splits_input_name = _get_name_hint(inputs[1]) size_splits = _get_param(params, inputs[1])
size_splits = params[size_splits_input_name].asnumpy()
section_beginnings = np.cumsum(size_splits)[:-1] section_beginnings = np.cumsum(size_splits)[:-1]
indices_or_sections = tuple(section_beginnings) indices_or_sections = tuple(section_beginnings)
else: else:
...@@ -1034,8 +1031,7 @@ def _split(has_size_vector): ...@@ -1034,8 +1031,7 @@ def _split(has_size_vector):
input_axis_index = 0 input_axis_index = 0
indices_or_sections = attr['num_split'] indices_or_sections = attr['num_split']
input_node = inputs[input_node_index] input_node = inputs[input_node_index]
axis_input_name = _get_name_hint(inputs[input_axis_index]) axis_input_value = _get_num_param(params, inputs[input_axis_index])
axis_input_value = params[axis_input_name].asnumpy()[0]
except (IndexError, KeyError): except (IndexError, KeyError):
raise TypeError( \ raise TypeError( \
"Unsupported argument for split: `axis` and `num_or_size_splits` " \ "Unsupported argument for split: `axis` and `num_or_size_splits` " \
...@@ -1105,8 +1101,8 @@ def _space_to_batch_nd(): ...@@ -1105,8 +1101,8 @@ def _space_to_batch_nd():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
input_node = inputs[0] input_node = inputs[0]
input_shape = attr['_input_shapes'][input_node] input_shape = attr['_input_shapes'][input_node]
block_shape = params.pop(inputs[1].name_hint).asnumpy().tolist() block_shape = _get_list_param(params, inputs[1])
paddings = params.pop(inputs[2].name_hint).asnumpy().tolist() paddings = _get_list_param(params, inputs[2])
N = len(input_shape) N = len(input_shape)
M = len(block_shape) M = len(block_shape)
batch = input_shape[0] batch = input_shape[0]
...@@ -1127,7 +1123,7 @@ def _space_to_batch_nd(): ...@@ -1127,7 +1123,7 @@ def _space_to_batch_nd():
axes = [2 * i + 2 for i in range(M)] + [0] + [2 * i + 1 for i in range(M)] + \ axes = [2 * i + 2 for i in range(M)] + [0] + [2 * i + 1 for i in range(M)] + \
list(range(1 + 2 * M, 1 + 2 * M + remaining_shape_length)) list(range(1 + 2 * M, 1 + 2 * M + remaining_shape_length))
permuted_reshaped_padded = tvm.relay.transpose(reshaped_padded, axes=axes) permuted_reshaped_padded = tvm.relay.transpose(reshaped_padded, axes=axes)
permuted_reshaped_padded_shape = _infer_out_shapes(permuted_reshaped_padded, params)[0] permuted_reshaped_padded_shape = _infer_shape(permuted_reshaped_padded, params)
# Reshape permuted_reshaped_padded to flatten block_shape into the batch dimension, # Reshape permuted_reshaped_padded to flatten block_shape into the batch dimension,
# producing an output tensor of shape: # producing an output tensor of shape:
# [batch * prod(block_shape)] + [padded_shape[1] / block_shape[0], ..., # [batch * prod(block_shape)] + [padded_shape[1] / block_shape[0], ...,
...@@ -1144,8 +1140,8 @@ def _batch_to_space_nd(): ...@@ -1144,8 +1140,8 @@ def _batch_to_space_nd():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
input_node = inputs[0] input_node = inputs[0]
input_shape = attr['_input_shapes'][input_node] input_shape = attr['_input_shapes'][input_node]
block_shape = params.pop(inputs[1].name_hint).asnumpy().tolist() block_shape = _get_list_param(params, inputs[1])
crops = params.pop(inputs[2].name_hint).asnumpy().tolist() crops = _get_list_param(params, inputs[2])
M = len(block_shape) M = len(block_shape)
batch = input_shape[0] batch = input_shape[0]
# From https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d: # From https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d:
...@@ -1170,7 +1166,7 @@ def _batch_to_space_nd(): ...@@ -1170,7 +1166,7 @@ def _batch_to_space_nd():
# [batch / prod(block_shape), input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1], # [batch / prod(block_shape), input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1],
# ..., input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1], # ..., input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1],
# input_shape[M+1], ..., input_shape[N-1]] # input_shape[M+1], ..., input_shape[N-1]]
reshaped_permuted_shape = _infer_out_shapes(reshaped_permuted, params)[0] reshaped_permuted_shape = _infer_shape(reshaped_permuted, params)
cropped = reshaped_permuted cropped = reshaped_permuted
for axis in range(1, M+1): for axis in range(1, M+1):
crop = crops[axis - 1] crop = crops[axis - 1]
...@@ -1971,23 +1967,17 @@ class GraphProto(object): ...@@ -1971,23 +1967,17 @@ class GraphProto(object):
# Infer shapes even without specifying "add_shapes=True" # Infer shapes even without specifying "add_shapes=True"
if output_shapes == [None]: if output_shapes == [None]:
out_shapes = [] out_shapes = [_infer_shape(node_item) for node_item in self._nodes[node.name]]
for node_item in self._nodes[node.name]:
out_type = ir_pass.infer_type(node_item)
out_shapes.append(get_const_tuple(out_type.checked_type.shape))
self._output_shapes[node.name] = out_shapes self._output_shapes[node.name] = out_shapes
if self._output_shapes[node.name] and shape and node.name in shape: if self._output_shapes[node.name] and shape and node.name in shape:
assert self._output_shapes[node.name] == list(shape[node.name]) assert self._output_shapes[node.name] == list(shape[node.name])
# Infer shapes if passed explicitely # Infer shapes if passed explicitly
node_output = self._nodes[node.name] node_output = self._nodes[node.name]
if shape and (not self._output_shapes[node.name][0] if shape and (not self._output_shapes[node.name][0]
or -1 in self._output_shapes[node.name][0]): or -1 in self._output_shapes[node.name][0]):
out_shapes = [] out_shapes = [_infer_shape(node_item) for node_item in node_output]
for node_item in node_output:
out_type = ir_pass.infer_type(node_item)
out_shapes.append(get_const_tuple(out_type.checked_type.shape))
self._output_shapes[node.name] = out_shapes self._output_shapes[node.name] = out_shapes
out = [] out = []
......
...@@ -56,31 +56,23 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1, ...@@ -56,31 +56,23 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
layout = None layout = None
if target == "cuda": if target == "cuda":
layout = "NCHW" layout = "NCHW"
target_host = 'llvm' target_host = None
if isinstance(input_data, list): shape_dict = {e: i.shape for e, i in zip(input_node, input_data)}
shape_dict = {}
dtype_dict = {}
for i, e in enumerate(input_node):
shape_dict[e] = input_data[i].shape
dtype_dict[e] = input_data[i].dtype
else:
shape_dict = {input_node: input_data.shape}
dtype_dict = {input_node: input_data.dtype}
sym, params = relay.frontend.from_tensorflow(graph_def, sym, params = relay.frontend.from_tensorflow(graph_def,
layout=layout, layout=layout,
shape=shape_dict, shape=shape_dict,
outputs=out_names) outputs=out_names)
with relay.build_config(opt_level=opt_level): with relay.build_config(opt_level=opt_level):
graph, lib, params = relay.build(sym, target, params=params) graph, lib, params = relay.build(sym, target, target_host, params)
ctx = tvm.context(target, 0) ctx = tvm.context(target, 0)
from tvm.contrib import graph_runtime from tvm.contrib import graph_runtime
m = graph_runtime.create(graph, lib, ctx) m = graph_runtime.create(graph, lib, ctx)
# set inputs # set inputs
for i, e in enumerate(input_node): for e, i in zip(input_node, input_data):
m.set_input(e, tvm.nd.array(input_data[i].astype(input_data[i].dtype))) m.set_input(e, tvm.nd.array(i))
m.set_input(**params) m.set_input(**params)
# execute # execute
...@@ -88,10 +80,7 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1, ...@@ -88,10 +80,7 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
# get outputs # get outputs
assert out_names is None or num_output == len(out_names), ( assert out_names is None or num_output == len(out_names), (
"out_names: {} num_output: {}".format(out_names, num_output)) "out_names: {} num_output: {}".format(out_names, num_output))
tvm_output_list = [] tvm_output_list = [m.get_output(i).asnumpy() for i in range(num_output)]
for i in range(0, num_output):
tvm_output = m.get_output(i)
tvm_output_list.append(tvm_output.asnumpy())
return tvm_output_list return tvm_output_list
def run_tf_graph(sess, input_data, input_node, output_node): def run_tf_graph(sess, input_data, input_node, output_node):
...@@ -100,13 +89,9 @@ def run_tf_graph(sess, input_data, input_node, output_node): ...@@ -100,13 +89,9 @@ def run_tf_graph(sess, input_data, input_node, output_node):
input_node = convert_to_list(input_node) input_node = convert_to_list(input_node)
output_node = convert_to_list(output_node) output_node = convert_to_list(output_node)
tensor = [0] * len(output_node) tensor = [sess.graph.get_tensor_by_name(output_name) for output_name in output_node]
for i in range(len(output_node)):
tensor[i] = sess.graph.get_tensor_by_name(output_node[i])
input_dict = {} input_dict = {e: input_data[i] for i, e in enumerate(input_node)}
for i, e in enumerate(input_node):
input_dict[e] = input_data[i]
output_data = sess.run(tensor, input_dict) output_data = sess.run(tensor, input_dict)
return output_data return output_data
...@@ -115,17 +100,15 @@ def run_tf_graph(sess, input_data, input_node, output_node): ...@@ -115,17 +100,15 @@ def run_tf_graph(sess, input_data, input_node, output_node):
def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False, def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False,
no_gpu=False, opt_level=3): no_gpu=False, opt_level=3):
"""Generic function to generate and compare tensorflow and TVM output""" """Generic function to generate and compare tensorflow and TVM output"""
def name_without_num(name):
return name.split(':')[0] if ":" in name else name
out_name = convert_to_list(out_name) out_name = convert_to_list(out_name)
out_node = [0]*len(out_name) out_node = [name_without_num(name) for name in out_name]
for i in range(len(out_name)):
out_node[i] = out_name[i].split(':')[0] if ":" in out_name[i] else out_name[i]
in_data = convert_to_list(in_data) in_data = convert_to_list(in_data)
in_name = convert_to_list(in_name) in_name = convert_to_list(in_name)
in_node = [0]*len(in_name) in_node = [name_without_num(name) for name in in_name]
for i in range(len(in_name)):
in_node[i] = in_name[i].split(':')[0] if ":" in in_name[i] else in_name[i]
with tf.Session() as sess: with tf.Session() as sess:
if init_global_variables: if init_global_variables:
sess.run(variables.global_variables_initializer()) sess.run(variables.global_variables_initializer())
...@@ -578,6 +561,38 @@ def test_forward_variable(): ...@@ -578,6 +561,38 @@ def test_forward_variable():
####################################################################### #######################################################################
# MatMul
# ------
def _test_matmul(i, j, k, dtype, outer=None):
""" One iteration of matmul """
A_shape_init = [i, j]
B_shape_init = [j, k]
for transpose_a in [False, True]:
for transpose_b in [False, True]:
outer = outer or []
A_shape = outer + (A_shape_init[::-1] if transpose_a else A_shape_init)
B_shape = outer + (B_shape_init[::-1] if transpose_b else B_shape_init)
with tf.Graph().as_default():
A = tf.placeholder(shape=A_shape, dtype=dtype, name='A')
B = tf.placeholder(shape=B_shape, dtype=dtype, name='B')
result = tf.matmul(A, B, transpose_a=transpose_a, transpose_b=transpose_b)
A_np = np.random.uniform(high=5.0, size=A_shape).astype(dtype)
B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype)
compare_tf_with_tvm([A_np, B_np], [A.name, B.name], result.name)
def test_forward_matmul():
""" Matmul op test"""
_test_matmul(1, 3, 6, 'int32')
_test_matmul(5, 3, 1, 'float64')
# TODO non-empty outer requires BatchMatMul (BatchMatMulV2 for some cases?) support
#######################################################################
# StridedSlice # StridedSlice
# ------------ # ------------
...@@ -1785,3 +1800,6 @@ if __name__ == '__main__': ...@@ -1785,3 +1800,6 @@ if __name__ == '__main__':
test_forward_rel_ops() test_forward_rel_ops()
test_forward_logical() test_forward_logical()
test_where() test_where()
test_forward_matmul()
# TODO missing tests: rank, range
\ No newline at end of file
...@@ -151,11 +151,7 @@ def get_const_tuple(in_tuple): ...@@ -151,11 +151,7 @@ def get_const_tuple(in_tuple):
out_tuple : tuple of int out_tuple : tuple of int
The output. The output.
""" """
out_tuple = () return tuple(get_const_int(elem) for elem in in_tuple)
for elem in in_tuple:
value = get_const_int(elem)
out_tuple = out_tuple + (value, )
return out_tuple
def get_float_tuple(in_tuple): def get_float_tuple(in_tuple):
...@@ -171,11 +167,7 @@ def get_float_tuple(in_tuple): ...@@ -171,11 +167,7 @@ def get_float_tuple(in_tuple):
out_tuple : tuple of float out_tuple : tuple of float
The output. The output.
""" """
out_tuple = () return tuple(get_const_float(elem) for elem in in_tuple)
for elem in in_tuple:
value = get_const_float(elem)
out_tuple = out_tuple + (value, )
return out_tuple
def simplify(expr): def simplify(expr):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment