Commit dc6203c2 by Tianqi Chen Committed by GitHub

[INTRIN] Add support for floor and ceil (#1267)

parent 652b397f
......@@ -38,6 +38,9 @@ set(USE_METAL OFF)
# Whether enable Vulkan runtime
set(USE_VULKAN OFF)
# Whether enable OpenGL runtime
set(USE_OPENGL OFF)
# Whether enable RPC runtime
set(USE_RPC ON)
......
......@@ -53,6 +53,8 @@ TVM_DECLARE_INTRIN_UNARY(tanh);
TVM_DECLARE_INTRIN_UNARY(sigmoid);
TVM_DECLARE_INTRIN_UNARY(sqrt);
TVM_DECLARE_INTRIN_UNARY(log);
TVM_DECLARE_INTRIN_UNARY(floor);
TVM_DECLARE_INTRIN_UNARY(ceil);
inline Expr pow(Expr x, Expr y) {
return ir::Call::make(x.type(), "pow", { x, y }, ir::Call::PureIntrinsic);
......
......@@ -233,6 +233,38 @@ def sqrt(x):
return call_pure_intrin(x.dtype, "sqrt", x)
def floor(x):
"""Take floor of float input x.
Parameters
----------
x : Expr
Input argument.
Returns
-------
y : Expr
The result.
"""
return call_pure_intrin(x.dtype, "floor", x)
def ceil(x):
"""Take ceil of float input x.
Parameters
----------
x : Expr
Input argument.
Returns
-------
y : Expr
The result.
"""
return call_pure_intrin(x.dtype, "ceil", x)
def power(x, y):
"""x power y
......
......@@ -55,6 +55,12 @@ struct CUDAShuffle {
}
};
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.floor")
.set_body(DispatchExtern<CUDAMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.ceil")
.set_body(DispatchExtern<CUDAMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp")
.set_body(DispatchExtern<CUDAFastMath>);
......
......@@ -9,6 +9,12 @@ namespace tvm {
namespace codegen {
namespace intrin {
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.floor")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.ceil")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp")
.set_body(DispatchExtern<Direct>);
......
......@@ -9,6 +9,12 @@ namespace tvm {
namespace codegen {
namespace intrin {
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.floor")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.ceil")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp")
.set_body(DispatchExtern<Direct>);
......
......@@ -9,6 +9,12 @@ namespace tvm {
namespace codegen {
namespace intrin {
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.floor")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.ceil")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.exp")
.set_body(DispatchExtern<Direct>);
......
......@@ -25,6 +25,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sqrt")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.floor")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::floor, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.ceil")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tanh")
.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
Expr e = targs[0];
......
......@@ -26,6 +26,12 @@ inline void DispatchExternOCML(const TVMArgs& args, TVMRetValue* rv) {
namespace llvm {
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.floor")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::floor, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.ceil")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp")
.set_body(DispatchExternOCML);
......
......@@ -3,8 +3,6 @@
* \file build_vulkan.cc
* \brief Build SPIRV block
*/
#if TVM_VULKAN_RUNTIME
// Use libspirv for parsing and validating code.
#include <vulkan/libspirv.h>
#include <dmlc/memory_io.h>
......@@ -92,4 +90,3 @@ TVM_REGISTER_API("codegen.build_vulkan")
} // namespace codegen
} // namespace tvm
#endif // TVM_VULKAN_RUNTIME
......@@ -3,9 +3,6 @@
* \file codegen_spirv.cc
* \brief Generate SPIRV block
*/
#if TVM_VULKAN_RUNTIME
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include "../codegen_common.h"
......@@ -634,5 +631,3 @@ void CodeGenSPIRV::VisitStmt_(const ProducerConsumer* op) {
} // namespace codegen
} // namespace tvm
#endif // TVM_VULKAN_RUNTIME
......@@ -2,8 +2,6 @@
* Copyright (c) 2017 by Contributors
* \file intrin_rule_spirv.cc
*/
#if TVM_VULKAN_RUNTIME
#include <tvm/packed_func_ext.h>
#include <tvm/ir.h>
#include <vulkan/GLSL.std.450.h>
......@@ -31,6 +29,12 @@ inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
call->type, "spirv_glsl450", cargs, ir::Call::PureIntrinsic);
}
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.floor")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Floor>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.ceil")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Ceil>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.exp")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Exp>);
......@@ -43,8 +47,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.sqrt")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.pow")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Pow>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.tanh")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Tanh>);
} // namespace spirv
} // namespace codegen
} // namespace tvm
#endif // TVM_VULKAN_RUNTIME
......@@ -3,9 +3,6 @@
* \file ir_builder.cc
* \brief IRBuilder for SPIRV block
*/
#if TVM_VULKAN_RUNTIME
#include "./ir_builder.h"
namespace tvm {
......@@ -555,5 +552,3 @@ Value IRBuilder::Select(Value cond, Value a, Value b) {
} // namespace spirv
} // namespace codegen
} // namespace tvm
#endif // TVM_VULKAN_RUNTIME
......@@ -29,6 +29,8 @@ TOPI_DECLARE_UNARY_OP(tanh);
TOPI_DECLARE_UNARY_OP(sigmoid);
TOPI_DECLARE_UNARY_OP(sqrt);
TOPI_DECLARE_UNARY_OP(log);
TOPI_DECLARE_UNARY_OP(floor);
TOPI_DECLARE_UNARY_OP(ceil);
/*!
* \brief Creates an operation that returns identity of a given tensor
......
......@@ -74,6 +74,40 @@ def tanh(x):
@tvm.tag_scope(tag=tag.ELEMWISE)
def floor(x):
"""Take floor of input x.
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
return tvm.compute(x.shape, lambda *i: tvm.floor(x(*i)))
@tvm.tag_scope(tag=tag.ELEMWISE)
def ceil(x):
"""Take ceil of input x.
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
return tvm.compute(x.shape, lambda *i: tvm.ceil(x(*i)))
@tvm.tag_scope(tag=tag.ELEMWISE)
def log(x):
"""Take logarithm of input x.
......
......@@ -18,12 +18,11 @@ def test_ewise():
shape = (20, 3)
def test_apply(func, name, f_numpy):
def test_apply(func, name, f_numpy, low, high):
B = func(A)
assert tuple(B.shape) == tuple(A.shape)
assert B.op.body[0].name == name
a_np = np.random.uniform(low=1e-5, size=shape).astype(A.dtype)
a_np = np.abs(a_np)
a_np = np.random.uniform(low=low, high=high, size=shape).astype(A.dtype) * 10
b_np = f_numpy(a_np)
def check_device(device):
......@@ -43,11 +42,14 @@ def test_ewise():
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'llvm']:
check_device(device)
test_apply(topi.exp, "exp", np.exp)
test_apply(topi.tanh, "tanh", np.tanh)
test_apply(topi.sigmoid, "sigmoid", lambda x:1/(1+np.exp(-x)))
test_apply(topi.log, "log", np.log)
test_apply(topi.sqrt, "sqrt", np.sqrt)
test_apply(topi.floor, "floor", np.floor, -100, 100)
test_apply(topi.ceil, "ceil", np.ceil, -100, 100)
test_apply(topi.exp, "exp", np.exp, -1, 1)
test_apply(topi.tanh, "tanh", np.tanh, -10, 10)
test_apply(topi.sigmoid, "sigmoid", lambda x:1/(1+np.exp(-x)), -1, 1)
test_apply(topi.log, "log", np.log, 0, 100)
test_apply(topi.sqrt, "sqrt", np.sqrt, 0, 100)
if __name__ == "__main__":
test_util()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment