Commit 0a6c36ce by Tianqi Chen Committed by GitHub

[INTRIN] Enable pow (#471)

* [INTRIN] Enable pow

* rename pow->power

* fix
parent 87b95e2e
......@@ -206,6 +206,25 @@ def sqrt(x):
return call_pure_intrin(x.dtype, "sqrt", x)
def power(x, y):
"""x power y
Parameters
----------
x : Expr
Input argument.
y : Expr
The exponent
Returns
-------
z : Expr
The result.
"""
return call_pure_intrin(x.dtype, "pow", x, y)
# Intrinsic rule related code
def register_intrin_rule(target, intrin, f=None, override=False):
"""Register an intrinsic function generation rule.
......
......@@ -21,6 +21,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tanh")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sqrt")
.set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.pow")
.set_body(DispatchExtern<FloatSuffix>);
} // namespace intrin
} // namespace codegen
} // namespace tvm
......@@ -48,6 +48,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tanh")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sqrt")
.set_body(DispatchExtern<CUDAMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.pow")
.set_body(DispatchExtern<CUDAMath>);
} // namespace intrin
} // namespace codegen
} // namespace tvm
......@@ -21,6 +21,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tanh")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sqrt")
.set_body(DispatchExtern<FloatDirect>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.pow")
.set_body(DispatchExtern<FloatDirect>);
} // namespace intrin
} // namespace codegen
} // namespace tvm
......@@ -76,6 +76,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sqrt")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.pow")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::pow>);
} // namespace llvm
} // namespace codegen
} // namespace tvm
......
......@@ -36,11 +36,11 @@ def test_exp():
check_device("opencl")
def test_log_llvm():
def test_log_pow_llvm():
# graph
n = tvm.var('n')
A = tvm.placeholder((n,), name='A')
B = tvm.compute(A.shape, lambda *i: tvm.log(A(*i)), name='B')
B = tvm.compute(A.shape, lambda *i: tvm.power(tvm.log(A(*i)), 2.0), name='B')
s = tvm.create_schedule(B.op)
# create iter var and assign them tags.
bx, tx = s[B].split(B.op.axis[0], factor=32)
......@@ -57,7 +57,7 @@ def test_log_llvm():
b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
flog(a, b)
np.testing.assert_allclose(
b.asnumpy(), np.log(a.asnumpy()), rtol=1e-5)
b.asnumpy(), np.power(np.log(a.asnumpy()), 2.0), rtol=1e-5)
def test_add():
......@@ -106,6 +106,6 @@ def test_add():
if __name__ == "__main__":
test_log_llvm()
test_log_pow_llvm()
test_exp()
test_add()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment