Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
165aa0db
Commit
165aa0db
authored
Jun 05, 2019
by
hlu1
Committed by
Tianqi Chen
Jun 05, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fast tanh (#3255)
parent
29b0b4c1
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
90 additions
and
14 deletions
+90
-14
topi/include/topi/elemwise.h
+71
-3
topi/tests/python/test_topi_math.py
+19
-11
No files found.
topi/include/topi/elemwise.h
View file @
165aa0db
...
@@ -6,9 +6,9 @@
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
@@ -31,6 +31,7 @@
...
@@ -31,6 +31,7 @@
#include "tvm/tvm.h"
#include "tvm/tvm.h"
#include "tvm/ir.h"
#include "tvm/ir.h"
#include "tvm/ir_pass.h"
#include "tvm/ir_pass.h"
#include "broadcast.h"
namespace
topi
{
namespace
topi
{
using
namespace
tvm
;
using
namespace
tvm
;
...
@@ -46,7 +47,6 @@ using namespace tvm;
...
@@ -46,7 +47,6 @@ using namespace tvm;
}
}
TOPI_DECLARE_UNARY_OP
(
exp
);
TOPI_DECLARE_UNARY_OP
(
exp
);
TOPI_DECLARE_UNARY_OP
(
tanh
);
TOPI_DECLARE_UNARY_OP
(
sigmoid
);
TOPI_DECLARE_UNARY_OP
(
sigmoid
);
TOPI_DECLARE_UNARY_OP
(
sqrt
);
TOPI_DECLARE_UNARY_OP
(
sqrt
);
TOPI_DECLARE_UNARY_OP
(
log
);
TOPI_DECLARE_UNARY_OP
(
log
);
...
@@ -56,6 +56,74 @@ TOPI_DECLARE_UNARY_OP(round);
...
@@ -56,6 +56,74 @@ TOPI_DECLARE_UNARY_OP(round);
TOPI_DECLARE_UNARY_OP
(
trunc
);
TOPI_DECLARE_UNARY_OP
(
trunc
);
TOPI_DECLARE_UNARY_OP
(
abs
);
TOPI_DECLARE_UNARY_OP
(
abs
);
/*
* \brief Fast_tanh_float implementation from Eigen
* https://github.com/eigenteam/eigen-git-mirror/blob/master/Eigen/src/Core/MathFunctionsImpl.h#L26
*/
inline
Tensor
fast_tanh_float
(
const
Tensor
&
in
,
std
::
string
name
,
std
::
string
tag
)
{
// Clamp the inputs to the range [-9, 9] since anything outside
// this range is +/-1.0f in single-precision.
auto
x
=
maximum
(
minimum
(
in
,
make_const
(
in
->
dtype
,
9
.
0
)),
make_const
(
in
->
dtype
,
-
9
.
0
));
// The monomial coefficients of the numerator polynomial (odd).
auto
alpha_1
=
make_const
(
in
->
dtype
,
4.89352455891786e-03
);
auto
alpha_3
=
make_const
(
in
->
dtype
,
6.37261928875436e-04
);
auto
alpha_5
=
make_const
(
in
->
dtype
,
1.48572235717979e-05
);
auto
alpha_7
=
make_const
(
in
->
dtype
,
5.12229709037114e-08
);
auto
alpha_9
=
make_const
(
in
->
dtype
,
-
8.60467152213735e-11
);
auto
alpha_11
=
make_const
(
in
->
dtype
,
2.00018790482477e-13
);
auto
alpha_13
=
make_const
(
in
->
dtype
,
-
2.76076847742355e-16
);
// The monomial coefficients of the denominator polynomial (even).
auto
beta_0
=
make_const
(
in
->
dtype
,
4.89352518554385e-03
);
auto
beta_2
=
make_const
(
in
->
dtype
,
2.26843463243900e-03
);
auto
beta_4
=
make_const
(
in
->
dtype
,
1.18534705686654e-04
);
auto
beta_6
=
make_const
(
in
->
dtype
,
1.19825839466702e-06
);
return
compute
(
x
->
shape
,
[
&
](
const
Array
<
Var
>&
i
)
{
auto
x2
=
x
(
i
)
*
x
(
i
);
auto
p
=
x2
*
alpha_13
+
alpha_11
;
p
=
x2
*
p
+
alpha_9
;
p
=
x2
*
p
+
alpha_7
;
p
=
x2
*
p
+
alpha_5
;
p
=
x2
*
p
+
alpha_3
;
p
=
x2
*
p
+
alpha_1
;
p
=
x
(
i
)
*
p
;
auto
q
=
x2
*
beta_6
+
beta_4
;
q
=
x2
*
q
+
beta_2
;
q
=
x2
*
q
+
beta_0
;
return
p
/
q
;
},
name
,
tag
);
}
/*!
* \brief Creates an operation that returns hyperbolic tanh of a given tensor
*
* \param x The input tensor
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is tanh
*/
inline
Tensor
tanh
(
const
Tensor
&
x
,
std
::
string
name
=
"T_tanh"
,
std
::
string
tag
=
kElementWise
)
{
if
(
x
->
dtype
==
Float
(
32
))
{
// invoke fast_tanh_float implementation
return
fast_tanh_float
(
x
,
name
,
tag
);
}
else
{
// fallback to default implementation
return
compute
(
x
->
shape
,
[
&
](
const
Array
<
Var
>&
i
)
{
return
::
tvm
::
tanh
(
x
(
i
));
},
name
,
tag
);
}
}
/*!
/*!
* \brief Creates an operation that returns identity of a given tensor
* \brief Creates an operation that returns identity of a given tensor
*
*
...
...
topi/tests/python/test_topi_math.py
View file @
165aa0db
...
@@ -29,13 +29,21 @@ def test_util():
...
@@ -29,13 +29,21 @@ def test_util():
def
test_ewise
():
def
test_ewise
():
m
=
tvm
.
var
(
'm'
)
def
test_apply
(
l
=
tvm
.
var
(
'l'
)
func
,
A
=
tvm
.
placeholder
((
m
,
l
),
name
=
'A'
)
name
,
f_numpy
,
low
,
high
,
shape
=
(
20
,
3
),
dtype
=
tvm
.
float32
,
check_round
=
False
,
skip_name_check
=
False
,
):
m
=
tvm
.
var
(
"m"
)
l
=
tvm
.
var
(
"l"
)
A
=
tvm
.
placeholder
((
m
,
l
),
dtype
=
dtype
,
name
=
"A"
)
shape
=
(
20
,
3
)
def
test_apply
(
func
,
name
,
f_numpy
,
low
,
high
,
check_round
=
False
,
skip_name_check
=
False
):
B
=
func
(
A
)
B
=
func
(
A
)
assert
tuple
(
B
.
shape
)
==
tuple
(
A
.
shape
)
assert
tuple
(
B
.
shape
)
==
tuple
(
A
.
shape
)
if
not
skip_name_check
:
if
not
skip_name_check
:
...
@@ -63,7 +71,6 @@ def test_ewise():
...
@@ -63,7 +71,6 @@ def test_ewise():
for
device
in
get_all_backend
():
for
device
in
get_all_backend
():
check_device
(
device
)
check_device
(
device
)
test_apply
(
topi
.
floor
,
"floor"
,
np
.
floor
,
-
100
,
100
)
test_apply
(
topi
.
floor
,
"floor"
,
np
.
floor
,
-
100
,
100
)
test_apply
(
topi
.
ceil
,
"ceil"
,
np
.
ceil
,
-
100
,
100
)
test_apply
(
topi
.
ceil
,
"ceil"
,
np
.
ceil
,
-
100
,
100
)
test_apply
(
topi
.
sign
,
"sign"
,
np
.
sign
,
-
100
,
100
,
skip_name_check
=
True
)
test_apply
(
topi
.
sign
,
"sign"
,
np
.
sign
,
-
100
,
100
,
skip_name_check
=
True
)
...
@@ -71,11 +78,12 @@ def test_ewise():
...
@@ -71,11 +78,12 @@ def test_ewise():
test_apply
(
topi
.
abs
,
"fabs"
,
np
.
abs
,
-
100
,
100
)
test_apply
(
topi
.
abs
,
"fabs"
,
np
.
abs
,
-
100
,
100
)
test_apply
(
topi
.
round
,
"round"
,
np
.
round
,
-
100
,
100
,
check_round
=
True
)
test_apply
(
topi
.
round
,
"round"
,
np
.
round
,
-
100
,
100
,
check_round
=
True
)
test_apply
(
topi
.
exp
,
"exp"
,
np
.
exp
,
-
1
,
1
)
test_apply
(
topi
.
exp
,
"exp"
,
np
.
exp
,
-
1
,
1
)
test_apply
(
topi
.
tanh
,
"tanh"
,
np
.
tanh
,
-
10
,
10
)
test_apply
(
topi
.
tanh
,
"tanh"
,
np
.
tanh
,
-
10
,
10
,
shape
=
(
128
,
128
))
test_apply
(
topi
.
sigmoid
,
"sigmoid"
,
lambda
x
:
1
/
(
1
+
np
.
exp
(
-
x
)),
-
1
,
1
)
test_apply
(
topi
.
tanh
,
"tanh"
,
np
.
tanh
,
-
10
,
10
,
shape
=
(
128
,
128
),
dtype
=
"float64"
)
test_apply
(
topi
.
sigmoid
,
"sigmoid"
,
lambda
x
:
1
/
(
1
+
np
.
exp
(
-
x
)),
-
1
,
1
)
test_apply
(
topi
.
log
,
"log"
,
np
.
log
,
0
,
100
)
test_apply
(
topi
.
log
,
"log"
,
np
.
log
,
0
,
100
)
test_apply
(
topi
.
sqrt
,
"sqrt"
,
np
.
sqrt
,
0
,
100
)
test_apply
(
topi
.
sqrt
,
"sqrt"
,
np
.
sqrt
,
0
,
100
)
test_apply
(
topi
.
rsqrt
,
"rsqrt"
,
lambda
x
:
np
.
ones_like
(
x
)
/
np
.
sqrt
(
x
),
0
,
100
,
skip_name_check
=
True
)
test_apply
(
topi
.
rsqrt
,
"rsqrt"
,
lambda
x
:
np
.
ones_like
(
x
)
/
np
.
sqrt
(
x
),
0
,
100
,
skip_name_check
=
True
)
def
test_cast
():
def
test_cast
():
...
@@ -93,7 +101,7 @@ def test_cast():
...
@@ -93,7 +101,7 @@ def test_cast():
b_np
=
a_np
.
astype
(
to_dtype
)
b_np
=
a_np
.
astype
(
to_dtype
)
for
device
in
get_all_backend
():
for
device
in
get_all_backend
():
ctx
=
tvm
.
context
(
device
,
0
)
ctx
=
tvm
.
context
(
device
,
0
)
if
not
ctx
.
exist
:
if
not
ctx
.
exist
:
print
(
"Skip because
%
s is not enabled"
%
device
)
print
(
"Skip because
%
s is not enabled"
%
device
)
continue
continue
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment