Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
61370e4b
Commit
61370e4b
authored
Jun 21, 2018
by
Tianqi Chen
Committed by
GitHub
Jun 21, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[MATH][TOPI][NNVM] introduce trunc, round (#1310)
parent
71b235fc
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
238 additions
and
1 deletions
+238
-1
docs/api/python/intrin.rst
+8
-0
docs/api/python/topi.rst
+8
-0
docs/nnvm_top.rst
+8
-0
include/tvm/ir_operator.h
+2
-0
nnvm/python/nnvm/top/tensor.py
+16
-0
nnvm/src/top/tensor/elemwise.cc
+48
-0
nnvm/tests/python/compiler/test_top_level3.py
+36
-0
python/tvm/intrin.py
+36
-0
src/codegen/intrin_rule_cuda.cc
+6
-0
src/codegen/intrin_rule_metal.cc
+6
-0
src/codegen/intrin_rule_opencl.cc
+6
-0
src/codegen/llvm/intrin_rule_llvm.cc
+6
-0
src/codegen/llvm/intrin_rule_rocm.cc
+6
-0
src/codegen/spirv/intrin_rule_spirv.cc
+6
-0
topi/include/topi/elemwise.h
+2
-0
topi/python/topi/math.py
+35
-0
topi/tests/python/test_topi_math.py
+3
-1
No files found.
docs/api/python/intrin.rst
View file @
61370e4b
...
@@ -10,6 +10,10 @@ tvm.intrin
...
@@ -10,6 +10,10 @@ tvm.intrin
tvm.register_intrin_rule
tvm.register_intrin_rule
tvm.exp
tvm.exp
tvm.log
tvm.log
tvm.floor
tvm.ceil
tvm.trunc
tvm.round
.. autofunction:: tvm.call_packed
.. autofunction:: tvm.call_packed
...
@@ -18,3 +22,7 @@ tvm.intrin
...
@@ -18,3 +22,7 @@ tvm.intrin
.. autofunction:: tvm.register_intrin_rule
.. autofunction:: tvm.register_intrin_rule
.. autofunction:: tvm.exp
.. autofunction:: tvm.exp
.. autofunction:: tvm.log
.. autofunction:: tvm.log
.. autofunction:: tvm.floor
.. autofunction:: tvm.ceil
.. autofunction:: tvm.trunc
.. autofunction:: tvm.round
docs/api/python/topi.rst
View file @
61370e4b
...
@@ -9,6 +9,10 @@ List of operators
...
@@ -9,6 +9,10 @@ List of operators
topi.identity
topi.identity
topi.negative
topi.negative
topi.floor
topi.ceil
topi.trunc
topi.round
topi.exp
topi.exp
topi.tanh
topi.tanh
topi.log
topi.log
...
@@ -68,6 +72,10 @@ topi
...
@@ -68,6 +72,10 @@ topi
~~~~
~~~~
.. autofunction:: topi.negative
.. autofunction:: topi.negative
.. autofunction:: topi.identity
.. autofunction:: topi.identity
.. autofunction:: topi.floor
.. autofunction:: topi.ceil
.. autofunction:: topi.trunc
.. autofunction:: topi.round
.. autofunction:: topi.exp
.. autofunction:: topi.exp
.. autofunction:: topi.tanh
.. autofunction:: topi.tanh
.. autofunction:: topi.log
.. autofunction:: topi.log
...
...
docs/nnvm_top.rst
View file @
61370e4b
...
@@ -75,6 +75,10 @@ This level enables typical convnet models.
...
@@ -75,6 +75,10 @@ This level enables typical convnet models.
nnvm.symbol.reshape
nnvm.symbol.reshape
nnvm.symbol.copy
nnvm.symbol.copy
nnvm.symbol.negative
nnvm.symbol.negative
nnvm.symbol.floor
nnvm.symbol.ceil
nnvm.symbol.round
nnvm.symbol.trunc
nnvm.symbol.leaky_relu
nnvm.symbol.leaky_relu
nnvm.symbol.__add_scalar__
nnvm.symbol.__add_scalar__
nnvm.symbol.__sub_scalar__
nnvm.symbol.__sub_scalar__
...
@@ -147,6 +151,10 @@ Detailed Definitions
...
@@ -147,6 +151,10 @@ Detailed Definitions
.. autofunction:: nnvm.symbol.reshape
.. autofunction:: nnvm.symbol.reshape
.. autofunction:: nnvm.symbol.copy
.. autofunction:: nnvm.symbol.copy
.. autofunction:: nnvm.symbol.negative
.. autofunction:: nnvm.symbol.negative
.. autofunction:: nnvm.symbol.floor
.. autofunction:: nnvm.symbol.ceil
.. autofunction:: nnvm.symbol.round
.. autofunction:: nnvm.symbol.trunc
.. autofunction:: nnvm.symbol.leaky_relu
.. autofunction:: nnvm.symbol.leaky_relu
.. autofunction:: nnvm.symbol.__add_scalar__
.. autofunction:: nnvm.symbol.__add_scalar__
.. autofunction:: nnvm.symbol.__sub_scalar__
.. autofunction:: nnvm.symbol.__sub_scalar__
...
...
include/tvm/ir_operator.h
View file @
61370e4b
...
@@ -55,6 +55,8 @@ TVM_DECLARE_INTRIN_UNARY(sqrt);
...
@@ -55,6 +55,8 @@ TVM_DECLARE_INTRIN_UNARY(sqrt);
TVM_DECLARE_INTRIN_UNARY
(
log
);
TVM_DECLARE_INTRIN_UNARY
(
log
);
TVM_DECLARE_INTRIN_UNARY
(
floor
);
TVM_DECLARE_INTRIN_UNARY
(
floor
);
TVM_DECLARE_INTRIN_UNARY
(
ceil
);
TVM_DECLARE_INTRIN_UNARY
(
ceil
);
TVM_DECLARE_INTRIN_UNARY
(
round
);
TVM_DECLARE_INTRIN_UNARY
(
trunc
);
inline
Expr
pow
(
Expr
x
,
Expr
y
)
{
inline
Expr
pow
(
Expr
x
,
Expr
y
)
{
return
ir
::
Call
::
make
(
x
.
type
(),
"pow"
,
{
x
,
y
},
ir
::
Call
::
PureIntrinsic
);
return
ir
::
Call
::
make
(
x
.
type
(),
"pow"
,
{
x
,
y
},
ir
::
Call
::
PureIntrinsic
);
...
...
nnvm/python/nnvm/top/tensor.py
View file @
61370e4b
...
@@ -61,6 +61,22 @@ def compute_cast(attrs, inputs, _):
...
@@ -61,6 +61,22 @@ def compute_cast(attrs, inputs, _):
reg
.
register_pattern
(
"cast"
,
OpPattern
.
ELEMWISE
)
reg
.
register_pattern
(
"cast"
,
OpPattern
.
ELEMWISE
)
reg
.
register_schedule
(
"cast"
,
_fschedule_broadcast
)
reg
.
register_schedule
(
"cast"
,
_fschedule_broadcast
)
# floor
reg
.
register_pattern
(
"floor"
,
OpPattern
.
ELEMWISE
)
reg
.
register_schedule
(
"floor"
,
_fschedule_broadcast
)
# ceil
reg
.
register_pattern
(
"ceil"
,
OpPattern
.
ELEMWISE
)
reg
.
register_schedule
(
"ceil"
,
_fschedule_broadcast
)
# round
reg
.
register_pattern
(
"round"
,
OpPattern
.
ELEMWISE
)
reg
.
register_schedule
(
"round"
,
_fschedule_broadcast
)
# trunc
reg
.
register_pattern
(
"trunc"
,
OpPattern
.
ELEMWISE
)
reg
.
register_schedule
(
"trunc"
,
_fschedule_broadcast
)
# exp
# exp
reg
.
register_pattern
(
"exp"
,
OpPattern
.
ELEMWISE
)
reg
.
register_pattern
(
"exp"
,
OpPattern
.
ELEMWISE
)
reg
.
register_schedule
(
"exp"
,
_fschedule_broadcast
)
reg
.
register_schedule
(
"exp"
,
_fschedule_broadcast
)
...
...
nnvm/src/top/tensor/elemwise.cc
View file @
61370e4b
...
@@ -31,6 +31,54 @@ Used to produce invalide node during optimization.
...
@@ -31,6 +31,54 @@ Used to produce invalide node during optimization.
.
set_num_outputs
(
1
)
.
set_num_outputs
(
1
)
.
set_num_inputs
(
0
);
.
set_num_inputs
(
0
);
// floor
NNVM_REGISTER_ELEMWISE_UNARY_OP
(
floor
)
.
describe
(
R"code(Take floor input array, computed element-wise.
)code"
NNVM_ADD_FILELINE
)
.
set_support_level
(
3
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
[](
const
NodeAttrs
&
attrs
,
const
Array
<
Tensor
>&
inputs
,
const
Array
<
Tensor
>&
out_info
)
{
return
Array
<
Tensor
>
{
topi
::
floor
(
inputs
[
0
])
};
});
// ceil
NNVM_REGISTER_ELEMWISE_UNARY_OP
(
ceil
)
.
describe
(
R"code(Take ceil input array, computed element-wise.
)code"
NNVM_ADD_FILELINE
)
.
set_support_level
(
3
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
[](
const
NodeAttrs
&
attrs
,
const
Array
<
Tensor
>&
inputs
,
const
Array
<
Tensor
>&
out_info
)
{
return
Array
<
Tensor
>
{
topi
::
ceil
(
inputs
[
0
])
};
});
// trunc
NNVM_REGISTER_ELEMWISE_UNARY_OP
(
trunc
)
.
describe
(
R"code(Take truncated value of the input, element-wise.
)code"
NNVM_ADD_FILELINE
)
.
set_support_level
(
3
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
[](
const
NodeAttrs
&
attrs
,
const
Array
<
Tensor
>&
inputs
,
const
Array
<
Tensor
>&
out_info
)
{
return
Array
<
Tensor
>
{
topi
::
trunc
(
inputs
[
0
])
};
});
// round
NNVM_REGISTER_ELEMWISE_UNARY_OP
(
round
)
.
describe
(
R"code(Round elements of the input to nearest integer.
)code"
NNVM_ADD_FILELINE
)
.
set_support_level
(
3
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
[](
const
NodeAttrs
&
attrs
,
const
Array
<
Tensor
>&
inputs
,
const
Array
<
Tensor
>&
out_info
)
{
return
Array
<
Tensor
>
{
topi
::
round
(
inputs
[
0
])
};
});
// sigmoid
// sigmoid
NNVM_REGISTER_ELEMWISE_UNARY_OP
(
sigmoid
)
NNVM_REGISTER_ELEMWISE_UNARY_OP
(
sigmoid
)
.
describe
(
R"code(Computes sigmoid.
.
describe
(
R"code(Computes sigmoid.
...
...
nnvm/tests/python/compiler/test_top_level3.py
0 → 100644
View file @
61370e4b
import
numpy
as
np
import
tvm
from
tvm.contrib
import
graph_runtime
import
topi.testing
import
nnvm.symbol
as
sym
import
nnvm.compiler
from
nnvm.testing.config
import
ctx_list
from
test_top_level1
import
helper
def
check_map
(
symfunc
,
np_func
,
np_backward
=
None
):
x
=
sym
.
Variable
(
"x"
)
y
=
symfunc
(
x
)
dtype
=
"float32"
dshape
=
(
1
,
3
,
32
,
32
)
inputs
=
[(
'x'
,
dshape
,
x
)]
helper
(
y
,
inputs
,
dtype
,
lambda
x
:
np_func
(
x
),
np_backward
)
def
test_floor
():
check_map
(
sym
.
floor
,
np
.
floor
)
def
test_ceil
():
check_map
(
sym
.
ceil
,
np
.
ceil
)
def
test_trunc
():
check_map
(
sym
.
trunc
,
np
.
trunc
)
def
test_round
():
check_map
(
sym
.
round
,
np
.
round
)
if
__name__
==
"__main__"
:
test_floor
()
test_ceil
()
test_round
()
test_trunc
()
python/tvm/intrin.py
View file @
61370e4b
"""Expression Intrinsics and math functions in TVM."""
"""Expression Intrinsics and math functions in TVM."""
# pylint: disable=redefined-builtin
from
__future__
import
absolute_import
as
_abs
from
__future__
import
absolute_import
as
_abs
from
._ffi.function
import
register_func
as
_register_func
from
._ffi.function
import
register_func
as
_register_func
...
@@ -265,6 +266,41 @@ def ceil(x):
...
@@ -265,6 +266,41 @@ def ceil(x):
return
call_pure_intrin
(
x
.
dtype
,
"ceil"
,
x
)
return
call_pure_intrin
(
x
.
dtype
,
"ceil"
,
x
)
def
trunc
(
x
):
"""Get truncated value of the input.
The truncated value of the scalar x is the
nearest integer i which is closer to zero than x is.
Parameters
----------
x : Expr
Input argument.
Returns
-------
y : Expr
The result.
"""
return
call_pure_intrin
(
x
.
dtype
,
"trunc"
,
x
)
def
round
(
x
):
"""Round elements of the array to the nearest integer.
Parameters
----------
x : Expr
Input argument.
Returns
-------
y : Expr
The result.
"""
return
call_pure_intrin
(
x
.
dtype
,
"round"
,
x
)
def
power
(
x
,
y
):
def
power
(
x
,
y
):
"""x power y
"""x power y
...
...
src/codegen/intrin_rule_cuda.cc
View file @
61370e4b
...
@@ -61,6 +61,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.floor")
...
@@ -61,6 +61,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.floor")
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.cuda.ceil"
)
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.cuda.ceil"
)
.
set_body
(
DispatchExtern
<
CUDAMath
>
);
.
set_body
(
DispatchExtern
<
CUDAMath
>
);
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.cuda.trunc"
)
.
set_body
(
DispatchExtern
<
CUDAMath
>
);
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.cuda.round"
)
.
set_body
(
DispatchExtern
<
CUDAMath
>
);
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.cuda.exp"
)
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.cuda.exp"
)
.
set_body
(
DispatchExtern
<
CUDAFastMath
>
);
.
set_body
(
DispatchExtern
<
CUDAFastMath
>
);
...
...
src/codegen/intrin_rule_metal.cc
View file @
61370e4b
...
@@ -15,6 +15,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.floor")
...
@@ -15,6 +15,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.floor")
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.metal.ceil"
)
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.metal.ceil"
)
.
set_body
(
DispatchExtern
<
Direct
>
);
.
set_body
(
DispatchExtern
<
Direct
>
);
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.metal.trunc"
)
.
set_body
(
DispatchExtern
<
Direct
>
);
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.metal.round"
)
.
set_body
(
DispatchExtern
<
Direct
>
);
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.metal.exp"
)
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.metal.exp"
)
.
set_body
(
DispatchExtern
<
Direct
>
);
.
set_body
(
DispatchExtern
<
Direct
>
);
...
...
src/codegen/intrin_rule_opencl.cc
View file @
61370e4b
...
@@ -15,6 +15,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.floor")
...
@@ -15,6 +15,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.floor")
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.opencl.ceil"
)
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.opencl.ceil"
)
.
set_body
(
DispatchExtern
<
Direct
>
);
.
set_body
(
DispatchExtern
<
Direct
>
);
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.opencl.trunc"
)
.
set_body
(
DispatchExtern
<
Direct
>
);
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.opencl.round"
)
.
set_body
(
DispatchExtern
<
Direct
>
);
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.opencl.exp"
)
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.opencl.exp"
)
.
set_body
(
DispatchExtern
<
Direct
>
);
.
set_body
(
DispatchExtern
<
Direct
>
);
...
...
src/codegen/llvm/intrin_rule_llvm.cc
View file @
61370e4b
...
@@ -31,6 +31,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.floor")
...
@@ -31,6 +31,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.floor")
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.llvm.ceil"
)
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.llvm.ceil"
)
.
set_body
(
DispatchLLVMPureIntrin
<::
llvm
::
Intrinsic
::
ceil
,
1
>
);
.
set_body
(
DispatchLLVMPureIntrin
<::
llvm
::
Intrinsic
::
ceil
,
1
>
);
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.llvm.trunc"
)
.
set_body
(
DispatchLLVMPureIntrin
<::
llvm
::
Intrinsic
::
trunc
,
1
>
);
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.llvm.round"
)
.
set_body
(
DispatchLLVMPureIntrin
<::
llvm
::
Intrinsic
::
round
,
1
>
);
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.llvm.tanh"
)
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.llvm.tanh"
)
.
set_body
([](
const
TVMArgs
&
targs
,
TVMRetValue
*
rv
)
{
.
set_body
([](
const
TVMArgs
&
targs
,
TVMRetValue
*
rv
)
{
Expr
e
=
targs
[
0
];
Expr
e
=
targs
[
0
];
...
...
src/codegen/llvm/intrin_rule_rocm.cc
View file @
61370e4b
...
@@ -32,6 +32,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.floor")
...
@@ -32,6 +32,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.floor")
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.rocm.ceil"
)
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.rocm.ceil"
)
.
set_body
(
DispatchLLVMPureIntrin
<::
llvm
::
Intrinsic
::
ceil
,
1
>
);
.
set_body
(
DispatchLLVMPureIntrin
<::
llvm
::
Intrinsic
::
ceil
,
1
>
);
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.rocm.round"
)
.
set_body
(
DispatchLLVMPureIntrin
<::
llvm
::
Intrinsic
::
round
,
1
>
);
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.rocm.trunc"
)
.
set_body
(
DispatchLLVMPureIntrin
<::
llvm
::
Intrinsic
::
trunc
,
1
>
);
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.rocm.exp"
)
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.rocm.exp"
)
.
set_body
(
DispatchExternOCML
);
.
set_body
(
DispatchExternOCML
);
...
...
src/codegen/spirv/intrin_rule_spirv.cc
View file @
61370e4b
...
@@ -35,6 +35,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.floor")
...
@@ -35,6 +35,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.floor")
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.vulkan.ceil"
)
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.vulkan.ceil"
)
.
set_body
(
DispatchGLSLPureIntrin
<
GLSLstd450Ceil
>
);
.
set_body
(
DispatchGLSLPureIntrin
<
GLSLstd450Ceil
>
);
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.vulkan.round"
)
.
set_body
(
DispatchGLSLPureIntrin
<
GLSLstd450Round
>
);
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.vulkan.trunc"
)
.
set_body
(
DispatchGLSLPureIntrin
<
GLSLstd450Trunc
>
);
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.vulkan.exp"
)
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.vulkan.exp"
)
.
set_body
(
DispatchGLSLPureIntrin
<
GLSLstd450Exp
>
);
.
set_body
(
DispatchGLSLPureIntrin
<
GLSLstd450Exp
>
);
...
...
topi/include/topi/elemwise.h
View file @
61370e4b
...
@@ -31,6 +31,8 @@ TOPI_DECLARE_UNARY_OP(sqrt);
...
@@ -31,6 +31,8 @@ TOPI_DECLARE_UNARY_OP(sqrt);
TOPI_DECLARE_UNARY_OP
(
log
);
TOPI_DECLARE_UNARY_OP
(
log
);
TOPI_DECLARE_UNARY_OP
(
floor
);
TOPI_DECLARE_UNARY_OP
(
floor
);
TOPI_DECLARE_UNARY_OP
(
ceil
);
TOPI_DECLARE_UNARY_OP
(
ceil
);
TOPI_DECLARE_UNARY_OP
(
round
);
TOPI_DECLARE_UNARY_OP
(
trunc
);
/*!
/*!
* \brief Creates an operation that returns identity of a given tensor
* \brief Creates an operation that returns identity of a given tensor
...
...
topi/python/topi/math.py
View file @
61370e4b
"""Elementwise operators"""
"""Elementwise operators"""
# pylint: disable=redefined-builtin
from
__future__
import
absolute_import
as
_abs
from
__future__
import
absolute_import
as
_abs
import
tvm
import
tvm
from
.
import
tag
from
.
import
tag
...
@@ -108,6 +109,40 @@ def ceil(x):
...
@@ -108,6 +109,40 @@ def ceil(x):
@tvm.tag_scope
(
tag
=
tag
.
ELEMWISE
)
@tvm.tag_scope
(
tag
=
tag
.
ELEMWISE
)
def
trunc
(
x
):
"""Take truncated value of the input of x, element-wise.
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
return
tvm
.
compute
(
x
.
shape
,
lambda
*
i
:
tvm
.
trunc
(
x
(
*
i
)))
@tvm.tag_scope
(
tag
=
tag
.
ELEMWISE
)
def
round
(
x
):
"""Round elements of x to nearest integer.
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
return
tvm
.
compute
(
x
.
shape
,
lambda
*
i
:
tvm
.
round
(
x
(
*
i
)))
@tvm.tag_scope
(
tag
=
tag
.
ELEMWISE
)
def
log
(
x
):
def
log
(
x
):
"""Take logarithm of input x.
"""Take logarithm of input x.
...
...
topi/tests/python/test_topi_math.py
View file @
61370e4b
...
@@ -33,9 +33,9 @@ def test_ewise():
...
@@ -33,9 +33,9 @@ def test_ewise():
print
(
"Running on target:
%
s"
%
device
)
print
(
"Running on target:
%
s"
%
device
)
with
tvm
.
target
.
create
(
device
):
with
tvm
.
target
.
create
(
device
):
s
=
topi
.
generic
.
schedule_injective
(
B
)
s
=
topi
.
generic
.
schedule_injective
(
B
)
foo
=
tvm
.
build
(
s
,
[
A
,
B
],
device
,
name
=
name
)
a
=
tvm
.
nd
.
array
(
a_np
,
ctx
)
a
=
tvm
.
nd
.
array
(
a_np
,
ctx
)
b
=
tvm
.
nd
.
array
(
np
.
zeros_like
(
b_np
),
ctx
)
b
=
tvm
.
nd
.
array
(
np
.
zeros_like
(
b_np
),
ctx
)
foo
=
tvm
.
build
(
s
,
[
A
,
B
],
device
,
name
=
name
)
foo
(
a
,
b
)
foo
(
a
,
b
)
np
.
testing
.
assert_allclose
(
b
.
asnumpy
(),
b_np
,
rtol
=
1e-5
,
atol
=
1e-5
)
np
.
testing
.
assert_allclose
(
b
.
asnumpy
(),
b_np
,
rtol
=
1e-5
,
atol
=
1e-5
)
...
@@ -45,6 +45,8 @@ def test_ewise():
...
@@ -45,6 +45,8 @@ def test_ewise():
test_apply
(
topi
.
floor
,
"floor"
,
np
.
floor
,
-
100
,
100
)
test_apply
(
topi
.
floor
,
"floor"
,
np
.
floor
,
-
100
,
100
)
test_apply
(
topi
.
ceil
,
"ceil"
,
np
.
ceil
,
-
100
,
100
)
test_apply
(
topi
.
ceil
,
"ceil"
,
np
.
ceil
,
-
100
,
100
)
test_apply
(
topi
.
trunc
,
"trunc"
,
np
.
trunc
,
-
100
,
100
)
test_apply
(
topi
.
round
,
"round"
,
np
.
round
,
-
100
,
100
)
test_apply
(
topi
.
exp
,
"exp"
,
np
.
exp
,
-
1
,
1
)
test_apply
(
topi
.
exp
,
"exp"
,
np
.
exp
,
-
1
,
1
)
test_apply
(
topi
.
tanh
,
"tanh"
,
np
.
tanh
,
-
10
,
10
)
test_apply
(
topi
.
tanh
,
"tanh"
,
np
.
tanh
,
-
10
,
10
)
test_apply
(
topi
.
sigmoid
,
"sigmoid"
,
lambda
x
:
1
/
(
1
+
np
.
exp
(
-
x
)),
-
1
,
1
)
test_apply
(
topi
.
sigmoid
,
"sigmoid"
,
lambda
x
:
1
/
(
1
+
np
.
exp
(
-
x
)),
-
1
,
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment