Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
65f87264
Commit
65f87264
authored
Sep 24, 2017
by
Leyuan Wang
Committed by
Tianqi Chen
Sep 24, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
log_softmax added to topi (#483)
parent
489ec872
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
79 additions
and
3 deletions
+79
-3
topi/python/topi/nn/softmax.py
+26
-1
topi/python/topi/testing/__init__.py
+1
-1
topi/python/topi/testing/softmax_python.py
+21
-1
topi/tests/python/test_topi_softmax.py
+31
-0
No files found.
topi/python/topi/nn/softmax.py
View file @
65f87264
# pylint: disable=invalid-name
"""TVM operator softmax compute."""
"""TVM operator
for softmax and log_
softmax compute."""
from
__future__
import
absolute_import
import
tvm
...
...
@@ -26,3 +26,28 @@ def softmax(x):
(
m
,
),
lambda
i
:
tvm
.
sum
(
tvm
.
exp
(
x
[
i
,
k
]
-
max_elem
[
i
]),
axis
=
k
))
return
tvm
.
compute
(
x
.
shape
,
lambda
i
,
j
:
tvm
.
exp
(
x
[
i
,
j
]
-
max_elem
[
i
])
/
expsum
[
i
])
@tvm.tag_scope
(
tag
=
'log_softmax_output'
)
def
log_softmax
(
x
):
"""Perform log softmax activation on the data
Parameters
----------
data : tvm.Tensor
2-D input data
Returns
-------
output : tvm.Tensor
2-D output with same shape
"""
assert
len
(
x
.
shape
)
==
2
,
"only support 2-dim log softmax"
m
,
n
=
x
.
shape
k
=
tvm
.
reduce_axis
((
0
,
n
),
name
=
'k'
)
max_elem
=
tvm
.
compute
((
m
,
),
lambda
i
:
tvm
.
max
(
x
[
i
,
k
],
axis
=
k
))
k
=
tvm
.
reduce_axis
((
0
,
n
),
name
=
'k'
)
expsum
=
tvm
.
compute
(
(
m
,
),
lambda
i
:
tvm
.
sum
(
tvm
.
exp
(
x
[
i
,
k
]
-
max_elem
[
i
]),
axis
=
k
))
return
tvm
.
compute
(
x
.
shape
,
lambda
i
,
j
:
x
[
i
,
j
]
-
max_elem
[
i
]
-
tvm
.
log
(
expsum
[
i
]))
topi/python/topi/testing/__init__.py
View file @
65f87264
...
...
@@ -8,4 +8,4 @@ from .conv2d_hwcn_python import conv2d_hwcn_python
from
.conv2d_nchw_python
import
conv2d_nchw_python
from
.depthwise_conv2d_python
import
depthwise_conv2d_python_nchw
,
depthwise_conv2d_python_nhwc
from
.dilate_python
import
dilate_python
from
.softmax_python
import
softmax_python
from
.softmax_python
import
softmax_python
,
log_softmax_python
topi/python/topi/testing/softmax_python.py
View file @
65f87264
# pylint: disable=invalid-name, trailing-whitespace
"""Softmax operation in python"""
"""Softmax
and log_softmax
operation in python"""
import
numpy
as
np
def
softmax_python
(
a_np
):
...
...
@@ -21,3 +21,23 @@ def softmax_python(a_np):
expsum
=
np
.
sum
(
e
,
axis
=
1
)
out_np
=
e
/
expsum
[:,
None
]
return
out_np
def
log_softmax_python
(
a_np
):
"""Log_softmax operator.
Parameters
----------
a_np : numpy.ndarray
2-D input data
Returns
-------
output_np : numpy.ndarray
2-D output with same shape
"""
assert
len
(
a_np
.
shape
)
==
2
,
"only support 2-dim log_softmax"
max_elem
=
np
.
amax
(
a_np
,
axis
=
1
)
max_elem
=
max_elem
.
reshape
(
max_elem
.
shape
[
0
],
1
)
e
=
np
.
exp
(
a_np
-
max_elem
)
expsum
=
np
.
sum
(
e
,
axis
=
1
)
out_np
=
a_np
-
max_elem
-
np
.
log
(
expsum
[:,
None
])
return
out_np
topi/tests/python/test_topi_softmax.py
View file @
65f87264
...
...
@@ -36,5 +36,36 @@ def test_softmax():
verify_softmax
(
3
,
4
)
def
verify_log_softmax
(
m
,
n
):
A
=
tvm
.
placeholder
((
m
,
n
),
name
=
'A'
)
B
=
topi
.
nn
.
log_softmax
(
A
)
# confirm lower works
s
=
tvm
.
create_schedule
([
B
.
op
])
tvm
.
lower
(
s
,
[
A
,
B
],
simple_mode
=
True
)
s
=
topi
.
cuda
.
schedule_softmax
(
B
)
a_np
=
np
.
random
.
uniform
(
size
=
get_const_tuple
(
A
.
shape
))
.
astype
(
A
.
dtype
)
b_np
=
topi
.
testing
.
log_softmax_python
(
a_np
)
def
check_device
(
device
):
if
not
tvm
.
module
.
enabled
(
device
):
print
(
"Skip because
%
s is not enabled"
%
device
)
return
ctx
=
tvm
.
gpu
(
0
)
if
device
==
"cuda"
else
tvm
.
cl
(
0
)
a
=
tvm
.
nd
.
array
(
a_np
,
ctx
)
b
=
tvm
.
nd
.
array
(
np
.
zeros
(
get_const_tuple
(
B
.
shape
),
dtype
=
B
.
dtype
),
ctx
)
foo
=
tvm
.
build
(
s
,
[
A
,
B
],
device
,
name
=
"log_softmax"
)
foo
(
a
,
b
)
np
.
testing
.
assert_allclose
(
b
.
asnumpy
(),
b_np
,
rtol
=
1e-5
)
for
device
in
[
'cuda'
,
'opencl'
,
'metal'
]:
check_device
(
device
)
def
test_log_softmax
():
verify_log_softmax
(
32
,
10
)
verify_log_softmax
(
3
,
4
)
if
__name__
==
"__main__"
:
test_softmax
()
test_log_softmax
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment