Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
4468c576
Commit
4468c576
authored
Oct 06, 2017
by
Xingjian Shi
Committed by
Tianqi Chen
Oct 06, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[TOPI] add argmax, argmin (#515)
* add argmax argmin * remove coder saver
parent
46657ed1
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
215 additions
and
43 deletions
+215
-43
topi/python/topi/cuda/reduction.py
+36
-12
topi/python/topi/reduction.py
+129
-26
topi/python/topi/tag.py
+1
-0
topi/tests/python/test_topi_reduce.py
+49
-5
No files found.
topi/python/topi/cuda/reduction.py
View file @
4468c576
...
@@ -4,14 +4,17 @@ from __future__ import absolute_import as _abs
...
@@ -4,14 +4,17 @@ from __future__ import absolute_import as _abs
import
tvm
import
tvm
from
..
import
tag
from
..
import
tag
def
_schedule_reduce
(
op
,
sch
):
def
_schedule_reduce
(
op
,
sch
,
is_idx_reduce
=
False
):
if
is_idx_reduce
:
data_out
=
op
.
input_tensors
[
0
]
else
:
data_in
=
op
.
input_tensors
[
0
]
data_in
=
op
.
input_tensors
[
0
]
data_out
=
op
.
output
(
0
)
data_out
=
op
.
output
(
0
)
assert
len
(
sch
[
data_out
]
.
op
.
reduce_axis
)
>
0
,
"reduce_axis must be bigger than zero!"
assert
len
(
sch
[
data_out
]
.
op
.
reduce_axis
)
>
0
,
"reduce_axis must be bigger than zero!"
if
len
(
sch
[
data_out
]
.
op
.
axis
)
>
0
:
if
len
(
sch
[
data_out
]
.
op
.
axis
)
>
0
:
all_reduce
=
False
all_reduce
=
False
num_thread
=
16
num_thread
=
32
block_x
=
tvm
.
thread_axis
(
"blockIdx.x"
)
block_x
=
tvm
.
thread_axis
(
"blockIdx.x"
)
thread_x
=
tvm
.
thread_axis
((
0
,
num_thread
),
"threadIdx.x"
)
thread_x
=
tvm
.
thread_axis
((
0
,
num_thread
),
"threadIdx.x"
)
thread_y
=
tvm
.
thread_axis
((
0
,
num_thread
),
"threadIdx.y"
)
thread_y
=
tvm
.
thread_axis
((
0
,
num_thread
),
"threadIdx.y"
)
...
@@ -24,21 +27,38 @@ def _schedule_reduce(op, sch):
...
@@ -24,21 +27,38 @@ def _schedule_reduce(op, sch):
fused_reduce
=
sch
[
data_out
]
.
fuse
(
*
[
sch
[
data_out
]
.
op
.
reduce_axis
[
i
]
fused_reduce
=
sch
[
data_out
]
.
fuse
(
*
[
sch
[
data_out
]
.
op
.
reduce_axis
[
i
]
for
i
in
range
(
len
(
sch
[
data_out
]
.
op
.
reduce_axis
))])
for
i
in
range
(
len
(
sch
[
data_out
]
.
op
.
reduce_axis
))])
ko
,
ki
=
sch
[
data_out
]
.
split
(
fused_reduce
,
factor
=
num_thread
)
ko
,
ki
=
sch
[
data_out
]
.
split
(
fused_reduce
,
factor
=
num_thread
)
if
is_idx_reduce
:
data_out_rf
,
_
=
sch
.
rfactor
(
data_out
,
ki
)
else
:
data_out_rf
=
sch
.
rfactor
(
data_out
,
ki
)
data_out_rf
=
sch
.
rfactor
(
data_out
,
ki
)
sch
[
data_out_rf
]
.
compute_at
(
sch
[
data_out
],
sch
[
data_out
]
.
op
.
reduce_axis
[
0
])
tx
=
sch
[
data_out
]
.
op
.
reduce_axis
[
0
]
sch
[
data_out
]
.
bind
(
tx
,
thread_x
)
sch
[
data_out_rf
]
.
compute_at
(
sch
[
data_out
],
tx
)
if
is_idx_reduce
:
real_output
=
op
.
output
(
0
)
temp_idx_input
=
data_out
.
op
.
output
(
0
)
temp_val_input
=
data_out
.
op
.
output
(
1
)
else
:
real_output
=
data_out
if
not
all_reduce
:
if
not
all_reduce
:
# Fuse and split the axis
# Fuse and split the axis
fused_outer
=
sch
[
data_out
]
.
fuse
(
*
[
sch
[
data_o
ut
]
.
op
.
axis
[
i
]
fused_outer
=
sch
[
real_output
]
.
fuse
(
*
[
sch
[
real_outp
ut
]
.
op
.
axis
[
i
]
for
i
in
range
(
len
(
sch
[
data_o
ut
]
.
op
.
axis
))])
for
i
in
range
(
len
(
sch
[
real_outp
ut
]
.
op
.
axis
))])
bx
,
outer_in
=
sch
[
data_o
ut
]
.
split
(
fused_outer
,
factor
=
num_thread
)
bx
,
outer_in
=
sch
[
real_outp
ut
]
.
split
(
fused_outer
,
factor
=
num_thread
)
# Bind the axes to threads and blocks
# Bind the axes to threads and blocks
sch
[
data_out
]
.
bind
(
sch
[
data_out
]
.
op
.
reduce_axis
[
0
],
thread_x
)
sch
[
real_output
]
.
bind
(
outer_in
,
thread_y
)
sch
[
data_out
]
.
set_store_predicate
(
thread_x
.
equal
(
0
))
sch
[
real_output
]
.
bind
(
bx
,
block_x
)
sch
[
data_out
]
.
bind
(
outer_in
,
thread_y
)
if
is_idx_reduce
:
sch
[
data_out
]
.
bind
(
bx
,
block_x
)
sch
[
temp_idx_input
]
.
compute_at
(
sch
[
real_output
],
outer_in
)
sch
[
temp_val_input
]
.
compute_at
(
sch
[
real_output
],
outer_in
)
else
:
else
:
sch
[
data_out
]
.
bind
(
sch
[
data_out
]
.
op
.
reduce_axis
[
0
],
thread_x
)
if
is_idx_reduce
:
sch
[
temp_idx_input
]
.
compute_at
(
sch
[
real_output
],
sch
[
real_output
]
.
op
.
axis
[
0
])
sch
[
temp_val_input
]
.
compute_at
(
sch
[
real_output
],
sch
[
real_output
]
.
op
.
axis
[
0
])
sch
[
real_output
]
.
set_store_predicate
(
thread_x
.
equal
(
0
))
return
sch
return
sch
...
@@ -73,9 +93,13 @@ def schedule_reduce(outs):
...
@@ -73,9 +93,13 @@ def schedule_reduce(outs):
if
tag
.
is_broadcast
(
operator
.
tag
):
if
tag
.
is_broadcast
(
operator
.
tag
):
raise
RuntimeError
(
"Not yet support ewise after reduce"
)
raise
RuntimeError
(
"Not yet support ewise after reduce"
)
elif
operator
.
tag
==
'comm_reduce'
:
elif
operator
.
tag
==
'comm_reduce'
:
_schedule_reduce
(
operator
,
sch
)
_schedule_reduce
(
operator
,
sch
,
is_idx_reduce
=
False
)
for
tensor
in
operator
.
input_tensors
:
for
tensor
in
operator
.
input_tensors
:
traverse_before_reduce
(
tensor
.
op
)
traverse_before_reduce
(
tensor
.
op
)
elif
operator
.
tag
==
'comm_reduce_idx'
:
_schedule_reduce
(
operator
,
sch
,
is_idx_reduce
=
True
)
for
tensor
in
operator
.
input_tensors
[
0
]
.
op
.
input_tensors
:
traverse_before_reduce
(
tensor
.
op
)
else
:
else
:
raise
RuntimeError
(
"Unsupported operator:
%
s"
%
operator
.
tag
)
raise
RuntimeError
(
"Unsupported operator:
%
s"
%
operator
.
tag
)
...
...
topi/python/topi/reduction.py
View file @
4468c576
# pylint: disable=redefined-builtin,consider-using-enumerate
# pylint: disable=redefined-builtin,consider-using-enumerate
,no-member
"""Reduce operators"""
"""Reduce operators"""
from
__future__
import
absolute_import
as
_abs
from
__future__
import
absolute_import
as
_abs
import
tvm
import
tvm
from
.
import
tag
from
.
import
tag
from
.util
import
ravel_index
def
_get_real_axis
(
ndim
,
axis
):
def
_get_real_axis
(
ndim
,
axis
):
if
axis
is
None
:
if
axis
is
None
:
...
@@ -26,6 +27,20 @@ def _get_real_axis(ndim, axis):
...
@@ -26,6 +27,20 @@ def _get_real_axis(ndim, axis):
def
get_reduce_out_shape
(
src_shape
,
axis
=
None
,
keepdims
=
False
):
def
get_reduce_out_shape
(
src_shape
,
axis
=
None
,
keepdims
=
False
):
"""Get the output shape for the reduction OPs
Parameters
----------
src_shape : tuple of int or tvm.expr.IntImm
axis : None or int or tuple of int
keepdims : bool
Returns
-------
dst_shape : tuple of int or tvm.expr.IntImm
"""
real_axis
=
_get_real_axis
(
len
(
src_shape
),
axis
)
real_axis
=
_get_real_axis
(
len
(
src_shape
),
axis
)
if
keepdims
:
if
keepdims
:
dst_shape
=
[
src_shape
[
i
]
if
i
in
real_axis
else
1
for
i
in
range
(
len
(
src_shape
))]
dst_shape
=
[
src_shape
[
i
]
if
i
in
real_axis
else
1
for
i
in
range
(
len
(
src_shape
))]
...
@@ -37,8 +52,36 @@ def get_reduce_out_shape(src_shape, axis=None, keepdims=False):
...
@@ -37,8 +52,36 @@ def get_reduce_out_shape(src_shape, axis=None, keepdims=False):
return
dst_shape
return
dst_shape
@tvm.tag_scope
(
tag
=
tag
.
COMM_REDUCE
)
def
_argmax_comp
(
lhs
,
rhs
):
def
comm_reduce
(
data
,
axis
=
None
,
keepdims
=
False
,
func
=
tvm
.
sum
):
"""Compare function of argmax"""
idx
=
tvm
.
make
.
Select
((
lhs
[
1
]
>=
rhs
[
1
]),
lhs
[
0
],
rhs
[
0
])
val
=
tvm
.
make
.
Select
((
lhs
[
1
]
>=
rhs
[
1
]),
lhs
[
1
],
rhs
[
1
])
return
idx
,
val
def
_argmax_init
(
idx_typ
,
val_typ
):
"""Initial ind and val of argmax"""
return
tvm
.
const
(
-
1
,
idx_typ
),
tvm
.
min_value
(
val_typ
)
def
_argmin_comp
(
lhs
,
rhs
):
"""Compare function of argmin"""
idx
=
tvm
.
make
.
Select
((
lhs
[
1
]
<=
rhs
[
1
]),
lhs
[
0
],
rhs
[
0
])
val
=
tvm
.
make
.
Select
((
lhs
[
1
]
<=
rhs
[
1
]),
lhs
[
1
],
rhs
[
1
])
return
idx
,
val
def
_argmin_init
(
idx_typ
,
val_typ
):
"""Initial ind and val of argmax"""
return
tvm
.
const
(
-
1
,
idx_typ
),
tvm
.
max_value
(
val_typ
)
def
_choose_idx
(
idx
,
_
,
*
indices
):
"""Chose the idx from idx and val"""
return
idx
(
*
indices
)
def
comm_reduce
(
data
,
axis
=
None
,
keepdims
=
False
,
func
=
tvm
.
sum
,
is_idx_reduce
=
False
):
"""Reducing the data
"""Reducing the data
Parameters
Parameters
...
@@ -63,9 +106,22 @@ def comm_reduce(data, axis=None, keepdims=False, func=tvm.sum):
...
@@ -63,9 +106,22 @@ def comm_reduce(data, axis=None, keepdims=False, func=tvm.sum):
-------
-------
ret : tvm.Tensor
ret : tvm.Tensor
"""
"""
def
_build_reduce_compute_func
(
data
,
real_axis
,
reduce_axes
,
keepdims
,
ndim
=
len
(
data
.
shape
)
func
,
*
args
):
real_axis
=
_get_real_axis
(
ndim
,
axis
)
if
real_axis
==
list
(
range
(
ndim
))
and
keepdims
is
False
:
raise
ValueError
(
"Currently we do not support all reduce + keepdims = False!"
" axis={}, keepdims={}"
.
format
(
axis
,
keepdims
))
reduce_axes
=
[
tvm
.
reduce_axis
((
0
,
data
.
shape
[
i
]),
"k
%
d"
%
i
)
for
i
in
real_axis
]
if
keepdims
:
target_shape
=
[
1
if
i
in
real_axis
else
data
.
shape
[
i
]
for
i
in
range
(
ndim
)]
else
:
target_shape
=
[]
for
i
in
range
(
ndim
):
if
i
not
in
real_axis
:
target_shape
.
append
(
tvm
.
convert
(
data
.
shape
[
i
]))
def
_compute
(
*
indices
):
eval_range
=
[]
eval_range
=
[]
eval_indices
=
[]
if
not
keepdims
:
if
not
keepdims
:
arg_counter
=
0
arg_counter
=
0
else
:
else
:
...
@@ -74,38 +130,29 @@ def comm_reduce(data, axis=None, keepdims=False, func=tvm.sum):
...
@@ -74,38 +130,29 @@ def comm_reduce(data, axis=None, keepdims=False, func=tvm.sum):
for
i
in
range
(
len
(
data
.
shape
)):
for
i
in
range
(
len
(
data
.
shape
)):
if
i
in
real_axis
:
if
i
in
real_axis
:
eval_range
.
append
(
reduce_axes
[
red_counter
])
eval_range
.
append
(
reduce_axes
[
red_counter
])
eval_indices
.
append
(
reduce_axes
[
red_counter
]
.
var
)
red_counter
+=
1
red_counter
+=
1
else
:
else
:
if
not
keepdims
:
if
not
keepdims
:
eval_range
.
append
(
arg
s
[
arg_counter
])
eval_range
.
append
(
indice
s
[
arg_counter
])
arg_counter
+=
1
arg_counter
+=
1
else
:
else
:
eval_range
.
append
(
args
[
i
])
eval_range
.
append
(
indices
[
i
])
if
not
is_idx_reduce
:
return
func
(
data
[
tuple
(
eval_range
)],
axis
=
reduce_axes
)
return
func
(
data
[
tuple
(
eval_range
)],
axis
=
reduce_axes
)
idx
=
ravel_index
(
eval_indices
,
[
data
.
shape
[
i
]
for
i
in
real_axis
])
ndim
=
len
(
data
.
shape
)
return
func
((
idx
,
data
[
tuple
(
eval_range
)]),
axis
=
reduce_axes
)
real_axis
=
_get_real_axis
(
ndim
,
axis
)
if
is_idx_reduce
:
if
real_axis
==
list
(
range
(
ndim
))
and
keepdims
is
False
:
temp_idx
,
temp_val
=
tvm
.
compute
(
target_shape
,
_compute
,
name
=
data
.
name
+
"_red_temp"
)
raise
ValueError
(
"Currently we do not support all reduce + keepdims = False!"
" axis={}, keepdims={}"
.
format
(
axis
,
keepdims
))
reduce_axes
=
[
tvm
.
reduce_axis
((
0
,
data
.
shape
[
i
]),
"k
%
d"
%
i
)
for
i
in
real_axis
]
if
keepdims
:
target_shape
=
[
tvm
.
convert
(
1
)
if
i
in
real_axis
else
tvm
.
convert
(
data
.
shape
[
i
])
for
i
in
range
(
ndim
)]
else
:
target_shape
=
[]
for
i
in
range
(
ndim
):
if
i
not
in
real_axis
:
target_shape
.
append
(
tvm
.
convert
(
data
.
shape
[
i
]))
out
=
tvm
.
compute
(
target_shape
,
out
=
tvm
.
compute
(
target_shape
,
lambda
*
args
:
_build_reduce_compute_func
(
data
,
lambda
*
indices
:
_choose_idx
(
temp_idx
,
temp_val
,
*
indices
),
real_axis
,
reduce_axes
,
keepdims
,
func
,
*
args
),
name
=
data
.
name
+
"_red"
)
name
=
data
.
name
+
"_red"
)
else
:
out
=
tvm
.
compute
(
target_shape
,
_compute
,
name
=
data
.
name
+
"_red"
)
return
out
return
out
@tvm.tag_scope
(
tag
=
tag
.
COMM_REDUCE
)
def
sum
(
data
,
axis
=
None
,
keepdims
=
False
):
def
sum
(
data
,
axis
=
None
,
keepdims
=
False
):
"""Sum of array elements over a given axis or a list of axes
"""Sum of array elements over a given axis or a list of axes
...
@@ -131,6 +178,7 @@ def sum(data, axis=None, keepdims=False):
...
@@ -131,6 +178,7 @@ def sum(data, axis=None, keepdims=False):
return
comm_reduce
(
data
,
axis
=
axis
,
keepdims
=
keepdims
,
func
=
tvm
.
sum
)
return
comm_reduce
(
data
,
axis
=
axis
,
keepdims
=
keepdims
,
func
=
tvm
.
sum
)
@tvm.tag_scope
(
tag
=
tag
.
COMM_REDUCE
)
def
max
(
data
,
axis
=
None
,
keepdims
=
False
):
def
max
(
data
,
axis
=
None
,
keepdims
=
False
):
"""Maximum of array elements over a given axis or a list of axes
"""Maximum of array elements over a given axis or a list of axes
...
@@ -156,6 +204,7 @@ def max(data, axis=None, keepdims=False):
...
@@ -156,6 +204,7 @@ def max(data, axis=None, keepdims=False):
return
comm_reduce
(
data
,
axis
=
axis
,
keepdims
=
keepdims
,
func
=
tvm
.
max
)
return
comm_reduce
(
data
,
axis
=
axis
,
keepdims
=
keepdims
,
func
=
tvm
.
max
)
@tvm.tag_scope
(
tag
=
tag
.
COMM_REDUCE
)
def
min
(
data
,
axis
=
None
,
keepdims
=
False
):
def
min
(
data
,
axis
=
None
,
keepdims
=
False
):
"""Minimum of array elements over a given axis or a list of axes
"""Minimum of array elements over a given axis or a list of axes
...
@@ -179,3 +228,57 @@ def min(data, axis=None, keepdims=False):
...
@@ -179,3 +228,57 @@ def min(data, axis=None, keepdims=False):
ret : tvm.Tensor
ret : tvm.Tensor
"""
"""
return
comm_reduce
(
data
,
axis
=
axis
,
keepdims
=
keepdims
,
func
=
tvm
.
min
)
return
comm_reduce
(
data
,
axis
=
axis
,
keepdims
=
keepdims
,
func
=
tvm
.
min
)
@tvm.tag_scope
(
tag
=
tag
.
COMM_REDUCE_IDX
)
def
argmax
(
data
,
axis
=
None
,
keepdims
=
False
):
"""Returns the indices of the maximum values along an axis.
Parameters
----------
data : tvm.Tensor
The input tvm tensor
axis : None or int or tuple of int
Axis or axes along which a sum is performed.
The default, axis=None, will sum all of the elements of the input array.
If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.
Returns
-------
ret : tvm.Tensor
"""
_argmax
=
tvm
.
comm_reducer
(
fcombine
=
_argmax_comp
,
fidentity
=
_argmax_init
,
name
=
'argmax'
)
return
comm_reduce
(
data
,
axis
=
axis
,
keepdims
=
keepdims
,
func
=
_argmax
,
is_idx_reduce
=
True
)
@tvm.tag_scope
(
tag
=
tag
.
COMM_REDUCE_IDX
)
def
argmin
(
data
,
axis
=
None
,
keepdims
=
False
):
"""Returns the indices of the minimum values along an axis.
Parameters
----------
data : tvm.Tensor
The input tvm tensor
axis : None or int or tuple of int
Axis or axes along which a sum is performed.
The default, axis=None, will sum all of the elements of the input array.
If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.
Returns
-------
ret : tvm.Tensor
"""
_argmin
=
tvm
.
comm_reducer
(
fcombine
=
_argmin_comp
,
fidentity
=
_argmin_init
,
name
=
'argmin'
)
return
comm_reduce
(
data
,
axis
=
axis
,
keepdims
=
keepdims
,
func
=
_argmin
,
is_idx_reduce
=
True
)
topi/python/topi/tag.py
View file @
4468c576
...
@@ -31,6 +31,7 @@ ELEMWISE = "elemwise"
...
@@ -31,6 +31,7 @@ ELEMWISE = "elemwise"
BROADCAST
=
"broadcast"
BROADCAST
=
"broadcast"
INJECTIVE
=
"injective"
INJECTIVE
=
"injective"
COMM_REDUCE
=
"comm_reduce"
COMM_REDUCE
=
"comm_reduce"
COMM_REDUCE_IDX
=
"comm_reduce_idx"
def
is_broadcast
(
tag
):
def
is_broadcast
(
tag
):
...
...
topi/tests/python/test_topi_reduce.py
View file @
4468c576
...
@@ -4,20 +4,48 @@ import numpy as np
...
@@ -4,20 +4,48 @@ import numpy as np
import
tvm
import
tvm
import
topi
import
topi
def
_my_npy_argmax
(
arr
,
axis
,
keepdims
):
if
not
keepdims
:
return
arr
.
argmax
(
axis
=
axis
)
else
:
if
axis
is
not
None
:
out_shape
=
list
(
arr
.
shape
)
out_shape
[
axis
]
=
1
else
:
out_shape
=
[
1
for
_
in
range
(
len
(
arr
.
shape
))]
return
arr
.
argmax
(
axis
=
axis
)
.
reshape
(
out_shape
)
def
_my_npy_argmin
(
arr
,
axis
,
keepdims
):
if
not
keepdims
:
return
arr
.
argmin
(
axis
=
axis
)
else
:
out_shape
=
list
(
arr
.
shape
)
out_shape
[
axis
]
=
1
return
arr
.
argmin
(
axis
=
axis
)
.
reshape
(
out_shape
)
def
verify_reduce_map_ele
(
in_shape
,
axis
,
keepdims
,
type
=
"sum"
):
def
verify_reduce_map_ele
(
in_shape
,
axis
,
keepdims
,
type
=
"sum"
):
# Build the logic and compile the function
# Build the logic and compile the function
A
=
tvm
.
placeholder
(
shape
=
in_shape
,
name
=
"A"
)
dat_dtype
=
"float32"
A
=
tvm
.
placeholder
(
shape
=
in_shape
,
name
=
"A"
,
dtype
=
dat_dtype
)
A1
=
topi
.
sqrt
(
topi
.
exp
(
A
))
A1
=
topi
.
sqrt
(
topi
.
exp
(
A
))
out_dtype
=
"float32"
if
type
==
"sum"
:
if
type
==
"sum"
:
B
=
topi
.
sum
(
A1
,
axis
=
axis
,
keepdims
=
keepdims
)
B
=
topi
.
sum
(
A1
,
axis
=
axis
,
keepdims
=
keepdims
)
elif
type
==
"max"
:
elif
type
==
"max"
:
B
=
topi
.
max
(
A1
,
axis
=
axis
,
keepdims
=
keepdims
)
B
=
topi
.
max
(
A1
,
axis
=
axis
,
keepdims
=
keepdims
)
elif
type
==
"min"
:
elif
type
==
"min"
:
B
=
topi
.
min
(
A1
,
axis
=
axis
,
keepdims
=
keepdims
)
B
=
topi
.
min
(
A1
,
axis
=
axis
,
keepdims
=
keepdims
)
elif
type
==
"argmax"
:
B
=
topi
.
argmax
(
A1
,
axis
=
axis
,
keepdims
=
keepdims
)
out_dtype
=
"int32"
elif
type
==
"argmin"
:
B
=
topi
.
argmin
(
A1
,
axis
=
axis
,
keepdims
=
keepdims
)
out_dtype
=
"int32"
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
s
=
topi
.
cuda
.
schedule_reduce
(
B
)
s
=
topi
.
cuda
.
schedule_reduce
(
B
)
def
check_device
(
device
):
def
check_device
(
device
):
if
not
tvm
.
module
.
enabled
(
device
):
if
not
tvm
.
module
.
enabled
(
device
):
print
(
"Skip because
%
s is not enabled"
%
device
)
print
(
"Skip because
%
s is not enabled"
%
device
)
...
@@ -26,18 +54,21 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
...
@@ -26,18 +54,21 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
foo
=
tvm
.
build
(
s
,
[
A
,
B
],
device
,
name
=
"sum"
)
foo
=
tvm
.
build
(
s
,
[
A
,
B
],
device
,
name
=
"sum"
)
# Test
# Test
in_npy
=
np
.
random
.
uniform
(
size
=
in_shape
)
.
astype
(
np
.
float32
)
in_npy
=
np
.
random
.
uniform
(
size
=
in_shape
)
.
astype
(
np
.
float32
)
in_npy_map
=
np
.
sqrt
(
np
.
exp
(
in_npy
))
in_npy_map
=
np
.
sqrt
(
np
.
exp
(
in_npy
))
.
astype
(
np
.
float32
)
if
type
==
"sum"
:
if
type
==
"sum"
:
out_npy
=
in_npy_map
.
sum
(
axis
=
axis
,
keepdims
=
keepdims
)
out_npy
=
in_npy_map
.
sum
(
axis
=
axis
,
keepdims
=
keepdims
)
elif
type
==
"max"
:
elif
type
==
"max"
:
out_npy
=
in_npy_map
.
max
(
axis
=
axis
,
keepdims
=
keepdims
)
out_npy
=
in_npy_map
.
max
(
axis
=
axis
,
keepdims
=
keepdims
)
elif
type
==
"min"
:
elif
type
==
"min"
:
out_npy
=
in_npy_map
.
min
(
axis
=
axis
,
keepdims
=
keepdims
)
out_npy
=
in_npy_map
.
min
(
axis
=
axis
,
keepdims
=
keepdims
)
elif
type
==
"argmax"
:
out_npy
=
_my_npy_argmax
(
in_npy_map
,
axis
=
axis
,
keepdims
=
keepdims
)
elif
type
==
"argmin"
:
out_npy
=
_my_npy_argmin
(
in_npy_map
,
axis
=
axis
,
keepdims
=
keepdims
)
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
data_tvm
=
tvm
.
nd
.
array
(
in_npy
,
ctx
=
ctx
)
data_tvm
=
tvm
.
nd
.
array
(
in_npy
,
ctx
=
ctx
)
out_tvm
=
tvm
.
nd
.
empty
(
shape
=
out_npy
.
shape
,
ctx
=
ctx
)
out_tvm
=
tvm
.
nd
.
empty
(
shape
=
out_npy
.
shape
,
ctx
=
ctx
,
dtype
=
out_dtype
)
for
_
in
range
(
1
):
for
_
in
range
(
1
):
foo
(
data_tvm
,
out_tvm
)
foo
(
data_tvm
,
out_tvm
)
np
.
testing
.
assert_allclose
(
out_tvm
.
asnumpy
(),
out_npy
,
1E-3
,
1E-3
)
np
.
testing
.
assert_allclose
(
out_tvm
.
asnumpy
(),
out_npy
,
1E-3
,
1E-3
)
...
@@ -64,6 +95,19 @@ def test_reduce_map():
...
@@ -64,6 +95,19 @@ def test_reduce_map():
axis
=
(
0
,
2
),
axis
=
(
0
,
2
),
keepdims
=
False
,
keepdims
=
False
,
type
=
"min"
)
type
=
"min"
)
verify_reduce_map_ele
(
in_shape
=
(
32
,
128
),
axis
=
1
,
keepdims
=
True
,
type
=
"argmax"
)
verify_reduce_map_ele
(
in_shape
=
(
32
,
24
,
32
,
24
),
axis
=
2
,
keepdims
=
False
,
type
=
"argmin"
)
verify_reduce_map_ele
(
in_shape
=
(
31
,
21
,
15
),
axis
=
None
,
keepdims
=
True
,
type
=
"argmax"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_reduce_map
()
test_reduce_map
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment