Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
f5b02fdb
Unverified
Commit
f5b02fdb
authored
Apr 07, 2020
by
Haichen Shen
Committed by
GitHub
Apr 07, 2020
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay][OP] Add fast_erf implementation (#5241)
* add fast erf * doc * lint * fix * fix indent
parent
869b718a
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
124 additions
and
6 deletions
+124
-6
include/tvm/target/generic_func.h
+1
-1
python/tvm/relay/op/_tensor.py
+2
-0
src/relay/op/tensor/unary.cc
+11
-0
src/relay/transforms/fast_math.cc
+4
-0
src/relay/transforms/pattern_util.h
+5
-0
tests/python/relay/test_op_fast_math.py
+3
-0
topi/include/topi/elemwise.h
+72
-1
topi/python/topi/math.py
+16
-0
topi/src/elemwise.cc
+5
-0
topi/tests/python/test_topi_math.py
+5
-4
No files found.
include/tvm/target/generic_func.h
View file @
f5b02fdb
...
@@ -72,7 +72,7 @@ class GenericFunc : public ObjectRef {
...
@@ -72,7 +72,7 @@ class GenericFunc : public ObjectRef {
*
*
* \code
* \code
* // Example code on how to call generic function
* // Example code on how to call generic function
* void CallGene
ir
c(GenericFunc f) {
* void CallGene
ri
c(GenericFunc f) {
* // call like normal functions by pass in arguments
* // call like normal functions by pass in arguments
* // return value is automatically converted back
* // return value is automatically converted back
* int rvalue = f(1, 2.0);
* int rvalue = f(1, 2.0);
...
...
python/tvm/relay/op/_tensor.py
View file @
f5b02fdb
...
@@ -76,6 +76,7 @@ register_injective_schedule("shape_of")
...
@@ -76,6 +76,7 @@ register_injective_schedule("shape_of")
register_injective_schedule
(
"ndarray_size"
)
register_injective_schedule
(
"ndarray_size"
)
register_broadcast_schedule
(
"fast_exp"
)
register_broadcast_schedule
(
"fast_exp"
)
register_broadcast_schedule
(
"fast_tanh"
)
register_broadcast_schedule
(
"fast_tanh"
)
register_broadcast_schedule
(
"fast_erf"
)
# zeros
# zeros
...
@@ -222,3 +223,4 @@ register_shape_func("exp", False, elemwise_shape_func)
...
@@ -222,3 +223,4 @@ register_shape_func("exp", False, elemwise_shape_func)
register_shape_func
(
"tan"
,
False
,
elemwise_shape_func
)
register_shape_func
(
"tan"
,
False
,
elemwise_shape_func
)
register_shape_func
(
"fast_exp"
,
False
,
elemwise_shape_func
)
register_shape_func
(
"fast_exp"
,
False
,
elemwise_shape_func
)
register_shape_func
(
"fast_tanh"
,
False
,
elemwise_shape_func
)
register_shape_func
(
"fast_tanh"
,
False
,
elemwise_shape_func
)
register_shape_func
(
"fast_erf"
,
False
,
elemwise_shape_func
)
src/relay/op/tensor/unary.cc
View file @
f5b02fdb
...
@@ -128,6 +128,17 @@ RELAY_REGISTER_UNARY_OP("erf")
...
@@ -128,6 +128,17 @@ RELAY_REGISTER_UNARY_OP("erf")
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
RELAY_UNARY_COMPUTE
(
topi
::
erf
));
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
RELAY_UNARY_COMPUTE
(
topi
::
erf
));
RELAY_REGISTER_UNARY_OP
(
"fast_erf"
)
.
describe
(
R"code(Returns the error function value for input array, computed element-wise.
.. math::
\fast_erf(x)
)code"
TVM_ADD_FILELINE
)
.
set_support_level
(
1
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
RELAY_UNARY_COMPUTE
(
topi
::
fast_erf
));
RELAY_REGISTER_UNARY_OP
(
"sqrt"
)
RELAY_REGISTER_UNARY_OP
(
"sqrt"
)
.
describe
(
R"code(Returns the sqrt input array, computed element-wise.
.
describe
(
R"code(Returns the sqrt input array, computed element-wise.
...
...
src/relay/transforms/fast_math.cc
View file @
f5b02fdb
...
@@ -35,11 +35,14 @@ class FastMathMutator : public ExprRewriter {
...
@@ -35,11 +35,14 @@ class FastMathMutator : public ExprRewriter {
public
:
public
:
FastMathMutator
()
FastMathMutator
()
:
exp_op_
(
Op
::
Get
(
"exp"
)),
:
exp_op_
(
Op
::
Get
(
"exp"
)),
erf_op_
(
Op
::
Get
(
"erf"
)),
tanh_op_
(
Op
::
Get
(
"tanh"
))
{}
tanh_op_
(
Op
::
Get
(
"tanh"
))
{}
Expr
Rewrite_
(
const
CallNode
*
pre
,
const
Expr
&
post
)
override
{
Expr
Rewrite_
(
const
CallNode
*
pre
,
const
Expr
&
post
)
override
{
if
(
pre
->
op
==
exp_op_
)
{
if
(
pre
->
op
==
exp_op_
)
{
return
FastExp
(
post
.
as
<
CallNode
>
()
->
args
[
0
]);
return
FastExp
(
post
.
as
<
CallNode
>
()
->
args
[
0
]);
}
else
if
(
pre
->
op
==
erf_op_
)
{
return
FastErf
(
post
.
as
<
CallNode
>
()
->
args
[
0
]);
}
else
if
(
pre
->
op
==
tanh_op_
)
{
}
else
if
(
pre
->
op
==
tanh_op_
)
{
return
FastTanh
(
post
.
as
<
CallNode
>
()
->
args
[
0
]);
return
FastTanh
(
post
.
as
<
CallNode
>
()
->
args
[
0
]);
}
}
...
@@ -51,6 +54,7 @@ class FastMathMutator : public ExprRewriter {
...
@@ -51,6 +54,7 @@ class FastMathMutator : public ExprRewriter {
// operator equivalence checking so that the registry lookup overhead can be
// operator equivalence checking so that the registry lookup overhead can be
// reduced.
// reduced.
const
Op
&
exp_op_
;
const
Op
&
exp_op_
;
const
Op
&
erf_op_
;
const
Op
&
tanh_op_
;
const
Op
&
tanh_op_
;
};
};
...
...
src/relay/transforms/pattern_util.h
View file @
f5b02fdb
...
@@ -322,6 +322,11 @@ inline Expr FastExp(Expr e) {
...
@@ -322,6 +322,11 @@ inline Expr FastExp(Expr e) {
return
Call
(
op
,
{
e
});
return
Call
(
op
,
{
e
});
}
}
inline
Expr
FastErf
(
Expr
e
)
{
static
const
Op
&
op
=
Op
::
Get
(
"fast_erf"
);
return
Call
(
op
,
{
e
});
}
inline
Expr
FastTanh
(
Expr
e
)
{
inline
Expr
FastTanh
(
Expr
e
)
{
static
const
Op
&
op
=
Op
::
Get
(
"fast_tanh"
);
static
const
Op
&
op
=
Op
::
Get
(
"fast_tanh"
);
return
Call
(
op
,
{
e
});
return
Call
(
op
,
{
e
});
...
...
tests/python/relay/test_op_fast_math.py
View file @
f5b02fdb
...
@@ -15,6 +15,8 @@
...
@@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# specific language governing permissions and limitations
# under the License.
# under the License.
import
numpy
as
np
import
numpy
as
np
import
scipy
from
scipy
import
special
import
tvm
import
tvm
import
tvm.relay
as
relay
import
tvm.relay
as
relay
import
topi
import
topi
...
@@ -52,6 +54,7 @@ def test_fastmath():
...
@@ -52,6 +54,7 @@ def test_fastmath():
rtol
=
1e-5
,
atol
=
1e-5
)
rtol
=
1e-5
,
atol
=
1e-5
)
test_apply
(
relay
.
exp
,
"fast_exp"
,
np
.
exp
,
low
=-
88
,
high
=
88
,
step
=
0.01
)
test_apply
(
relay
.
exp
,
"fast_exp"
,
np
.
exp
,
low
=-
88
,
high
=
88
,
step
=
0.01
)
test_apply
(
relay
.
erf
,
"fast_erf"
,
scipy
.
special
.
erf
,
low
=-
10
,
high
=
10
,
step
=
0.01
)
test_apply
(
relay
.
tanh
,
"fast_tanh"
,
np
.
tanh
,
low
=-
10
,
high
=
10
,
step
=
0.01
)
test_apply
(
relay
.
tanh
,
"fast_tanh"
,
np
.
tanh
,
low
=-
10
,
high
=
10
,
step
=
0.01
)
...
...
topi/include/topi/elemwise.h
View file @
f5b02fdb
...
@@ -27,6 +27,7 @@
...
@@ -27,6 +27,7 @@
#include <tvm/tir/expr.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/ir_pass.h>
#include <topi/tags.h>
#include <topi/tags.h>
#include <algorithm>
#include <string>
#include <string>
#include "broadcast.h"
#include "broadcast.h"
...
@@ -63,7 +64,7 @@ TOPI_DECLARE_UNARY_OP(tanh);
...
@@ -63,7 +64,7 @@ TOPI_DECLARE_UNARY_OP(tanh);
TOPI_DECLARE_UNARY_OP
(
isfinite
);
TOPI_DECLARE_UNARY_OP
(
isfinite
);
TOPI_DECLARE_UNARY_OP
(
isinf
);
TOPI_DECLARE_UNARY_OP
(
isinf
);
/*
/*
!
* \brief Fast_tanh_float implementation from Eigen
* \brief Fast_tanh_float implementation from Eigen
* https://github.com/eigenteam/eigen-git-mirror/blob/master/Eigen/src/Core/MathFunctionsImpl.h#L26
* https://github.com/eigenteam/eigen-git-mirror/blob/master/Eigen/src/Core/MathFunctionsImpl.h#L26
*/
*/
...
@@ -461,5 +462,75 @@ inline Tensor fast_exp(const Tensor& x,
...
@@ -461,5 +462,75 @@ inline Tensor fast_exp(const Tensor& x,
}
}
}
}
/*!
* \brief Fast_tanh_float implementation from Eigen
* https://github.com/eigenteam/eigen-git-mirror/blob/master/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h#L290
*/
inline
Tensor
fast_erf_float32
(
const
Tensor
&
data
,
std
::
string
name
,
std
::
string
tag
)
{
auto
plus_4
=
make_const
(
DataType
::
Float
(
32
),
4
.
f
);
auto
minus_4
=
make_const
(
DataType
::
Float
(
32
),
-
4
.
f
);
// The monomial coefficients of the numerator polynomial (odd).
auto
alpha_1
=
make_const
(
DataType
::
Float
(
32
),
-
1.60960333262415e-02
f
);
auto
alpha_3
=
make_const
(
DataType
::
Float
(
32
),
-
2.95459980854025e-03
f
);
auto
alpha_5
=
make_const
(
DataType
::
Float
(
32
),
-
7.34990630326855e-04
f
);
auto
alpha_7
=
make_const
(
DataType
::
Float
(
32
),
-
5.69250639462346e-05
f
);
auto
alpha_9
=
make_const
(
DataType
::
Float
(
32
),
-
2.10102402082508e-06
f
);
auto
alpha_11
=
make_const
(
DataType
::
Float
(
32
),
2.77068142495902e-08
f
);
auto
alpha_13
=
make_const
(
DataType
::
Float
(
32
),
-
2.72614225801306e-10
f
);
// The monomial coefficients of the denominator polynomial (even).
auto
beta_0
=
make_const
(
DataType
::
Float
(
32
),
-
1.42647390514189e-02
f
);
auto
beta_2
=
make_const
(
DataType
::
Float
(
32
),
-
7.37332916720468e-03
f
);
auto
beta_4
=
make_const
(
DataType
::
Float
(
32
),
-
1.68282697438203e-03
f
);
auto
beta_6
=
make_const
(
DataType
::
Float
(
32
),
-
2.13374055278905e-04
f
);
auto
beta_8
=
make_const
(
DataType
::
Float
(
32
),
-
1.45660718464996e-05
f
);
return
compute
(
data
->
shape
,
[
&
](
const
Array
<
Var
>
&
i
)
{
// clamp x
auto
x
=
tvm
::
max
(
tvm
::
min
(
data
(
i
),
plus_4
),
minus_4
);
auto
x2
=
x
*
x
;
// Evaluate the numerator polynomial p.
auto
p
=
x2
*
alpha_13
+
alpha_11
;
p
=
x2
*
p
+
alpha_9
;
p
=
x2
*
p
+
alpha_7
;
p
=
x2
*
p
+
alpha_5
;
p
=
x2
*
p
+
alpha_3
;
p
=
x2
*
p
+
alpha_1
;
p
=
x
*
p
;
// Evaluate the denominator polynomial p.
auto
q
=
x2
*
beta_8
+
beta_6
;
q
=
x2
*
q
+
beta_4
;
q
=
x2
*
q
+
beta_2
;
q
=
x2
*
q
+
beta_0
;
return
p
/
q
;
},
name
,
tag
);
}
/*!
* \brief Fast erf implementation
*
* \param x The input tensor
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is erf operation
*/
inline
Tensor
fast_erf
(
const
Tensor
&
x
,
std
::
string
name
=
"T_fast_erf"
,
std
::
string
tag
=
kElementWise
)
{
if
(
x
->
dtype
==
DataType
::
Float
(
32
))
{
auto
ret
=
fast_erf_float32
(
x
,
name
,
tag
);
return
ret
;
}
else
{
return
topi
::
erf
(
x
);
}
}
}
// namespace topi
}
// namespace topi
#endif // TOPI_ELEMWISE_H_
#endif // TOPI_ELEMWISE_H_
topi/python/topi/math.py
View file @
f5b02fdb
...
@@ -534,3 +534,19 @@ def fast_tanh(x):
...
@@ -534,3 +534,19 @@ def fast_tanh(x):
The result.
The result.
"""
"""
return
cpp
.
fast_tanh
(
x
,
x
.
dtype
,
tag
.
ELEMWISE
)
return
cpp
.
fast_tanh
(
x
,
x
.
dtype
,
tag
.
ELEMWISE
)
def
fast_erf
(
x
):
"""Take gauss error function of input x using fast_erf implementation.
Parameters
----------
x : tvm.te.Tensor
Input argument.
Returns
-------
y : tvm.te.Tensor
The result.
"""
return
cpp
.
fast_erf
(
x
,
x
.
dtype
,
tag
.
ELEMWISE
)
topi/src/elemwise.cc
View file @
f5b02fdb
...
@@ -46,6 +46,11 @@ TVM_REGISTER_GLOBAL("topi.erf")
...
@@ -46,6 +46,11 @@ TVM_REGISTER_GLOBAL("topi.erf")
*
rv
=
erf
(
args
[
0
]);
*
rv
=
erf
(
args
[
0
]);
});
});
TVM_REGISTER_GLOBAL
(
"topi.fast_erf"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
*
rv
=
fast_erf
(
args
[
0
]);
});
TVM_REGISTER_GLOBAL
(
"topi.tan"
)
TVM_REGISTER_GLOBAL
(
"topi.tan"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
*
rv
=
tan
(
args
[
0
]);
*
rv
=
tan
(
args
[
0
]);
...
...
topi/tests/python/test_topi_math.py
View file @
f5b02fdb
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
# under the License.
# under the License.
import
numpy
as
np
import
numpy
as
np
import
scipy
import
scipy
from
scipy
import
special
import
tvm
import
tvm
from
tvm
import
te
from
tvm
import
te
import
topi
import
topi
...
@@ -238,11 +239,11 @@ def test_fastmath():
...
@@ -238,11 +239,11 @@ def test_fastmath():
test_apply
(
topi
.
fast_exp
,
"fast_exp"
,
np
.
exp
,
test_apply
(
topi
.
fast_exp
,
"fast_exp"
,
np
.
exp
,
low
=-
88
,
high
=
88
,
low
=-
88
,
high
=
88
,
step
=
0.01
)
step
=
0.01
)
test_apply
(
topi
.
fast_erf
,
"fast_erf"
,
scipy
.
special
.
erf
,
low
=-
10
,
high
=
10
,
step
=
0.01
)
test_apply
(
topi
.
fast_tanh
,
"fast_tanh"
,
np
.
tanh
,
test_apply
(
topi
.
fast_tanh
,
"fast_tanh"
,
np
.
tanh
,
low
=-
10
,
high
=
10
,
low
=-
10
,
high
=
10
,
step
=
0.01
)
step
=
0.01
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_util
()
test_util
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment