Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
dc7ab96b
Commit
dc7ab96b
authored
Sep 26, 2017
by
Xingjian Shi
Committed by
Tianqi Chen
Sep 26, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[TOPI] add squeeze (#494)
* add squeeze * should be squeeze
parent
fd864c51
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
81 additions
and
1 deletions
+81
-1
topi/python/topi/transform.py
+52
-1
topi/tests/python/test_topi_transform.py
+29
-0
No files found.
topi/python/topi/transform.py
View file @
dc7ab96b
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
from
__future__
import
absolute_import
as
_abs
from
__future__
import
absolute_import
as
_abs
import
tvm
import
tvm
from
.
import
tag
from
.
import
tag
from
.util
import
ravel_index
,
unravel_index
,
get_const_int
from
.util
import
ravel_index
,
unravel_index
,
get_const_int
,
get_const_tuple
@tvm.tag_scope
(
tag
=
tag
.
BROADCAST
)
@tvm.tag_scope
(
tag
=
tag
.
BROADCAST
)
def
expand_dims
(
a
,
axis
,
num_newaxis
=
1
):
def
expand_dims
(
a
,
axis
,
num_newaxis
=
1
):
...
@@ -78,6 +78,57 @@ def reshape(a, newshape):
...
@@ -78,6 +78,57 @@ def reshape(a, newshape):
@tvm.tag_scope
(
tag
=
tag
.
INJECTIVE
)
@tvm.tag_scope
(
tag
=
tag
.
INJECTIVE
)
def
squeeze
(
a
,
axis
=
None
):
"""Remove single-dimensional entries from the shape of an array.
Parameters
----------
a : tvm.Tensor
axis : None or int or tuple of ints, optional
Selects a subset of the single-dimensional entries in the shape.
If an axis is selected with shape entry greater than one, an error is raised.
Returns
-------
squeezed : tvm.Tensor
"""
a_ndim
=
len
(
a
.
shape
)
a_shape
=
get_const_tuple
(
a
.
shape
)
if
axis
is
None
:
axis
=
[]
for
i
,
ele
in
enumerate
(
a_shape
):
if
ele
==
1
:
axis
.
append
(
i
)
else
:
if
isinstance
(
axis
,
int
):
axis
=
axis
+
a_ndim
if
axis
<
0
else
axis
assert
a_shape
[
axis
]
==
1
axis
=
[
axis
]
else
:
axis
=
[
ele
+
a_ndim
if
ele
<
0
else
ele
for
ele
in
axis
]
for
ele
in
axis
:
assert
a_shape
[
ele
]
==
1
out_shape
=
[]
search_axis
=
set
(
axis
)
for
i
,
a_dim
in
enumerate
(
a_shape
):
if
i
not
in
search_axis
:
out_shape
.
append
(
a_dim
)
def
_compute
(
*
indices
):
real_indices
=
[]
flag
=
0
for
i
in
range
(
a_ndim
):
if
i
not
in
search_axis
:
real_indices
.
append
(
indices
[
i
-
flag
])
else
:
real_indices
.
append
(
0
)
flag
+=
1
return
a
(
*
real_indices
)
return
tvm
.
compute
(
out_shape
,
_compute
)
@tvm.tag_scope
(
tag
=
tag
.
INJECTIVE
)
def
concatenate
(
a_tuple
,
axis
=
0
):
def
concatenate
(
a_tuple
,
axis
=
0
):
"""Join a sequence of arrays along an existing axis.
"""Join a sequence of arrays along an existing axis.
...
...
topi/tests/python/test_topi_transform.py
View file @
dc7ab96b
...
@@ -69,6 +69,28 @@ def verify_reshape(src_shape, dst_shape):
...
@@ -69,6 +69,28 @@ def verify_reshape(src_shape, dst_shape):
check_device
(
"metal"
)
check_device
(
"metal"
)
def
verify_squeeze
(
src_shape
,
axis
):
A
=
tvm
.
placeholder
(
shape
=
src_shape
,
name
=
"A"
)
B
=
topi
.
squeeze
(
A
,
axis
=
axis
)
s
=
topi
.
cuda
.
schedule_injective
(
B
)
def
check_device
(
device
):
if
not
tvm
.
module
.
enabled
(
device
):
print
(
"Skip because
%
s is not enabled"
%
device
)
return
ctx
=
tvm
.
gpu
(
0
)
if
device
==
"cuda"
else
tvm
.
cl
(
0
)
foo
=
tvm
.
build
(
s
,
[
A
,
B
],
device
,
name
=
"squeeze"
)
data_npy
=
np
.
random
.
normal
(
size
=
src_shape
)
.
astype
(
A
.
dtype
)
out_npy
=
np
.
squeeze
(
data_npy
,
axis
=
axis
)
data_nd
=
tvm
.
nd
.
array
(
data_npy
,
ctx
)
out_nd
=
tvm
.
nd
.
empty
(
out_npy
.
shape
,
ctx
=
ctx
,
dtype
=
B
.
dtype
)
foo
(
data_nd
,
out_nd
)
np
.
testing
.
assert_allclose
(
out_nd
.
asnumpy
(),
out_npy
)
check_device
(
"cuda"
)
check_device
(
"opencl"
)
check_device
(
"metal"
)
def
verify_concatenate
(
shapes
,
axis
):
def
verify_concatenate
(
shapes
,
axis
):
tensor_l
=
[]
tensor_l
=
[]
for
i
,
shape
in
enumerate
(
shapes
):
for
i
,
shape
in
enumerate
(
shapes
):
...
@@ -133,6 +155,12 @@ def test_reshape():
...
@@ -133,6 +155,12 @@ def test_reshape():
verify_reshape
((
16
,
),
(
2
,
2
,
2
,
2
))
verify_reshape
((
16
,
),
(
2
,
2
,
2
,
2
))
def
test_squeeze
():
verify_squeeze
((
1
,
2
,
3
,
4
),
0
)
verify_squeeze
((
1
,
2
,
1
,
4
),
None
)
verify_squeeze
((
1
,
1
,
1
,
4
),
(
1
,
2
))
def
test_concatenate
():
def
test_concatenate
():
verify_concatenate
([(
2
,
3
,
4
),
(
2
,
2
,
4
),
(
2
,
5
,
4
)],
1
)
verify_concatenate
([(
2
,
3
,
4
),
(
2
,
2
,
4
),
(
2
,
5
,
4
)],
1
)
verify_concatenate
([(
1
,
2
,
4
),
(
1
,
2
,
3
),
(
1
,
2
,
7
),
(
1
,
2
,
8
),
(
1
,
2
,
1
)],
-
1
)
verify_concatenate
([(
1
,
2
,
4
),
(
1
,
2
,
3
),
(
1
,
2
,
7
),
(
1
,
2
,
8
),
(
1
,
2
,
1
)],
-
1
)
...
@@ -152,6 +180,7 @@ if __name__ == "__main__":
...
@@ -152,6 +180,7 @@ if __name__ == "__main__":
test_tranpose
()
test_tranpose
()
test_expand_dims
()
test_expand_dims
()
test_reshape
()
test_reshape
()
test_squeeze
()
test_concatenate
()
test_concatenate
()
test_split
()
test_split
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment