Commit 61370e4b by Tianqi Chen Committed by GitHub

[MATH][TOPI][NNVM] introduce trunc, round (#1310)

parent 71b235fc
...@@ -10,6 +10,10 @@ tvm.intrin ...@@ -10,6 +10,10 @@ tvm.intrin
tvm.register_intrin_rule tvm.register_intrin_rule
tvm.exp tvm.exp
tvm.log tvm.log
tvm.floor
tvm.ceil
tvm.trunc
tvm.round
.. autofunction:: tvm.call_packed .. autofunction:: tvm.call_packed
...@@ -18,3 +22,7 @@ tvm.intrin ...@@ -18,3 +22,7 @@ tvm.intrin
.. autofunction:: tvm.register_intrin_rule .. autofunction:: tvm.register_intrin_rule
.. autofunction:: tvm.exp .. autofunction:: tvm.exp
.. autofunction:: tvm.log .. autofunction:: tvm.log
.. autofunction:: tvm.floor
.. autofunction:: tvm.ceil
.. autofunction:: tvm.trunc
.. autofunction:: tvm.round
...@@ -9,6 +9,10 @@ List of operators ...@@ -9,6 +9,10 @@ List of operators
topi.identity topi.identity
topi.negative topi.negative
topi.floor
topi.ceil
topi.trunc
topi.round
topi.exp topi.exp
topi.tanh topi.tanh
topi.log topi.log
...@@ -68,6 +72,10 @@ topi ...@@ -68,6 +72,10 @@ topi
~~~~ ~~~~
.. autofunction:: topi.negative .. autofunction:: topi.negative
.. autofunction:: topi.identity .. autofunction:: topi.identity
.. autofunction:: topi.floor
.. autofunction:: topi.ceil
.. autofunction:: topi.trunc
.. autofunction:: topi.round
.. autofunction:: topi.exp .. autofunction:: topi.exp
.. autofunction:: topi.tanh .. autofunction:: topi.tanh
.. autofunction:: topi.log .. autofunction:: topi.log
......
...@@ -75,6 +75,10 @@ This level enables typical convnet models. ...@@ -75,6 +75,10 @@ This level enables typical convnet models.
nnvm.symbol.reshape nnvm.symbol.reshape
nnvm.symbol.copy nnvm.symbol.copy
nnvm.symbol.negative nnvm.symbol.negative
nnvm.symbol.floor
nnvm.symbol.ceil
nnvm.symbol.round
nnvm.symbol.trunc
nnvm.symbol.leaky_relu nnvm.symbol.leaky_relu
nnvm.symbol.__add_scalar__ nnvm.symbol.__add_scalar__
nnvm.symbol.__sub_scalar__ nnvm.symbol.__sub_scalar__
...@@ -147,6 +151,10 @@ Detailed Definitions ...@@ -147,6 +151,10 @@ Detailed Definitions
.. autofunction:: nnvm.symbol.reshape .. autofunction:: nnvm.symbol.reshape
.. autofunction:: nnvm.symbol.copy .. autofunction:: nnvm.symbol.copy
.. autofunction:: nnvm.symbol.negative .. autofunction:: nnvm.symbol.negative
.. autofunction:: nnvm.symbol.floor
.. autofunction:: nnvm.symbol.ceil
.. autofunction:: nnvm.symbol.round
.. autofunction:: nnvm.symbol.trunc
.. autofunction:: nnvm.symbol.leaky_relu .. autofunction:: nnvm.symbol.leaky_relu
.. autofunction:: nnvm.symbol.__add_scalar__ .. autofunction:: nnvm.symbol.__add_scalar__
.. autofunction:: nnvm.symbol.__sub_scalar__ .. autofunction:: nnvm.symbol.__sub_scalar__
......
...@@ -55,6 +55,8 @@ TVM_DECLARE_INTRIN_UNARY(sqrt); ...@@ -55,6 +55,8 @@ TVM_DECLARE_INTRIN_UNARY(sqrt);
TVM_DECLARE_INTRIN_UNARY(log); TVM_DECLARE_INTRIN_UNARY(log);
TVM_DECLARE_INTRIN_UNARY(floor); TVM_DECLARE_INTRIN_UNARY(floor);
TVM_DECLARE_INTRIN_UNARY(ceil); TVM_DECLARE_INTRIN_UNARY(ceil);
TVM_DECLARE_INTRIN_UNARY(round);
TVM_DECLARE_INTRIN_UNARY(trunc);
inline Expr pow(Expr x, Expr y) { inline Expr pow(Expr x, Expr y) {
return ir::Call::make(x.type(), "pow", { x, y }, ir::Call::PureIntrinsic); return ir::Call::make(x.type(), "pow", { x, y }, ir::Call::PureIntrinsic);
......
...@@ -61,6 +61,22 @@ def compute_cast(attrs, inputs, _): ...@@ -61,6 +61,22 @@ def compute_cast(attrs, inputs, _):
reg.register_pattern("cast", OpPattern.ELEMWISE) reg.register_pattern("cast", OpPattern.ELEMWISE)
reg.register_schedule("cast", _fschedule_broadcast) reg.register_schedule("cast", _fschedule_broadcast)
# floor
reg.register_pattern("floor", OpPattern.ELEMWISE)
reg.register_schedule("floor", _fschedule_broadcast)
# ceil
reg.register_pattern("ceil", OpPattern.ELEMWISE)
reg.register_schedule("ceil", _fschedule_broadcast)
# round
reg.register_pattern("round", OpPattern.ELEMWISE)
reg.register_schedule("round", _fschedule_broadcast)
# trunc
reg.register_pattern("trunc", OpPattern.ELEMWISE)
reg.register_schedule("trunc", _fschedule_broadcast)
# exp # exp
reg.register_pattern("exp", OpPattern.ELEMWISE) reg.register_pattern("exp", OpPattern.ELEMWISE)
reg.register_schedule("exp", _fschedule_broadcast) reg.register_schedule("exp", _fschedule_broadcast)
......
...@@ -31,6 +31,54 @@ Used to produce invalide node during optimization. ...@@ -31,6 +31,54 @@ Used to produce invalide node during optimization.
.set_num_outputs(1) .set_num_outputs(1)
.set_num_inputs(0); .set_num_inputs(0);
// floor
NNVM_REGISTER_ELEMWISE_UNARY_OP(floor)
.describe(R"code(Take floor input array, computed element-wise.
)code" NNVM_ADD_FILELINE)
.set_support_level(3)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::floor(inputs[0]) };
});
// ceil
NNVM_REGISTER_ELEMWISE_UNARY_OP(ceil)
.describe(R"code(Take ceil input array, computed element-wise.
)code" NNVM_ADD_FILELINE)
.set_support_level(3)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::ceil(inputs[0]) };
});
// trunc
NNVM_REGISTER_ELEMWISE_UNARY_OP(trunc)
.describe(R"code(Take truncated value of the input, element-wise.
)code" NNVM_ADD_FILELINE)
.set_support_level(3)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::trunc(inputs[0]) };
});
// round
NNVM_REGISTER_ELEMWISE_UNARY_OP(round)
.describe(R"code(Round elements of the input to nearest integer.
)code" NNVM_ADD_FILELINE)
.set_support_level(3)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::round(inputs[0]) };
});
// sigmoid // sigmoid
NNVM_REGISTER_ELEMWISE_UNARY_OP(sigmoid) NNVM_REGISTER_ELEMWISE_UNARY_OP(sigmoid)
.describe(R"code(Computes sigmoid. .describe(R"code(Computes sigmoid.
......
import numpy as np
import tvm
from tvm.contrib import graph_runtime
import topi.testing
import nnvm.symbol as sym
import nnvm.compiler
from nnvm.testing.config import ctx_list
from test_top_level1 import helper
def check_map(symfunc, np_func, np_backward=None):
x = sym.Variable("x")
y = symfunc(x)
dtype = "float32"
dshape = (1, 3, 32, 32)
inputs = [('x', dshape, x)]
helper(y, inputs, dtype, lambda x: np_func(x), np_backward)
def test_floor():
check_map(sym.floor, np.floor)
def test_ceil():
check_map(sym.ceil, np.ceil)
def test_trunc():
check_map(sym.trunc, np.trunc)
def test_round():
check_map(sym.round, np.round)
if __name__ == "__main__":
test_floor()
test_ceil()
test_round()
test_trunc()
"""Expression Intrinsics and math functions in TVM.""" """Expression Intrinsics and math functions in TVM."""
# pylint: disable=redefined-builtin
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ._ffi.function import register_func as _register_func from ._ffi.function import register_func as _register_func
...@@ -265,6 +266,41 @@ def ceil(x): ...@@ -265,6 +266,41 @@ def ceil(x):
return call_pure_intrin(x.dtype, "ceil", x) return call_pure_intrin(x.dtype, "ceil", x)
def trunc(x):
"""Get truncated value of the input.
The truncated value of the scalar x is the
nearest integer i which is closer to zero than x is.
Parameters
----------
x : Expr
Input argument.
Returns
-------
y : Expr
The result.
"""
return call_pure_intrin(x.dtype, "trunc", x)
def round(x):
"""Round elements of the array to the nearest integer.
Parameters
----------
x : Expr
Input argument.
Returns
-------
y : Expr
The result.
"""
return call_pure_intrin(x.dtype, "round", x)
def power(x, y): def power(x, y):
"""x power y """x power y
......
...@@ -61,6 +61,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.floor") ...@@ -61,6 +61,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.floor")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.ceil") TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.ceil")
.set_body(DispatchExtern<CUDAMath>); .set_body(DispatchExtern<CUDAMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.trunc")
.set_body(DispatchExtern<CUDAMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.round")
.set_body(DispatchExtern<CUDAMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp") TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp")
.set_body(DispatchExtern<CUDAFastMath>); .set_body(DispatchExtern<CUDAFastMath>);
......
...@@ -15,6 +15,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.floor") ...@@ -15,6 +15,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.floor")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.ceil") TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.ceil")
.set_body(DispatchExtern<Direct>); .set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.trunc")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.round")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp") TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp")
.set_body(DispatchExtern<Direct>); .set_body(DispatchExtern<Direct>);
......
...@@ -15,6 +15,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.floor") ...@@ -15,6 +15,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.floor")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.ceil") TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.ceil")
.set_body(DispatchExtern<Direct>); .set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.trunc")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.round")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp") TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp")
.set_body(DispatchExtern<Direct>); .set_body(DispatchExtern<Direct>);
......
...@@ -31,6 +31,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.floor") ...@@ -31,6 +31,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.floor")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.ceil") TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.ceil")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1>); .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.trunc")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.round")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tanh") TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tanh")
.set_body([](const TVMArgs& targs, TVMRetValue* rv) { .set_body([](const TVMArgs& targs, TVMRetValue* rv) {
Expr e = targs[0]; Expr e = targs[0];
......
...@@ -32,6 +32,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.floor") ...@@ -32,6 +32,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.floor")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.ceil") TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.ceil")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1>); .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.round")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.trunc")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp") TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp")
.set_body(DispatchExternOCML); .set_body(DispatchExternOCML);
......
...@@ -35,6 +35,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.floor") ...@@ -35,6 +35,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.floor")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.ceil") TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.ceil")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Ceil>); .set_body(DispatchGLSLPureIntrin<GLSLstd450Ceil>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.round")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Round>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.trunc")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Trunc>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.exp") TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.exp")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Exp>); .set_body(DispatchGLSLPureIntrin<GLSLstd450Exp>);
......
...@@ -31,6 +31,8 @@ TOPI_DECLARE_UNARY_OP(sqrt); ...@@ -31,6 +31,8 @@ TOPI_DECLARE_UNARY_OP(sqrt);
TOPI_DECLARE_UNARY_OP(log); TOPI_DECLARE_UNARY_OP(log);
TOPI_DECLARE_UNARY_OP(floor); TOPI_DECLARE_UNARY_OP(floor);
TOPI_DECLARE_UNARY_OP(ceil); TOPI_DECLARE_UNARY_OP(ceil);
TOPI_DECLARE_UNARY_OP(round);
TOPI_DECLARE_UNARY_OP(trunc);
/*! /*!
* \brief Creates an operation that returns identity of a given tensor * \brief Creates an operation that returns identity of a given tensor
......
"""Elementwise operators""" """Elementwise operators"""
# pylint: disable=redefined-builtin
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm import tvm
from . import tag from . import tag
...@@ -108,6 +109,40 @@ def ceil(x): ...@@ -108,6 +109,40 @@ def ceil(x):
@tvm.tag_scope(tag=tag.ELEMWISE) @tvm.tag_scope(tag=tag.ELEMWISE)
def trunc(x):
"""Take truncated value of the input of x, element-wise.
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
return tvm.compute(x.shape, lambda *i: tvm.trunc(x(*i)))
@tvm.tag_scope(tag=tag.ELEMWISE)
def round(x):
"""Round elements of x to nearest integer.
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
return tvm.compute(x.shape, lambda *i: tvm.round(x(*i)))
@tvm.tag_scope(tag=tag.ELEMWISE)
def log(x): def log(x):
"""Take logarithm of input x. """Take logarithm of input x.
......
...@@ -33,9 +33,9 @@ def test_ewise(): ...@@ -33,9 +33,9 @@ def test_ewise():
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
s = topi.generic.schedule_injective(B) s = topi.generic.schedule_injective(B)
foo = tvm.build(s, [A, B], device, name=name)
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros_like(b_np), ctx) b = tvm.nd.array(np.zeros_like(b_np), ctx)
foo = tvm.build(s, [A, B], device, name=name)
foo(a, b) foo(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5) np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5)
...@@ -45,6 +45,8 @@ def test_ewise(): ...@@ -45,6 +45,8 @@ def test_ewise():
test_apply(topi.floor, "floor", np.floor, -100, 100) test_apply(topi.floor, "floor", np.floor, -100, 100)
test_apply(topi.ceil, "ceil", np.ceil, -100, 100) test_apply(topi.ceil, "ceil", np.ceil, -100, 100)
test_apply(topi.trunc, "trunc", np.trunc, -100, 100)
test_apply(topi.round, "round", np.round, -100, 100)
test_apply(topi.exp, "exp", np.exp, -1, 1) test_apply(topi.exp, "exp", np.exp, -1, 1)
test_apply(topi.tanh, "tanh", np.tanh, -10, 10) test_apply(topi.tanh, "tanh", np.tanh, -10, 10)
test_apply(topi.sigmoid, "sigmoid", lambda x:1/(1+np.exp(-x)), -1, 1) test_apply(topi.sigmoid, "sigmoid", lambda x:1/(1+np.exp(-x)), -1, 1)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment