Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
03a29da7
Commit
03a29da7
authored
Nov 12, 2019
by
Wei Chen
Committed by
Zhi
Nov 12, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay][Op][TF] Complete tensor array unstack with all ranks support (#4309)
parent
e6806115
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
148 additions
and
3 deletions
+148
-3
python/tvm/relay/frontend/tensorflow.py
+8
-3
python/tvm/relay/prelude.py
+120
-0
tests/python/frontend/tensorflow/test_forward.py
+20
-0
No files found.
python/tvm/relay/frontend/tensorflow.py
View file @
03a29da7
...
...
@@ -40,6 +40,7 @@ from .common import infer_type as _infer_type
from
.common
import
infer_shape
as
_infer_shape
from
.common
import
infer_channels
as
_infer_channels
from
.common
import
infer_value
as
_infer_value
from
.common
import
infer_value_simulated
as
_infer_value_simulated
__all__
=
[
'from_tensorflow'
]
...
...
@@ -1079,9 +1080,13 @@ def _rank():
def
_range
():
def
_impl
(
inputs
,
attr
,
params
):
start
=
_get_param
(
params
,
inputs
[
0
])[
0
]
limit
=
_get_param
(
params
,
inputs
[
1
])[
0
]
\
if
hasattr
(
inputs
[
1
],
"name_hint"
)
or
isinstance
(
inputs
[
1
],
_expr
.
Constant
)
\
else
params
.
pop
(
'Rank'
)
.
asnumpy
()[
0
]
if
hasattr
(
inputs
[
1
],
"name_hint"
)
or
isinstance
(
inputs
[
1
],
_expr
.
Constant
):
limit
=
_get_param
(
params
,
inputs
[
1
])[
0
]
else
:
if
any
([
'Rank'
in
param
for
param
in
params
]):
limit
=
params
.
pop
(
'Rank'
)
.
asnumpy
()[
0
]
else
:
limit
=
_infer_value_simulated
(
inputs
[
1
],
params
)
.
asnumpy
()[
0
]
delta
=
_get_param
(
params
,
inputs
[
2
])[
0
]
dtype
=
attr
[
'Tidx'
]
.
name
if
'Tidx'
in
attr
else
str
(
start
.
dtype
)
return
AttrCvt
(
...
...
python/tvm/relay/prelude.py
View file @
03a29da7
...
...
@@ -336,6 +336,122 @@ class TensorArrayOps(object):
Function
([
tensor2
],
helper_var
(
const
(
0
),
ndim
,
tensor2
),
self
.
prelude
.
l
(
self
.
get_var
(
'tensor_t'
)()),
[])
def
define_tensor_array_unstack_tensor3
(
self
):
"""Defines a function to unstack the values of a tensor_t with rank 3 in a tensor array.
tensor_array_unstack_tensor3(t) : tensor_t -> list[tensor_t]
"""
helper_name
=
self
.
get_name
(
"tensor_array_unstack_tensor3_helper"
)
helper_var
=
GlobalVar
(
helper_name
)
setattr
(
self
.
prelude
,
helper_name
,
helper_var
)
tensor
=
Var
(
"t"
,
TensorType
([
Any
(),
Any
(),
Any
()],
self
.
dtype
))
up
=
Var
(
"up"
,
scalar_type
(
'int32'
))
i
=
Var
(
"i"
,
scalar_type
(
'int32'
))
helper_body
=
If
(
equal
(
i
,
up
),
self
.
prelude
.
nil
(),
self
.
prelude
.
cons
(
self
.
get_var
(
'tensor2'
)(
op
.
take
(
tensor
,
i
,
axis
=
0
)),
helper_var
(
add
(
i
,
const
(
1
)),
up
,
tensor
)))
self
.
prelude
.
mod
[
helper_var
]
=
\
Function
([
i
,
up
,
tensor
],
helper_body
,
self
.
prelude
.
l
(
self
.
get_var
(
'tensor_t'
)()),
[])
tensor_array_unstack_tensor3_name
=
self
.
get_name
(
"tensor_array_unstack_tensor3"
)
tensor_array_unstack_tensor3_var
=
GlobalVar
(
tensor_array_unstack_tensor3_name
)
setattr
(
self
.
prelude
,
tensor_array_unstack_tensor3_name
,
tensor_array_unstack_tensor3_var
)
tensor3
=
Var
(
"tensor"
,
TensorType
([
Any
(),
Any
(),
Any
()],
self
.
dtype
))
shape
=
op
.
shape_of
(
tensor3
)
ndim
=
op
.
take
(
shape
,
const
(
0
))
self
.
prelude
.
mod
[
tensor_array_unstack_tensor3_var
]
=
\
Function
([
tensor3
],
helper_var
(
const
(
0
),
ndim
,
tensor3
),
self
.
prelude
.
l
(
self
.
get_var
(
'tensor_t'
)()),
[])
def
define_tensor_array_unstack_tensor4
(
self
):
"""Defines a function to unstack the values of a tensor_t with rank 4 in a tensor array.
tensor_array_unstack_tensor4(t) : tensor_t -> list[tensor_t]
"""
helper_name
=
self
.
get_name
(
"tensor_array_unstack_tensor4_helper"
)
helper_var
=
GlobalVar
(
helper_name
)
setattr
(
self
.
prelude
,
helper_name
,
helper_var
)
tensor
=
Var
(
"t"
,
TensorType
([
Any
(),
Any
(),
Any
(),
Any
()],
self
.
dtype
))
up
=
Var
(
"up"
,
scalar_type
(
'int32'
))
i
=
Var
(
"i"
,
scalar_type
(
'int32'
))
helper_body
=
If
(
equal
(
i
,
up
),
self
.
prelude
.
nil
(),
self
.
prelude
.
cons
(
self
.
get_var
(
'tensor3'
)(
op
.
take
(
tensor
,
i
,
axis
=
0
)),
helper_var
(
add
(
i
,
const
(
1
)),
up
,
tensor
)))
self
.
prelude
.
mod
[
helper_var
]
=
\
Function
([
i
,
up
,
tensor
],
helper_body
,
self
.
prelude
.
l
(
self
.
get_var
(
'tensor_t'
)()),
[])
tensor_array_unstack_tensor4_name
=
self
.
get_name
(
"tensor_array_unstack_tensor4"
)
tensor_array_unstack_tensor4_var
=
GlobalVar
(
tensor_array_unstack_tensor4_name
)
setattr
(
self
.
prelude
,
tensor_array_unstack_tensor4_name
,
tensor_array_unstack_tensor4_var
)
tensor4
=
Var
(
"tensor"
,
TensorType
([
Any
(),
Any
(),
Any
(),
Any
()],
self
.
dtype
))
shape
=
op
.
shape_of
(
tensor4
)
ndim
=
op
.
take
(
shape
,
const
(
0
))
self
.
prelude
.
mod
[
tensor_array_unstack_tensor4_var
]
=
\
Function
([
tensor4
],
helper_var
(
const
(
0
),
ndim
,
tensor4
),
self
.
prelude
.
l
(
self
.
get_var
(
'tensor_t'
)()),
[])
def
define_tensor_array_unstack_tensor5
(
self
):
"""Defines a function to unstack the values of a tensor_t with rank 5 in a tensor array.
tensor_array_unstack_tensor5(t) : tensor_t -> list[tensor_t]
"""
helper_name
=
self
.
get_name
(
"tensor_array_unstack_tensor5_helper"
)
helper_var
=
GlobalVar
(
helper_name
)
setattr
(
self
.
prelude
,
helper_name
,
helper_var
)
tensor
=
Var
(
"t"
,
TensorType
([
Any
(),
Any
(),
Any
(),
Any
(),
Any
()],
self
.
dtype
))
up
=
Var
(
"up"
,
scalar_type
(
'int32'
))
i
=
Var
(
"i"
,
scalar_type
(
'int32'
))
helper_body
=
If
(
equal
(
i
,
up
),
self
.
prelude
.
nil
(),
self
.
prelude
.
cons
(
self
.
get_var
(
'tensor4'
)(
op
.
take
(
tensor
,
i
,
axis
=
0
)),
helper_var
(
add
(
i
,
const
(
1
)),
up
,
tensor
)))
self
.
prelude
.
mod
[
helper_var
]
=
\
Function
([
i
,
up
,
tensor
],
helper_body
,
self
.
prelude
.
l
(
self
.
get_var
(
'tensor_t'
)()),
[])
tensor_array_unstack_tensor5_name
=
self
.
get_name
(
"tensor_array_unstack_tensor5"
)
tensor_array_unstack_tensor5_var
=
GlobalVar
(
tensor_array_unstack_tensor5_name
)
setattr
(
self
.
prelude
,
tensor_array_unstack_tensor5_name
,
tensor_array_unstack_tensor5_var
)
tensor5
=
Var
(
"tensor"
,
TensorType
([
Any
(),
Any
(),
Any
(),
Any
(),
Any
()],
self
.
dtype
))
shape
=
op
.
shape_of
(
tensor5
)
ndim
=
op
.
take
(
shape
,
const
(
0
))
self
.
prelude
.
mod
[
tensor_array_unstack_tensor5_var
]
=
\
Function
([
tensor5
],
helper_var
(
const
(
0
),
ndim
,
tensor5
),
self
.
prelude
.
l
(
self
.
get_var
(
'tensor_t'
)()),
[])
def
define_tensor_array_unstack_tensor6
(
self
):
"""Defines a function to unstack the values of a tensor_t with rank 6 in a tensor array.
tensor_array_unstack_tensor6(t) : tensor_t -> list[tensor_t]
"""
helper_name
=
self
.
get_name
(
"tensor_array_unstack_tensor6_helper"
)
helper_var
=
GlobalVar
(
helper_name
)
setattr
(
self
.
prelude
,
helper_name
,
helper_var
)
tensor
=
Var
(
"t"
,
TensorType
([
Any
(),
Any
(),
Any
(),
Any
(),
Any
(),
Any
()],
self
.
dtype
))
up
=
Var
(
"up"
,
scalar_type
(
'int32'
))
i
=
Var
(
"i"
,
scalar_type
(
'int32'
))
helper_body
=
If
(
equal
(
i
,
up
),
self
.
prelude
.
nil
(),
self
.
prelude
.
cons
(
self
.
get_var
(
'tensor5'
)(
op
.
take
(
tensor
,
i
,
axis
=
0
)),
helper_var
(
add
(
i
,
const
(
1
)),
up
,
tensor
)))
self
.
prelude
.
mod
[
helper_var
]
=
\
Function
([
i
,
up
,
tensor
],
helper_body
,
self
.
prelude
.
l
(
self
.
get_var
(
'tensor_t'
)()),
[])
tensor_array_unstack_tensor6_name
=
self
.
get_name
(
"tensor_array_unstack_tensor6"
)
tensor_array_unstack_tensor6_var
=
GlobalVar
(
tensor_array_unstack_tensor6_name
)
setattr
(
self
.
prelude
,
tensor_array_unstack_tensor6_name
,
tensor_array_unstack_tensor6_var
)
tensor6
=
Var
(
"tensor"
,
TensorType
([
Any
(),
Any
(),
Any
(),
Any
(),
Any
(),
Any
()],
self
.
dtype
))
shape
=
op
.
shape_of
(
tensor6
)
ndim
=
op
.
take
(
shape
,
const
(
0
))
self
.
prelude
.
mod
[
tensor_array_unstack_tensor6_var
]
=
\
Function
([
tensor6
],
helper_var
(
const
(
0
),
ndim
,
tensor6
),
self
.
prelude
.
l
(
self
.
get_var
(
'tensor_t'
)()),
[])
def
define_tensor_array_scatter
(
self
):
"""Defines a function to scatter the values of a tensor_t in indices of a tensor array.
tensor_array_scatter(ta, indices, value) :
...
...
@@ -516,6 +632,10 @@ class TensorArrayOps(object):
self
.
define_tensor_array_write
()
self
.
define_tensor_array_unstack_tensor1
()
self
.
define_tensor_array_unstack_tensor2
()
self
.
define_tensor_array_unstack_tensor3
()
self
.
define_tensor_array_unstack_tensor4
()
self
.
define_tensor_array_unstack_tensor5
()
self
.
define_tensor_array_unstack_tensor6
()
self
.
define_tensor_array_scatter
()
self
.
define_tensor_array_split
()
self
.
define_tensor_array_concat
()
...
...
tests/python/frontend/tensorflow/test_forward.py
View file @
03a29da7
...
...
@@ -763,6 +763,26 @@ def test_tensor_array_size():
for
dtype
in
tf_dtypes
.
keys
():
run
(
dtype
)
def
test_tensor_array_unstack
():
def
run
(
dtype_str
,
input_shape
):
with
tf
.
Graph
()
.
as_default
():
dtype
=
tf_dtypes
[
dtype_str
]
t
=
tf
.
constant
(
np
.
random
.
choice
([
0
,
1
,
2
,
3
],
size
=
input_shape
)
.
astype
(
dtype
.
name
))
ta1
=
tf
.
TensorArray
(
dtype
=
dtype
,
infer_shape
=
False
,
size
=
input_shape
[
0
])
ta2
=
ta1
.
unstack
(
t
)
out0
=
ta2
.
size
()
out1
=
ta2
.
read
(
0
)
compare_tf_with_tvm
([],
[],
'TensorArraySizeV3:0'
,
mode
=
'debug'
)
compare_tf_with_tvm
([],
[],
'TensorArrayReadV3:0'
,
mode
=
'debug'
)
for
dtype
in
tf_dtypes
.
keys
():
run
(
dtype
,
(
5
,))
run
(
dtype
,
(
5
,
5
))
run
(
dtype
,
(
5
,
5
,
5
))
run
(
dtype
,
(
5
,
5
,
5
,
5
))
run
(
dtype
,
(
5
,
5
,
5
,
5
,
5
))
run
(
dtype
,
(
5
,
5
,
5
,
5
,
5
,
5
))
#######################################################################
# ConcatV2
# --------
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment