Commit fae3f60c by lixiaoquan Committed by Yizhi Liu

[FRONTEND][TENSORFLOW] Add Transpose support. (#1665)

* [FRONTEND][TENSORFLOW] Add Transpose support.

* [FRONTEND][TENSORFLOW] Get parameter from inputs and fix document style.

* [FRONTEND][TENSORFLOW] Handle the case that perm is not specified.

* [FRONTEND][TENSORFLOW] Convert Rank and Range to param.

* [FRONTEND][TENSORFLOW] Fix a pylint issue.

* [FRONTEND][TENSORFLOW] Implement Rank and Range as normal op.
parent cda8cb24
...@@ -650,6 +650,35 @@ def _pad(name): ...@@ -650,6 +650,35 @@ def _pad(name):
ignores=['Tpaddings'],)(new_inputs, attr) ignores=['Tpaddings'],)(new_inputs, attr)
return _impl return _impl
def _transpose():
def _impl(inputs, attr, params):
# If perm is not specified, axes is left empty,
# otherwise its value is get from params
param_name = inputs[1].list_output_names()[0]
axes = params.get(param_name, tvm.nd.array([])).asnumpy()
return _sym.transpose(inputs[0], axes=tuple(axes))
return _impl
def _rank():
def _impl(inputs, attr, params):
input_shapes = attr['_input_shapes'][inputs[0]]
assert len(inputs) == 1
name = attr["_node_name"]
params[name] = tvm.nd.array([len(input_shapes[0])])
return _sym.Variable(name=name, shape=params[name].shape)
return _impl
def _range():
def _impl(inputs, attr, params):
start = params.pop(inputs[0].list_output_names()[0]).asnumpy()[0]
limit = params.pop(inputs[1].list_output_names()[0]).asnumpy()[0]
delta = params.pop(inputs[2].list_output_names()[0]).asnumpy()[0]
name = attr["_node_name"]
params[name] = tvm.nd.array([start, limit, delta])
return _sym.Variable(name=name, shape=params[name].shape)
return _impl
# compatible operators that do NOT require any conversion. # compatible operators that do NOT require any conversion.
_identity_list = [] _identity_list = []
...@@ -700,6 +729,9 @@ _convert_map = { ...@@ -700,6 +729,9 @@ _convert_map = {
'LRN' : _lrn(), 'LRN' : _lrn(),
'Pad' : _pad('Pad'), 'Pad' : _pad('Pad'),
'PadV2' : _pad('PadV2'), 'PadV2' : _pad('PadV2'),
'Range' : _range(),
'Rank' : _rank(),
'Transpose' : _transpose(),
} }
# _convert_map_rnn defines maps of rnn operator name to # _convert_map_rnn defines maps of rnn operator name to
......
...@@ -853,11 +853,34 @@ def _test_l2_normalize(ishape, eps, axis): ...@@ -853,11 +853,34 @@ def _test_l2_normalize(ishape, eps, axis):
def test_forward_l2_normalize(): def test_forward_l2_normalize():
_test_l2_normalize((1, 3, 20, 20), 0.001, (0,)) _test_l2_normalize((1, 3, 20, 20), 0.001, (0,))
#######################################################################
# transpose
# ---------
def _test_forward_transpose(ishape, axes=None):
input = np.random.uniform(size=ishape).astype(np.float32)
with tf.Graph().as_default():
in1 = tf.placeholder(shape=input.shape, dtype=input.dtype, name="transpose_data")
if axes is None:
tf.transpose(in1)
else:
tf.transpose(in1, perm=axes)
compare_tf_with_tvm(input, 'transpose_data:0', 'transpose:0')
def test_forward_transpose():
_test_forward_transpose((2, 3, 4))
_test_forward_transpose((7, 8, 8, 10))
_test_forward_transpose((2, 3, 4), (1, 2, 0))
_test_forward_transpose((2, 3, 4), (0, 1, 2))
_test_forward_transpose((2, 3, 4, 5), (3, 0, 1, 2))
####################################################################### #######################################################################
# Main # Main
# ---- # ----
if __name__ == '__main__': if __name__ == '__main__':
test_forward_transpose()
test_forward_convolution() test_forward_convolution()
test_forward_pooling() test_forward_pooling()
test_forward_reshape() test_forward_reshape()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment