Commit 29d5ffbb by ziheng Committed by Tianqi Chen

[INTRINSIC] Add sqrt (#202)

* [INTRINSIC] Add sqrt

* [INTRINSIC] Expose on cpp
parent b8e02348
...@@ -51,6 +51,7 @@ Expr min(Expr source, Array<IterVar> axis); ...@@ -51,6 +51,7 @@ Expr min(Expr source, Array<IterVar> axis);
TVM_DECLARE_INTRIN_UNARY(exp); TVM_DECLARE_INTRIN_UNARY(exp);
TVM_DECLARE_INTRIN_UNARY(tanh); TVM_DECLARE_INTRIN_UNARY(tanh);
TVM_DECLARE_INTRIN_UNARY(sigmoid); TVM_DECLARE_INTRIN_UNARY(sigmoid);
TVM_DECLARE_INTRIN_UNARY(sqrt);
} // namespace tvm } // namespace tvm
#endif // TVM_IR_OPERATOR_H_ #endif // TVM_IR_OPERATOR_H_
...@@ -166,6 +166,22 @@ def log(x): ...@@ -166,6 +166,22 @@ def log(x):
return call_pure_intrin(x.dtype, "log", x) return call_pure_intrin(x.dtype, "log", x)
def sqrt(x):
"""Take log of input x.
Parameters
----------
x : Expr
Input argument.
Returns
-------
y : Expr
The result.
"""
return call_pure_intrin(x.dtype, "sqrt", x)
# Intrinsic rule related code # Intrinsic rule related code
def register_intrin_rule(target, intrin, f=None, override=False): def register_intrin_rule(target, intrin, f=None, override=False):
"""Register an intrinsic function generation rule. """Register an intrinsic function generation rule.
......
...@@ -18,6 +18,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log") ...@@ -18,6 +18,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tanh") TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tanh")
.set_body(DispatchExtern<FloatSuffix>); .set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sqrt")
.set_body(DispatchExtern<FloatSuffix>);
} // namespace intrin } // namespace intrin
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
...@@ -45,6 +45,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log") ...@@ -45,6 +45,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tanh") TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tanh")
.set_body(DispatchExtern<CUDAMath>); .set_body(DispatchExtern<CUDAMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sqrt")
.set_body(DispatchExtern<CUDAMath>);
} // namespace intrin } // namespace intrin
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
...@@ -18,6 +18,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log") ...@@ -18,6 +18,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tanh") TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tanh")
.set_body(DispatchExtern<FloatDirect>); .set_body(DispatchExtern<FloatDirect>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sqrt")
.set_body(DispatchExtern<FloatDirect>);
} // namespace intrin } // namespace intrin
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
...@@ -37,6 +37,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp") ...@@ -37,6 +37,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log") TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log")
.set_body(DispatchLLVMIntrin<::llvm::Intrinsic::log>); .set_body(DispatchLLVMIntrin<::llvm::Intrinsic::log>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sqrt")
.set_body(DispatchLLVMIntrin<::llvm::Intrinsic::sqrt>);
} // namespace llvm } // namespace llvm
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
......
...@@ -14,6 +14,8 @@ def test_ewise(): ...@@ -14,6 +14,8 @@ def test_ewise():
test_apply(topi.exp, "exp") test_apply(topi.exp, "exp")
test_apply(topi.tanh, "tanh") test_apply(topi.tanh, "tanh")
test_apply(topi.sigmoid, "sigmoid") test_apply(topi.sigmoid, "sigmoid")
test_apply(topi.log, "log")
test_apply(topi.sqrt, "sqrt")
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -22,5 +22,6 @@ using namespace tvm; ...@@ -22,5 +22,6 @@ using namespace tvm;
TOPI_DECLARE_UNARY_OP(exp); TOPI_DECLARE_UNARY_OP(exp);
TOPI_DECLARE_UNARY_OP(tanh); TOPI_DECLARE_UNARY_OP(tanh);
TOPI_DECLARE_UNARY_OP(sigmoid); TOPI_DECLARE_UNARY_OP(sigmoid);
TOPI_DECLARE_UNARY_OP(sqrt);
} // namespace topi } // namespace topi
#endif // TOPI_EWISE_H_ #endif // TOPI_EWISE_H_
...@@ -34,6 +34,38 @@ def tanh(x): ...@@ -34,6 +34,38 @@ def tanh(x):
return tvm.compute(x.shape, lambda *i: tvm.tanh(x(*i))) return tvm.compute(x.shape, lambda *i: tvm.tanh(x(*i)))
def log(x):
"""Take logarithm of input x.
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
return tvm.compute(x.shape, lambda *i: tvm.log(x(*i)))
def sqrt(x):
"""Take square root of input x.
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
return tvm.compute(x.shape, lambda *i: tvm.sqrt(x(*i)))
def sigmoid(x): def sigmoid(x):
"""Take sigmoid tanh of input x. """Take sigmoid tanh of input x.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment