Commit 0a6c36ce by Tianqi Chen Committed by GitHub

[INTRIN] Enable pow (#471)

* [INTRIN] Enable pow

* rename pow->power

* fix
parent 87b95e2e
...@@ -206,6 +206,25 @@ def sqrt(x): ...@@ -206,6 +206,25 @@ def sqrt(x):
return call_pure_intrin(x.dtype, "sqrt", x) return call_pure_intrin(x.dtype, "sqrt", x)
def power(x, y):
"""x power y
Parameters
----------
x : Expr
Input argument.
y : Expr
The exponent
Returns
-------
z : Expr
The result.
"""
return call_pure_intrin(x.dtype, "pow", x, y)
# Intrinsic rule related code # Intrinsic rule related code
def register_intrin_rule(target, intrin, f=None, override=False): def register_intrin_rule(target, intrin, f=None, override=False):
"""Register an intrinsic function generation rule. """Register an intrinsic function generation rule.
......
...@@ -21,6 +21,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tanh") ...@@ -21,6 +21,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tanh")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sqrt") TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sqrt")
.set_body(DispatchExtern<FloatSuffix>); .set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.pow")
.set_body(DispatchExtern<FloatSuffix>);
} // namespace intrin } // namespace intrin
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
...@@ -48,6 +48,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tanh") ...@@ -48,6 +48,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tanh")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sqrt") TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sqrt")
.set_body(DispatchExtern<CUDAMath>); .set_body(DispatchExtern<CUDAMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.pow")
.set_body(DispatchExtern<CUDAMath>);
} // namespace intrin } // namespace intrin
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
...@@ -21,6 +21,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tanh") ...@@ -21,6 +21,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tanh")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sqrt") TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sqrt")
.set_body(DispatchExtern<FloatDirect>); .set_body(DispatchExtern<FloatDirect>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.pow")
.set_body(DispatchExtern<FloatDirect>);
} // namespace intrin } // namespace intrin
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
...@@ -76,6 +76,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log") ...@@ -76,6 +76,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sqrt") TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sqrt")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt>); .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.pow")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::pow>);
} // namespace llvm } // namespace llvm
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
......
...@@ -36,11 +36,11 @@ def test_exp(): ...@@ -36,11 +36,11 @@ def test_exp():
check_device("opencl") check_device("opencl")
def test_log_llvm(): def test_log_pow_llvm():
# graph # graph
n = tvm.var('n') n = tvm.var('n')
A = tvm.placeholder((n,), name='A') A = tvm.placeholder((n,), name='A')
B = tvm.compute(A.shape, lambda *i: tvm.log(A(*i)), name='B') B = tvm.compute(A.shape, lambda *i: tvm.power(tvm.log(A(*i)), 2.0), name='B')
s = tvm.create_schedule(B.op) s = tvm.create_schedule(B.op)
# create iter var and assign them tags. # create iter var and assign them tags.
bx, tx = s[B].split(B.op.axis[0], factor=32) bx, tx = s[B].split(B.op.axis[0], factor=32)
...@@ -57,7 +57,7 @@ def test_log_llvm(): ...@@ -57,7 +57,7 @@ def test_log_llvm():
b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx) b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
flog(a, b) flog(a, b)
np.testing.assert_allclose( np.testing.assert_allclose(
b.asnumpy(), np.log(a.asnumpy()), rtol=1e-5) b.asnumpy(), np.power(np.log(a.asnumpy()), 2.0), rtol=1e-5)
def test_add(): def test_add():
...@@ -106,6 +106,6 @@ def test_add(): ...@@ -106,6 +106,6 @@ def test_add():
if __name__ == "__main__": if __name__ == "__main__":
test_log_llvm() test_log_pow_llvm()
test_exp() test_exp()
test_add() test_add()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment