Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
3708b311
Commit
3708b311
authored
Dec 30, 2018
by
masahi
Committed by
Tianqi Chen
Dec 29, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Update cuda softmax schedule for spatial inputs (#2338)
parent
f6c3f997
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
47 additions
and
44 deletions
+47
-44
topi/python/topi/cuda/softmax.py
+19
-15
topi/tests/python/test_topi_softmax.py
+28
-29
No files found.
topi/python/topi/cuda/softmax.py
View file @
3708b311
...
...
@@ -2,6 +2,7 @@
"""Schedule for softmax operator"""
import
tvm
from
..
import
generic
from
.injective
import
_schedule_injective
@generic.schedule_softmax.register
([
"cuda"
,
"gpu"
])
def
schedule_softmax
(
outs
):
...
...
@@ -24,21 +25,24 @@ def schedule_softmax(outs):
max_elem
=
softmax
.
op
.
input_tensors
[
1
]
expsum
=
softmax
.
op
.
input_tensors
[
2
]
num_thread
=
64
block_x
=
tvm
.
thread_axis
(
"blockIdx.x"
)
thread_x
=
tvm
.
thread_axis
((
0
,
num_thread
),
"threadIdx.x"
)
if
len
(
softmax
.
shape
)
>
2
:
for
op
in
[
max_elem
.
op
,
expsum
.
op
,
softmax
.
op
]:
s
=
_schedule_injective
(
op
,
s
)
else
:
num_thread
=
64
block_x
=
tvm
.
thread_axis
(
"blockIdx.x"
)
thread_x
=
tvm
.
thread_axis
((
0
,
num_thread
),
"threadIdx.x"
)
s
[
max_elem
]
.
bind
(
max_elem
.
op
.
axis
[
0
],
block_x
)
k
=
expsum
.
op
.
reduce_axis
[
0
]
ko
,
ki
=
s
[
expsum
]
.
split
(
k
,
factor
=
num_thread
)
EF
=
s
.
rfactor
(
expsum
,
ki
)
s
[
expsum
]
.
bind
(
s
[
expsum
]
.
op
.
axis
[
0
],
block_x
)
s
[
expsum
]
.
bind
(
s
[
expsum
]
.
op
.
reduce_axis
[
0
],
thread_x
)
s
[
EF
]
.
compute_at
(
s
[
expsum
],
s
[
expsum
]
.
op
.
reduce_axis
[
0
])
s
[
expsum
]
.
set_store_predicate
(
thread_x
.
var
.
equal
(
0
))
tx
,
xi
=
s
[
softmax
]
.
split
(
softmax
.
op
.
axis
[
1
],
nparts
=
num_thread
)
s
[
softmax
]
.
bind
(
softmax
.
op
.
axis
[
0
],
block_x
)
s
[
softmax
]
.
bind
(
tx
,
thread_x
)
s
[
max_elem
]
.
bind
(
max_elem
.
op
.
axis
[
0
],
block_x
)
k
=
expsum
.
op
.
reduce_axis
[
0
]
ko
,
ki
=
s
[
expsum
]
.
split
(
k
,
factor
=
num_thread
)
EF
=
s
.
rfactor
(
expsum
,
ki
)
s
[
expsum
]
.
bind
(
s
[
expsum
]
.
op
.
axis
[
0
],
block_x
)
s
[
expsum
]
.
bind
(
s
[
expsum
]
.
op
.
reduce_axis
[
0
],
thread_x
)
s
[
EF
]
.
compute_at
(
s
[
expsum
],
s
[
expsum
]
.
op
.
reduce_axis
[
0
])
s
[
expsum
]
.
set_store_predicate
(
thread_x
.
var
.
equal
(
0
))
tx
,
xi
=
s
[
softmax
]
.
split
(
softmax
.
op
.
axis
[
1
],
nparts
=
num_thread
)
s
[
softmax
]
.
bind
(
softmax
.
op
.
axis
[
0
],
block_x
)
s
[
softmax
]
.
bind
(
tx
,
thread_x
)
return
s
topi/tests/python/test_topi_softmax.py
View file @
3708b311
...
...
@@ -9,6 +9,21 @@ from topi.util import get_const_tuple
from
common
import
get_all_backend
def
check_device
(
A
,
B
,
a_np
,
b_np
,
device
,
name
):
ctx
=
tvm
.
context
(
device
,
0
)
if
not
ctx
.
exist
:
print
(
"Skip because
%
s is not enabled"
%
device
)
return
print
(
"Running on target:
%
s"
%
device
)
with
tvm
.
target
.
create
(
device
):
s
=
topi
.
generic
.
schedule_softmax
(
B
)
a
=
tvm
.
nd
.
array
(
a_np
,
ctx
)
b
=
tvm
.
nd
.
array
(
np
.
zeros
(
get_const_tuple
(
B
.
shape
),
dtype
=
B
.
dtype
),
ctx
)
f
=
tvm
.
build
(
s
,
[
A
,
B
],
device
,
name
=
"softmax"
)
f
(
a
,
b
)
tvm
.
testing
.
assert_allclose
(
b
.
asnumpy
(),
b_np
,
rtol
=
1e-5
)
def
verify_softmax
(
m
,
n
,
dtype
=
"float32"
):
A
=
tvm
.
placeholder
((
m
,
n
),
dtype
=
dtype
,
name
=
'A'
)
B
=
topi
.
nn
.
softmax
(
A
)
...
...
@@ -19,28 +34,26 @@ def verify_softmax(m, n, dtype="float32"):
a_np
=
np
.
random
.
uniform
(
size
=
get_const_tuple
(
A
.
shape
))
.
astype
(
A
.
dtype
)
b_np
=
topi
.
testing
.
softmax_python
(
a_np
)
def
check_device
(
device
):
ctx
=
tvm
.
context
(
device
,
0
)
if
not
ctx
.
exist
:
print
(
"Skip because
%
s is not enabled"
%
device
)
return
print
(
"Running on target:
%
s"
%
device
)
with
tvm
.
target
.
create
(
device
):
s
=
topi
.
generic
.
schedule_softmax
(
B
)
for
device
in
[
'cuda'
,
'opencl'
,
'metal'
,
'rocm'
,
'vulkan'
,
'nvptx'
]:
check_device
(
A
,
B
,
a_np
,
b_np
,
device
,
"softmax"
)
def
verify_softmax_4d
(
shape
,
dtype
=
"float32"
):
A
=
tvm
.
placeholder
(
shape
,
dtype
=
dtype
,
name
=
'A'
)
B
=
topi
.
nn
.
softmax
(
A
,
axis
=
1
)
a
=
tvm
.
nd
.
array
(
a_np
,
ctx
)
b
=
tvm
.
nd
.
array
(
np
.
zeros
(
get_const_tuple
(
B
.
shape
),
dtype
=
B
.
dtype
),
ctx
)
foo
=
tvm
.
build
(
s
,
[
A
,
B
],
device
,
name
=
"softmax"
)
foo
(
a
,
b
)
tvm
.
testing
.
assert_allclose
(
b
.
asnumpy
(),
b_np
,
rtol
=
1e-5
)
_
,
c
,
h
,
w
=
shape
a_np
=
np
.
random
.
uniform
(
size
=
get_const_tuple
(
A
.
shape
))
.
astype
(
A
.
dtype
)
b_np
=
topi
.
testing
.
softmax_python
(
a_np
.
transpose
(
0
,
2
,
3
,
1
)
.
reshape
(
h
*
w
,
c
))
b_np
=
b_np
.
reshape
(
1
,
h
,
w
,
c
)
.
transpose
(
0
,
3
,
1
,
2
)
for
device
in
[
'cuda'
,
'opencl'
,
'metal'
,
'rocm'
,
'vulkan'
,
'nvptx'
]:
check_device
(
device
)
check_device
(
A
,
B
,
a_np
,
b_np
,
device
,
"softmax"
)
def
test_softmax
():
verify_softmax
(
32
,
10
)
verify_softmax
(
3
,
4
)
verify_softmax
(
32
,
10
,
"float64"
)
verify_softmax_4d
((
1
,
16
,
256
,
256
))
def
verify_log_softmax
(
m
,
n
,
dtype
=
"float32"
):
A
=
tvm
.
placeholder
((
m
,
n
),
dtype
=
dtype
,
name
=
'A'
)
...
...
@@ -51,22 +64,8 @@ def verify_log_softmax(m, n, dtype="float32"):
a_np
=
np
.
random
.
uniform
(
size
=
get_const_tuple
(
A
.
shape
))
.
astype
(
A
.
dtype
)
b_np
=
topi
.
testing
.
log_softmax_python
(
a_np
)
def
check_device
(
device
):
ctx
=
tvm
.
context
(
device
,
0
)
if
not
ctx
.
exist
:
print
(
"Skip because
%
s is not enabled"
%
device
)
return
print
(
"Running on target:
%
s"
%
device
)
with
tvm
.
target
.
create
(
device
):
s
=
topi
.
generic
.
schedule_softmax
(
B
)
a
=
tvm
.
nd
.
array
(
a_np
,
ctx
)
b
=
tvm
.
nd
.
array
(
np
.
zeros
(
get_const_tuple
(
B
.
shape
),
dtype
=
B
.
dtype
),
ctx
)
foo
=
tvm
.
build
(
s
,
[
A
,
B
],
device
,
name
=
"log_softmax"
)
foo
(
a
,
b
)
tvm
.
testing
.
assert_allclose
(
b
.
asnumpy
(),
b_np
,
rtol
=
1e-5
)
for
device
in
get_all_backend
():
check_device
(
device
)
check_device
(
A
,
B
,
a_np
,
b_np
,
device
,
"log_softmax"
)
def
test_log_softmax
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment