Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
2c41fd2f
Commit
2c41fd2f
authored
Jun 11, 2019
by
hlu1
Committed by
Haichen Shen
Jun 11, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Topi] Fast mode in take op (#3325)
parent
d4ca627a
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
43 additions
and
5 deletions
+43
-5
include/tvm/relay/attrs/transform.h
+2
-1
python/tvm/relay/op/transform.py
+2
-1
tests/python/relay/test_op_level3.py
+5
-1
topi/include/topi/transform.h
+26
-0
topi/python/topi/transform.py
+1
-0
topi/tests/python/test_topi_transform.py
+7
-2
No files found.
include/tvm/relay/attrs/transform.h
View file @
2c41fd2f
...
...
@@ -101,7 +101,8 @@ struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
TVM_ATTR_FIELD
(
mode
).
set_default
(
"clip"
)
.
describe
(
"Specify how out-of-bound indices will behave."
"clip - clip to the range (default)"
"wrap - wrap around the indices"
);
"wrap - wrap around the indices"
"fast - no clip or wrap around (user must make sure indices are in-bound)"
);
}
};
...
...
python/tvm/relay/op/transform.py
View file @
2c41fd2f
...
...
@@ -218,9 +218,10 @@ def take(data, indices, axis=None, mode="clip"):
the flattened input array is used.
mode : str, optional
Specifies how out-of-bound indices will behave [clip, wrap].
Specifies how out-of-bound indices will behave [clip, wrap
, fast
].
clip: clip to the range (default).
wrap: wrap around the indices.
fast: no clip or wrap around (user must make sure indices are in-bound).
Returns
-------
...
...
tests/python/relay/test_op_level3.py
View file @
2c41fd2f
...
...
@@ -269,7 +269,8 @@ def test_take():
func
=
relay
.
Function
([
x
,
indices
],
z
)
x_data
=
np
.
random
.
uniform
(
low
=-
1
,
high
=
1
,
size
=
src_shape
)
.
astype
(
src_dtype
)
ref_res
=
np
.
take
(
x_data
,
indices
=
indices_src
,
axis
=
axis
,
mode
=
mode
)
np_mode
=
"raise"
if
mode
==
"fast"
else
mode
ref_res
=
np
.
take
(
x_data
,
indices
=
indices_src
,
axis
=
axis
,
mode
=
np_mode
)
for
target
,
ctx
in
ctx_list
():
for
kind
in
[
"graph"
,
"debug"
]:
...
...
@@ -291,6 +292,9 @@ def test_take():
verify_take
((
3
,
4
),
[
-
1
,
2
],
axis
=
0
,
mode
=
"wrap"
)
verify_take
((
3
,
4
),
[
-
1
,
2
],
axis
=
1
)
verify_take
((
3
,
4
),
[
-
1
,
2
],
axis
=
1
,
mode
=
"wrap"
)
verify_take
((
3
,
3
,
3
),
[[
11
,
25
]],
mode
=
"fast"
)
verify_take
((
3
,
4
),
[
0
,
2
],
axis
=
0
,
mode
=
"fast"
)
verify_take
((
3
,
4
),
[
0
,
2
],
axis
=
1
,
mode
=
"fast"
)
def
test_split_infer_type
():
...
...
topi/include/topi/transform.h
View file @
2c41fd2f
...
...
@@ -641,6 +641,13 @@ inline Tensor take(const Tensor& a,
auto
idx
=
tvm
::
min
(
tvm
::
max
(
0
,
indices
(
out_index
)),
a_size
-
1
);
return
a
(
UnravelIndex
(
idx
,
a_shape
));
},
name
,
tag
);
}
else
if
(
mode
==
"fast"
)
{
LOG
(
WARNING
)
<<
"Fast mode segfaults when there are out-of-bounds indices. "
"Make sure input indices are in bound"
;
return
compute
(
out_shape
,
[
&
](
const
Array
<
Var
>&
out_index
)
{
return
a
(
UnravelIndex
(
indices
(
out_index
),
a_shape
));
},
name
,
tag
);
}
else
{
// mode == "wrap"
return
compute
(
out_shape
,
[
&
](
const
Array
<
Var
>&
out_index
)
{
...
...
@@ -706,6 +713,25 @@ inline Tensor take(const Tensor& a,
}
return
a
(
real_indices
);
},
name
,
tag
);
}
else
if
(
mode
==
"fast"
)
{
LOG
(
WARNING
)
<<
"Fast mode segfaults when there are out-of-bounds indices. "
"Make sure input indices are in bound"
;
return
compute
(
out_shape
,
[
&
](
const
Array
<
Var
>&
out_index
)
{
Array
<
Expr
>
indices_position
;
for
(
size_t
j
=
axis
;
j
<
static_cast
<
size_t
>
(
axis
+
indices_len
);
++
j
)
{
indices_position
.
push_back
(
out_index
[
j
]);
}
Array
<
Expr
>
real_indices
;
for
(
size_t
j
=
0
;
j
<
static_cast
<
size_t
>
(
axis
);
++
j
)
{
real_indices
.
push_back
(
out_index
[
j
]);
}
real_indices
.
push_back
(
indices
(
indices_position
));
for
(
size_t
j
=
axis
+
indices_len
;
j
<
out_index
.
size
();
++
j
)
{
real_indices
.
push_back
(
out_index
[
j
]);
}
return
a
(
real_indices
);
},
name
,
tag
);
}
else
{
// mode == "wrap"
return
compute
(
out_shape
,
[
&
](
const
Array
<
Var
>&
out_index
)
{
...
...
topi/python/topi/transform.py
View file @
2c41fd2f
...
...
@@ -265,6 +265,7 @@ def take(a, indices, axis=None, mode="clip"):
Specifies how out-of-bound indices will behave.
clip - clip to the range (default)
wrap - wrap around the indices
fast - no clip or wrap around (user must make sure indices are in-bound)
Returns
-------
...
...
topi/tests/python/test_topi_transform.py
View file @
2c41fd2f
...
...
@@ -275,9 +275,11 @@ def verify_take(src_shape, indices_src, axis=None, mode="clip"):
data_npy
=
np
.
arange
(
shape_size
,
dtype
=
src_dtype
)
.
reshape
((
src_shape
))
if
axis
is
None
:
out_npys
=
np
.
take
(
data_npy
,
indices_src
,
mode
=
mode
)
np_mode
=
"raise"
if
mode
==
"fast"
else
mode
out_npys
=
np
.
take
(
data_npy
,
indices_src
,
mode
=
np_mode
)
else
:
out_npys
=
np
.
take
(
data_npy
,
indices_src
,
axis
=
axis
,
mode
=
mode
)
np_mode
=
"raise"
if
mode
==
"fast"
else
mode
out_npys
=
np
.
take
(
data_npy
,
indices_src
,
axis
=
axis
,
mode
=
np_mode
)
data_nd
=
tvm
.
nd
.
array
(
data_npy
,
ctx
)
indices_nd
=
tvm
.
nd
.
array
(
indices_src
,
ctx
)
out_nd
=
tvm
.
nd
.
empty
(
out_npys
.
shape
,
ctx
=
ctx
,
dtype
=
src_dtype
)
...
...
@@ -521,6 +523,9 @@ def test_take():
verify_take
((
3
,
4
),
[
-
1
,
2
],
axis
=
0
,
mode
=
"wrap"
)
verify_take
((
3
,
4
),
[
-
1
,
2
],
axis
=
1
)
verify_take
((
3
,
4
),
[
-
1
,
2
],
axis
=
1
,
mode
=
"wrap"
)
verify_take
((
3
,
3
,
3
),
[[
11
,
25
]],
mode
=
"fast"
)
verify_take
((
3
,
4
),
[
0
,
2
],
axis
=
0
,
mode
=
"fast"
)
verify_take
((
3
,
4
),
[
0
,
2
],
axis
=
1
,
mode
=
"fast"
)
def
test_gather_nd
():
for
indices_dtype
in
[
'int32'
,
'float32'
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment