Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
58888b21
Commit
58888b21
authored
May 19, 2018
by
Pariksheet Pinjari
Committed by
Tianqi Chen
May 18, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[TOPI] add take (#1158)
parent
bd988658
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
217 additions
and
0 deletions
+217
-0
topi/include/topi/transform.h
+81
-0
topi/python/topi/transform.py
+25
-0
topi/src/topi.cc
+10
-0
topi/tests/python/test_topi_transform.py
+50
-0
topi/tests/python_cpp/test_topi_transform.py
+51
-0
No files found.
topi/include/topi/transform.h
View file @
58888b21
...
...
@@ -398,5 +398,86 @@ inline Array<Tensor> split_sections(const Tensor& x,
return
split
(
x
,
split_indices
,
axis
,
name
,
tag
);
}
/*!
* \brief Take elements from an flattened input array when axis is None.
*
* \param a The source array.
* \param indices The indices of the values to extract.
* \param name The name of the operation.
* \param tag The tag to mark the operation.
*
* \return A Tensor whose op member is the take operation
*/
inline
Tensor
take
(
const
Tensor
&
a
,
const
Tensor
&
indices
,
std
::
string
name
=
"tensor"
,
std
::
string
tag
=
kInjective
)
{
Array
<
Expr
>
a_shape
=
a
->
shape
;
Array
<
Expr
>
out_shape
;
for
(
size_t
j
=
0
;
j
<
indices
->
shape
.
size
();
++
j
)
{
out_shape
.
push_back
(
indices
->
shape
[
j
]);
}
return
compute
(
out_shape
,
[
&
](
const
Array
<
Var
>&
out_index
)
{
Array
<
Expr
>
indices_position
;
for
(
size_t
j
=
0
;
j
<
indices
->
shape
.
size
();
++
j
)
{
indices_position
.
push_back
(
out_index
[
j
]);
}
return
a
(
UnavelIndex
(
indices
(
indices_position
),
a_shape
));
},
name
,
tag
);
}
/*!
* \brief Take elements from an array along an axis.
*
* \param a The source array.
* \param indices The indices of the values to extract.
* \param axis The axis over which to select values. By default,
* the flattened input array is used.
* \param name The name of the operation.
* \param tag The tag to mark the operation.
*
* \return A Tensor whose op member is the take operation
*/
inline
Tensor
take
(
const
Tensor
&
a
,
const
Tensor
&
indices
,
int
axis
,
std
::
string
name
=
"tensor"
,
std
::
string
tag
=
kInjective
)
{
if
(
axis
<
0
)
{
axis
+=
static_cast
<
int
>
(
a
->
shape
.
size
());
}
CHECK_LT
(
axis
,
a
->
shape
.
size
())
<<
"axis out of bounds"
;
int
indices_len
=
static_cast
<
int
>
(
indices
->
shape
.
size
());
Array
<
Expr
>
out_shape
;
for
(
size_t
i
=
0
;
i
<
a
->
shape
.
size
();
++
i
)
{
if
(
axis
==
static_cast
<
int
>
(
i
))
{
for
(
size_t
j
=
0
;
j
<
indices
->
shape
.
size
();
++
j
)
{
out_shape
.
push_back
(
indices
->
shape
[
j
]);
}
}
else
{
out_shape
.
push_back
(
a
->
shape
[
i
]);
}
}
return
compute
(
out_shape
,
[
&
](
const
Array
<
Var
>&
out_index
)
{
Array
<
Expr
>
indices_position
;
for
(
size_t
j
=
axis
;
j
<
static_cast
<
size_t
>
(
axis
+
indices_len
);
++
j
)
{
indices_position
.
push_back
(
out_index
[
j
]);
}
Array
<
Expr
>
real_indices
;
for
(
size_t
j
=
0
;
j
<
static_cast
<
size_t
>
(
axis
);
++
j
)
{
real_indices
.
push_back
(
out_index
[
j
]);
}
real_indices
.
push_back
(
indices
(
indices_position
));
for
(
size_t
j
=
axis
+
indices_len
;
j
<
out_index
.
size
();
++
j
)
{
real_indices
.
push_back
(
out_index
[
j
]);
}
return
a
(
real_indices
);
},
name
,
tag
);
}
}
// namespace topi
#endif // TOPI_TRANSFORM_H_
topi/python/topi/transform.py
View file @
58888b21
...
...
@@ -286,3 +286,28 @@ def split(ary, indices_or_sections, axis=0):
lambda
*
indices
:
_compute
(
begin_id
,
*
indices
),
name
=
"s
%
d"
%
i
)
for
i
,
(
out_shape
,
begin_id
)
in
enumerate
(
zip
(
out_shapes
,
begin_ids
))]
# pylint: enable=cell-var-from-loop
@tvm.tag_scope
(
tag
=
tag
.
INJECTIVE
)
def
take
(
a
,
indices
,
axis
=
None
):
"""Take elements from an array along an axis.
Parameters
----------
a : tvm.Tensor
The source array.
indices : tvm.Tensor
The indices of the values to extract.
axis : int, optional
The axis over which to select values. By default,
the flattened input array is used.
Returns
-------
ret : tvm.Tensor
"""
if
axis
is
None
:
return
cpp
.
take
(
a
,
indices
)
return
cpp
.
take
(
a
,
indices
,
int
(
axis
))
topi/src/topi.cc
View file @
58888b21
...
...
@@ -270,6 +270,16 @@ TVM_REGISTER_GLOBAL("topi.split")
}
});
TVM_REGISTER_GLOBAL
(
"topi.take"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
if
(
args
.
size
()
==
2
)
{
*
rv
=
take
(
args
[
0
],
args
[
1
]);
}
else
{
int
axis
=
args
[
2
];
*
rv
=
take
(
args
[
0
],
args
[
1
],
axis
);
}
});
/* Ops from nn/batch_norm.h */
TVM_REGISTER_GLOBAL
(
"topi.nn.batch_norm_inference"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
...
...
topi/tests/python/test_topi_transform.py
View file @
58888b21
...
...
@@ -207,6 +207,46 @@ def verify_flip(in_shape, axis):
for
device
in
[
"llvm"
,
"cuda"
,
"opencl"
]:
check_device
(
device
)
def
verify_take
(
src_shape
,
indices_src
,
axis
=
None
):
src_dtype
=
"float32"
indices_dtype
=
"int32"
indices_src
=
np
.
array
(
indices_src
,
dtype
=
indices_dtype
)
A
=
tvm
.
placeholder
(
shape
=
src_shape
,
dtype
=
src_dtype
,
name
=
"A"
)
indices
=
tvm
.
placeholder
(
shape
=
indices_src
.
shape
,
dtype
=
indices_dtype
,
name
=
"indices"
)
if
axis
is
None
:
out_tensor
=
topi
.
take
(
a
=
A
,
indices
=
indices
)
else
:
out_tensor
=
topi
.
take
(
a
=
A
,
indices
=
indices
,
axis
=
axis
)
def
check_device
(
device
):
ctx
=
tvm
.
context
(
device
,
0
)
if
not
ctx
.
exist
:
print
(
"Skip because
%
s is not enabled"
%
device
)
return
print
(
"Running on target:
%
s"
%
device
)
with
tvm
.
target
.
create
(
device
):
s
=
topi
.
generic
.
schedule_injective
(
out_tensor
)
foo
=
tvm
.
build
(
s
,
[
A
]
+
[
indices
]
+
[
out_tensor
]
,
device
,
name
=
"take"
)
shape_size
=
1
for
i
in
range
(
len
(
src_shape
)):
shape_size
=
shape_size
*
src_shape
[
i
]
data_npy
=
np
.
arange
(
shape_size
,
dtype
=
src_dtype
)
.
reshape
((
src_shape
))
if
axis
is
None
:
out_npys
=
np
.
take
(
data_npy
,
indices_src
)
else
:
out_npys
=
np
.
take
(
data_npy
,
indices_src
,
axis
=
axis
)
data_nd
=
tvm
.
nd
.
array
(
data_npy
,
ctx
)
indices_nd
=
tvm
.
nd
.
array
(
indices_src
,
ctx
)
out_nd
=
tvm
.
nd
.
empty
(
out_npys
.
shape
,
ctx
=
ctx
,
dtype
=
src_dtype
)
foo
(
data_nd
,
indices_nd
,
out_nd
)
np
.
testing
.
assert_allclose
(
out_nd
.
asnumpy
(),
out_npys
)
for
device
in
[
"llvm"
,
"opencl"
]:
check_device
(
device
)
def
test_expand_dims
():
verify_expand_dims
((
3
,
10
),
(
3
,
10
,
1
,
1
),
2
,
2
)
verify_expand_dims
((
3
,
10
),
(
1
,
3
,
10
),
-
3
,
1
)
...
...
@@ -262,6 +302,15 @@ def test_expand_like():
verify_expand_like
((
3
,
4
),
(
3
,
5
,
4
),
[
1
])
verify_expand_like
((
5
,
7
),
(
5
,
6
,
7
,
8
),
[
1
,
3
])
def
test_take
():
verify_take
((
4
,),
[
1
])
verify_take
((
4
,),
[[
0
,
1
,
2
,
3
]])
verify_take
((
3
,
3
,
3
),
[[
11
,
25
]])
verify_take
((
4
,),
[[
0
,
1
],[
2
,
3
]])
verify_take
((
4
,),
[
1
],
0
)
verify_take
((
2
,
2
),
[[[
1
,
0
],[
0
,
1
]]],
0
)
verify_take
((
2
,
2
),
[[[
1
,
0
],[
0
,
1
]]],
1
)
verify_take
((
4
,
3
,
5
,
6
),
[[
2
,
1
,
0
,
0
]],
-
2
)
if
__name__
==
"__main__"
:
test_concatenate
()
...
...
@@ -272,3 +321,4 @@ if __name__ == "__main__":
test_split
()
test_flip
()
test_expand_like
()
test_take
()
topi/tests/python_cpp/test_topi_transform.py
View file @
58888b21
...
...
@@ -167,6 +167,45 @@ def verify_split(src_shape, indices_or_sections, axis):
for
device
in
[
"llvm"
,
"nvptx"
,
"cuda"
,
"opencl"
,
"metal"
,
"rocm"
]:
check_device
(
device
)
def
verify_take
(
src_shape
,
indices_src
,
axis
=
None
):
src_dtype
=
"float32"
indices_dtype
=
"int32"
indices_src
=
np
.
array
(
indices_src
,
dtype
=
indices_dtype
)
A
=
tvm
.
placeholder
(
shape
=
src_shape
,
dtype
=
src_dtype
,
name
=
"A"
)
indices
=
tvm
.
placeholder
(
shape
=
indices_src
.
shape
,
dtype
=
indices_dtype
,
name
=
"indices"
)
if
axis
is
None
:
out_tensor
=
topi
.
cpp
.
take
(
A
,
indices
)
else
:
out_tensor
=
topi
.
cpp
.
take
(
A
,
indices
,
axis
)
def
check_device
(
device
):
ctx
=
tvm
.
context
(
device
,
0
)
if
not
ctx
.
exist
:
print
(
"Skip because
%
s is not enabled"
%
device
)
return
print
(
"Running on target:
%
s"
%
device
)
with
tvm
.
target
.
create
(
device
):
s
=
topi
.
generic
.
schedule_injective
(
out_tensor
)
foo
=
tvm
.
build
(
s
,
[
A
]
+
[
indices
]
+
[
out_tensor
]
,
device
,
name
=
"take"
)
shape_size
=
1
for
i
in
range
(
len
(
src_shape
)):
shape_size
=
shape_size
*
src_shape
[
i
]
data_npy
=
np
.
arange
(
shape_size
,
dtype
=
src_dtype
)
.
reshape
((
src_shape
))
if
axis
is
None
:
out_npys
=
np
.
take
(
data_npy
,
indices_src
)
else
:
out_npys
=
np
.
take
(
data_npy
,
indices_src
,
axis
=
axis
)
data_nd
=
tvm
.
nd
.
array
(
data_npy
,
ctx
)
indices_nd
=
tvm
.
nd
.
array
(
indices_src
,
ctx
)
out_nd
=
tvm
.
nd
.
empty
(
out_npys
.
shape
,
ctx
=
ctx
,
dtype
=
src_dtype
)
foo
(
data_nd
,
indices_nd
,
out_nd
)
np
.
testing
.
assert_allclose
(
out_nd
.
asnumpy
(),
out_npys
)
for
device
in
[
"llvm"
,
"opencl"
]:
check_device
(
device
)
def
test_expand_dims
():
verify_expand_dims
((
3
,
10
),
(
3
,
10
,
1
,
1
),
2
,
2
)
...
...
@@ -209,6 +248,16 @@ def test_split():
verify_split
((
2
,
12
,
3
),
[
2
,
4
],
1
)
verify_split
((
10
,
12
,
24
),
[
5
,
7
,
9
],
-
1
)
def
test_take
():
verify_take
((
4
,),
[
1
])
verify_take
((
4
,),
[[
0
,
1
,
2
,
3
]])
verify_take
((
3
,
3
,
3
),
[[
11
,
25
]])
verify_take
((
4
,),
[[
0
,
1
],[
2
,
3
]])
verify_take
((
4
,),
[
1
],
0
)
verify_take
((
2
,
2
),
[[[
1
,
0
],[
0
,
1
]]],
0
)
verify_take
((
2
,
2
),
[[[
1
,
0
],[
0
,
1
]]],
1
)
verify_take
((
4
,
3
,
5
,
6
),
[[
2
,
1
,
0
,
0
]],
-
2
)
if
__name__
==
"__main__"
:
test_concatenate
()
test_tranpose
()
...
...
@@ -216,3 +265,4 @@ if __name__ == "__main__":
test_reshape
()
test_squeeze
()
test_split
()
test_take
()
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment