Commit dc6203c2 by Tianqi Chen Committed by GitHub

[INTRIN] Add support for floor and ceil (#1267)

parent 652b397f
...@@ -38,6 +38,9 @@ set(USE_METAL OFF) ...@@ -38,6 +38,9 @@ set(USE_METAL OFF)
# Whether enable Vulkan runtime # Whether enable Vulkan runtime
set(USE_VULKAN OFF) set(USE_VULKAN OFF)
# Whether enable OpenGL runtime
set(USE_OPENGL OFF)
# Whether enable RPC runtime # Whether enable RPC runtime
set(USE_RPC ON) set(USE_RPC ON)
......
...@@ -53,6 +53,8 @@ TVM_DECLARE_INTRIN_UNARY(tanh); ...@@ -53,6 +53,8 @@ TVM_DECLARE_INTRIN_UNARY(tanh);
TVM_DECLARE_INTRIN_UNARY(sigmoid); TVM_DECLARE_INTRIN_UNARY(sigmoid);
TVM_DECLARE_INTRIN_UNARY(sqrt); TVM_DECLARE_INTRIN_UNARY(sqrt);
TVM_DECLARE_INTRIN_UNARY(log); TVM_DECLARE_INTRIN_UNARY(log);
TVM_DECLARE_INTRIN_UNARY(floor);
TVM_DECLARE_INTRIN_UNARY(ceil);
inline Expr pow(Expr x, Expr y) { inline Expr pow(Expr x, Expr y) {
return ir::Call::make(x.type(), "pow", { x, y }, ir::Call::PureIntrinsic); return ir::Call::make(x.type(), "pow", { x, y }, ir::Call::PureIntrinsic);
......
...@@ -233,6 +233,38 @@ def sqrt(x): ...@@ -233,6 +233,38 @@ def sqrt(x):
return call_pure_intrin(x.dtype, "sqrt", x) return call_pure_intrin(x.dtype, "sqrt", x)
def floor(x):
"""Take floor of float input x.
Parameters
----------
x : Expr
Input argument.
Returns
-------
y : Expr
The result.
"""
return call_pure_intrin(x.dtype, "floor", x)
def ceil(x):
"""Take ceil of float input x.
Parameters
----------
x : Expr
Input argument.
Returns
-------
y : Expr
The result.
"""
return call_pure_intrin(x.dtype, "ceil", x)
def power(x, y): def power(x, y):
"""x power y """x power y
......
...@@ -55,6 +55,12 @@ struct CUDAShuffle { ...@@ -55,6 +55,12 @@ struct CUDAShuffle {
} }
}; };
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.floor")
.set_body(DispatchExtern<CUDAMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.ceil")
.set_body(DispatchExtern<CUDAMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp") TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp")
.set_body(DispatchExtern<CUDAFastMath>); .set_body(DispatchExtern<CUDAFastMath>);
......
...@@ -9,6 +9,12 @@ namespace tvm { ...@@ -9,6 +9,12 @@ namespace tvm {
namespace codegen { namespace codegen {
namespace intrin { namespace intrin {
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.floor")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.ceil")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp") TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp")
.set_body(DispatchExtern<Direct>); .set_body(DispatchExtern<Direct>);
......
...@@ -9,6 +9,12 @@ namespace tvm { ...@@ -9,6 +9,12 @@ namespace tvm {
namespace codegen { namespace codegen {
namespace intrin { namespace intrin {
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.floor")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.ceil")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp") TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp")
.set_body(DispatchExtern<Direct>); .set_body(DispatchExtern<Direct>);
......
...@@ -9,6 +9,12 @@ namespace tvm { ...@@ -9,6 +9,12 @@ namespace tvm {
namespace codegen { namespace codegen {
namespace intrin { namespace intrin {
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.floor")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.ceil")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.exp") TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.exp")
.set_body(DispatchExtern<Direct>); .set_body(DispatchExtern<Direct>);
......
...@@ -25,6 +25,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log") ...@@ -25,6 +25,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sqrt") TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sqrt")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt, 1>); .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.floor")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::floor, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.ceil")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tanh") TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tanh")
.set_body([](const TVMArgs& targs, TVMRetValue* rv) { .set_body([](const TVMArgs& targs, TVMRetValue* rv) {
Expr e = targs[0]; Expr e = targs[0];
......
...@@ -26,6 +26,12 @@ inline void DispatchExternOCML(const TVMArgs& args, TVMRetValue* rv) { ...@@ -26,6 +26,12 @@ inline void DispatchExternOCML(const TVMArgs& args, TVMRetValue* rv) {
namespace llvm { namespace llvm {
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.floor")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::floor, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.ceil")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp") TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp")
.set_body(DispatchExternOCML); .set_body(DispatchExternOCML);
......
...@@ -3,8 +3,6 @@ ...@@ -3,8 +3,6 @@
* \file build_vulkan.cc * \file build_vulkan.cc
* \brief Build SPIRV block * \brief Build SPIRV block
*/ */
#if TVM_VULKAN_RUNTIME
// Use libspirv for parsing and validating code. // Use libspirv for parsing and validating code.
#include <vulkan/libspirv.h> #include <vulkan/libspirv.h>
#include <dmlc/memory_io.h> #include <dmlc/memory_io.h>
...@@ -92,4 +90,3 @@ TVM_REGISTER_API("codegen.build_vulkan") ...@@ -92,4 +90,3 @@ TVM_REGISTER_API("codegen.build_vulkan")
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
#endif // TVM_VULKAN_RUNTIME
...@@ -3,9 +3,6 @@ ...@@ -3,9 +3,6 @@
* \file codegen_spirv.cc * \file codegen_spirv.cc
* \brief Generate SPIRV block * \brief Generate SPIRV block
*/ */
#if TVM_VULKAN_RUNTIME
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include "../codegen_common.h" #include "../codegen_common.h"
...@@ -634,5 +631,3 @@ void CodeGenSPIRV::VisitStmt_(const ProducerConsumer* op) { ...@@ -634,5 +631,3 @@ void CodeGenSPIRV::VisitStmt_(const ProducerConsumer* op) {
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
#endif // TVM_VULKAN_RUNTIME
...@@ -2,8 +2,6 @@ ...@@ -2,8 +2,6 @@
* Copyright (c) 2017 by Contributors * Copyright (c) 2017 by Contributors
* \file intrin_rule_spirv.cc * \file intrin_rule_spirv.cc
*/ */
#if TVM_VULKAN_RUNTIME
#include <tvm/packed_func_ext.h> #include <tvm/packed_func_ext.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <vulkan/GLSL.std.450.h> #include <vulkan/GLSL.std.450.h>
...@@ -31,6 +29,12 @@ inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { ...@@ -31,6 +29,12 @@ inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
call->type, "spirv_glsl450", cargs, ir::Call::PureIntrinsic); call->type, "spirv_glsl450", cargs, ir::Call::PureIntrinsic);
} }
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.floor")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Floor>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.ceil")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Ceil>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.exp") TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.exp")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Exp>); .set_body(DispatchGLSLPureIntrin<GLSLstd450Exp>);
...@@ -43,8 +47,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.sqrt") ...@@ -43,8 +47,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.sqrt")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.pow") TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.pow")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Pow>); .set_body(DispatchGLSLPureIntrin<GLSLstd450Pow>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.tanh")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Tanh>);
} // namespace spirv } // namespace spirv
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
#endif // TVM_VULKAN_RUNTIME
...@@ -3,9 +3,6 @@ ...@@ -3,9 +3,6 @@
* \file ir_builder.cc * \file ir_builder.cc
* \brief IRBuilder for SPIRV block * \brief IRBuilder for SPIRV block
*/ */
#if TVM_VULKAN_RUNTIME
#include "./ir_builder.h" #include "./ir_builder.h"
namespace tvm { namespace tvm {
...@@ -555,5 +552,3 @@ Value IRBuilder::Select(Value cond, Value a, Value b) { ...@@ -555,5 +552,3 @@ Value IRBuilder::Select(Value cond, Value a, Value b) {
} // namespace spirv } // namespace spirv
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
#endif // TVM_VULKAN_RUNTIME
...@@ -29,6 +29,8 @@ TOPI_DECLARE_UNARY_OP(tanh); ...@@ -29,6 +29,8 @@ TOPI_DECLARE_UNARY_OP(tanh);
TOPI_DECLARE_UNARY_OP(sigmoid); TOPI_DECLARE_UNARY_OP(sigmoid);
TOPI_DECLARE_UNARY_OP(sqrt); TOPI_DECLARE_UNARY_OP(sqrt);
TOPI_DECLARE_UNARY_OP(log); TOPI_DECLARE_UNARY_OP(log);
TOPI_DECLARE_UNARY_OP(floor);
TOPI_DECLARE_UNARY_OP(ceil);
/*! /*!
* \brief Creates an operation that returns identity of a given tensor * \brief Creates an operation that returns identity of a given tensor
......
...@@ -74,6 +74,40 @@ def tanh(x): ...@@ -74,6 +74,40 @@ def tanh(x):
@tvm.tag_scope(tag=tag.ELEMWISE) @tvm.tag_scope(tag=tag.ELEMWISE)
def floor(x):
"""Take floor of input x.
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
return tvm.compute(x.shape, lambda *i: tvm.floor(x(*i)))
@tvm.tag_scope(tag=tag.ELEMWISE)
def ceil(x):
"""Take ceil of input x.
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
return tvm.compute(x.shape, lambda *i: tvm.ceil(x(*i)))
@tvm.tag_scope(tag=tag.ELEMWISE)
def log(x): def log(x):
"""Take logarithm of input x. """Take logarithm of input x.
......
...@@ -18,12 +18,11 @@ def test_ewise(): ...@@ -18,12 +18,11 @@ def test_ewise():
shape = (20, 3) shape = (20, 3)
def test_apply(func, name, f_numpy): def test_apply(func, name, f_numpy, low, high):
B = func(A) B = func(A)
assert tuple(B.shape) == tuple(A.shape) assert tuple(B.shape) == tuple(A.shape)
assert B.op.body[0].name == name assert B.op.body[0].name == name
a_np = np.random.uniform(low=1e-5, size=shape).astype(A.dtype) a_np = np.random.uniform(low=low, high=high, size=shape).astype(A.dtype) * 10
a_np = np.abs(a_np)
b_np = f_numpy(a_np) b_np = f_numpy(a_np)
def check_device(device): def check_device(device):
...@@ -43,11 +42,14 @@ def test_ewise(): ...@@ -43,11 +42,14 @@ def test_ewise():
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'llvm']: for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'llvm']:
check_device(device) check_device(device)
test_apply(topi.exp, "exp", np.exp)
test_apply(topi.tanh, "tanh", np.tanh) test_apply(topi.floor, "floor", np.floor, -100, 100)
test_apply(topi.sigmoid, "sigmoid", lambda x:1/(1+np.exp(-x))) test_apply(topi.ceil, "ceil", np.ceil, -100, 100)
test_apply(topi.log, "log", np.log) test_apply(topi.exp, "exp", np.exp, -1, 1)
test_apply(topi.sqrt, "sqrt", np.sqrt) test_apply(topi.tanh, "tanh", np.tanh, -10, 10)
test_apply(topi.sigmoid, "sigmoid", lambda x:1/(1+np.exp(-x)), -1, 1)
test_apply(topi.log, "log", np.log, 0, 100)
test_apply(topi.sqrt, "sqrt", np.sqrt, 0, 100)
if __name__ == "__main__": if __name__ == "__main__":
test_util() test_util()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment