Commit 29d5ffbb by ziheng Committed by Tianqi Chen

[INTRINSIC] Add sqrt (#202)

* [INTRINSIC] Add sqrt

* [INTRINSIC] Expose on cpp
parent b8e02348
......@@ -51,6 +51,7 @@ Expr min(Expr source, Array<IterVar> axis);
TVM_DECLARE_INTRIN_UNARY(exp);
TVM_DECLARE_INTRIN_UNARY(tanh);
TVM_DECLARE_INTRIN_UNARY(sigmoid);
TVM_DECLARE_INTRIN_UNARY(sqrt);
} // namespace tvm
#endif // TVM_IR_OPERATOR_H_
......@@ -166,6 +166,22 @@ def log(x):
return call_pure_intrin(x.dtype, "log", x)
def sqrt(x):
"""Take log of input x.
Parameters
----------
x : Expr
Input argument.
Returns
-------
y : Expr
The result.
"""
return call_pure_intrin(x.dtype, "sqrt", x)
# Intrinsic rule related code
def register_intrin_rule(target, intrin, f=None, override=False):
"""Register an intrinsic function generation rule.
......
......@@ -18,6 +18,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tanh")
.set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sqrt")
.set_body(DispatchExtern<FloatSuffix>);
} // namespace intrin
} // namespace codegen
} // namespace tvm
......@@ -45,6 +45,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tanh")
.set_body(DispatchExtern<CUDAMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sqrt")
.set_body(DispatchExtern<CUDAMath>);
} // namespace intrin
} // namespace codegen
} // namespace tvm
......@@ -18,6 +18,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tanh")
.set_body(DispatchExtern<FloatDirect>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sqrt")
.set_body(DispatchExtern<FloatDirect>);
} // namespace intrin
} // namespace codegen
} // namespace tvm
......@@ -37,6 +37,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log")
.set_body(DispatchLLVMIntrin<::llvm::Intrinsic::log>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sqrt")
.set_body(DispatchLLVMIntrin<::llvm::Intrinsic::sqrt>);
} // namespace llvm
} // namespace codegen
} // namespace tvm
......
......@@ -14,6 +14,8 @@ def test_ewise():
test_apply(topi.exp, "exp")
test_apply(topi.tanh, "tanh")
test_apply(topi.sigmoid, "sigmoid")
test_apply(topi.log, "log")
test_apply(topi.sqrt, "sqrt")
if __name__ == "__main__":
......
......@@ -22,5 +22,6 @@ using namespace tvm;
TOPI_DECLARE_UNARY_OP(exp);
TOPI_DECLARE_UNARY_OP(tanh);
TOPI_DECLARE_UNARY_OP(sigmoid);
TOPI_DECLARE_UNARY_OP(sqrt);
} // namespace topi
#endif // TOPI_EWISE_H_
......@@ -34,6 +34,38 @@ def tanh(x):
return tvm.compute(x.shape, lambda *i: tvm.tanh(x(*i)))
def log(x):
"""Take logarithm of input x.
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
return tvm.compute(x.shape, lambda *i: tvm.log(x(*i)))
def sqrt(x):
"""Take square root of input x.
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
return tvm.compute(x.shape, lambda *i: tvm.sqrt(x(*i)))
def sigmoid(x):
"""Take sigmoid tanh of input x.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment