Unverified Commit 45ee7b5f by notoraptor Committed by GitHub

[topi][relay] new PR to re-add tan to TVM (#5025)

* Add relay operation relay.op.tan.

* Update tan implementation in TVM.

* Update tests.

* Add shape function for tan.

* Add missing main test to python/frontend/tensorflow/test_forward.

* Revert, back to sin/cos.

* Revert "Revert, back to sin/cos."

This reverts commit 4da5b503b921585ba9d80944b29136142b575c40.

* Fix implementation of tan in cuda. Do not support tan for float16.

Simplify topi/tests/python/test_topi_math. Add testing for tan with float32 and float64.

Finally implement tan as sin/cos in llvm.
parent 6026af50
...@@ -135,6 +135,7 @@ Supported Ops ...@@ -135,6 +135,7 @@ Supported Ops
- ConcatV2 - ConcatV2
- Conv2D - Conv2D
- Cos - Cos
- Tan
- CropAndResize - CropAndResize
- DecodeJpeg - DecodeJpeg
- DepthwiseConv2dNative - DepthwiseConv2dNative
......
...@@ -515,6 +515,7 @@ TVM_DECLARE_INTRIN_UNARY(sqrt); ...@@ -515,6 +515,7 @@ TVM_DECLARE_INTRIN_UNARY(sqrt);
TVM_DECLARE_INTRIN_UNARY(rsqrt); TVM_DECLARE_INTRIN_UNARY(rsqrt);
TVM_DECLARE_INTRIN_UNARY(log); TVM_DECLARE_INTRIN_UNARY(log);
TVM_DECLARE_INTRIN_UNARY(popcount); TVM_DECLARE_INTRIN_UNARY(popcount);
TVM_DECLARE_INTRIN_UNARY(tan);
TVM_DECLARE_INTRIN_UNARY(cos); TVM_DECLARE_INTRIN_UNARY(cos);
TVM_DECLARE_INTRIN_UNARY(sin); TVM_DECLARE_INTRIN_UNARY(sin);
TVM_DECLARE_INTRIN_UNARY(atan); TVM_DECLARE_INTRIN_UNARY(atan);
......
...@@ -1696,6 +1696,7 @@ _identity_list = [ ...@@ -1696,6 +1696,7 @@ _identity_list = [
"ones_like", "ones_like",
"where", "where",
"gather_nd", "gather_nd",
"tan",
"cos", "cos",
"sin" "sin"
] ]
......
...@@ -1572,6 +1572,7 @@ _convert_map = { ...@@ -1572,6 +1572,7 @@ _convert_map = {
'LessEqual' : _broadcast('less_equal'), 'LessEqual' : _broadcast('less_equal'),
'Log' : AttrCvt('log'), 'Log' : AttrCvt('log'),
'Log1p' : _log1p(), 'Log1p' : _log1p(),
'Tan' : AttrCvt('tan'),
'Cos' : AttrCvt('cos'), 'Cos' : AttrCvt('cos'),
'Sin' : AttrCvt('sin'), 'Sin' : AttrCvt('sin'),
'LogicalAnd' : _logical('logical_and'), 'LogicalAnd' : _logical('logical_and'),
......
...@@ -68,6 +68,7 @@ class OperatorConverter(object): ...@@ -68,6 +68,7 @@ class OperatorConverter(object):
'LOG': self.convert_log, 'LOG': self.convert_log,
'SIN': self.convert_sin, 'SIN': self.convert_sin,
'COS': self.convert_cos, 'COS': self.convert_cos,
'TAN': self.convert_tan,
'SQRT': self.convert_sqrt, 'SQRT': self.convert_sqrt,
'RSQRT': self.convert_rsqrt, 'RSQRT': self.convert_rsqrt,
'NEG': self.convert_neg, 'NEG': self.convert_neg,
...@@ -657,6 +658,13 @@ class OperatorConverter(object): ...@@ -657,6 +658,13 @@ class OperatorConverter(object):
'TFlite quantized SIN operator is not supported yet.') 'TFlite quantized SIN operator is not supported yet.')
return self._convert_unary_elemwise(_op.sin, op) return self._convert_unary_elemwise(_op.sin, op)
def convert_tan(self, op):
"""Convert TFLite TAN"""
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized TAN operator is not supported yet.')
return self._convert_unary_elemwise(_op.tan, op)
def convert_cos(self, op): def convert_cos(self, op):
"""Convert TFLite COS""" """Convert TFLite COS"""
if self.is_quantized(op): if self.is_quantized(op):
......
...@@ -27,6 +27,7 @@ from ...hybrid import script ...@@ -27,6 +27,7 @@ from ...hybrid import script
register_broadcast_schedule("log") register_broadcast_schedule("log")
register_broadcast_schedule("tan")
register_broadcast_schedule("cos") register_broadcast_schedule("cos")
register_broadcast_schedule("sin") register_broadcast_schedule("sin")
register_broadcast_schedule("atan") register_broadcast_schedule("atan")
...@@ -214,3 +215,4 @@ register_shape_func("minimum", False, broadcast_shape_func) ...@@ -214,3 +215,4 @@ register_shape_func("minimum", False, broadcast_shape_func)
register_shape_func("sqrt", False, elemwise_shape_func) register_shape_func("sqrt", False, elemwise_shape_func)
register_shape_func("negative", False, elemwise_shape_func) register_shape_func("negative", False, elemwise_shape_func)
register_shape_func("exp", False, elemwise_shape_func) register_shape_func("exp", False, elemwise_shape_func)
register_shape_func("tan", False, elemwise_shape_func)
...@@ -61,6 +61,13 @@ def log_grad(orig, grad): ...@@ -61,6 +61,13 @@ def log_grad(orig, grad):
return [grad * ones_like(x) / x] return [grad * ones_like(x) / x]
@register_gradient("tan")
def tan_grad(orig, grad):
"""Returns [grad / (cos^2(x))]"""
x = orig.args[0]
return [grad / (cos(x) * cos(x))]
@register_gradient("cos") @register_gradient("cos")
def cos_grad(orig, grad): def cos_grad(orig, grad):
"""Returns [grad * (-sin(x))]""" """Returns [grad * (-sin(x))]"""
......
...@@ -47,6 +47,21 @@ def log(data): ...@@ -47,6 +47,21 @@ def log(data):
""" """
return _make.log(data) return _make.log(data)
def tan(data):
"""Compute elementwise tan of data.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.tan(data)
def cos(data): def cos(data):
"""Compute elementwise cos of data. """Compute elementwise cos of data.
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
""" """
# expose all operators in tvm tir.op # expose all operators in tvm tir.op
from tvm.tir import any, all, min_value, max_value, trace from tvm.tir import any, all, min_value, max_value, trace
from tvm.tir import exp, erf, tanh, sigmoid, log, cos, sin, atan, sqrt, rsqrt, floor, ceil from tvm.tir import exp, erf, tanh, sigmoid, log, tan, cos, sin, atan, sqrt, rsqrt, floor, ceil
from tvm.tir import trunc, abs, round, nearbyint, isnan, power, popcount, fmod, if_then_else from tvm.tir import trunc, abs, round, nearbyint, isnan, power, popcount, fmod, if_then_else
from tvm.tir import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod from tvm.tir import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod
from tvm.tir import comm_reducer, min, max, sum from tvm.tir import comm_reducer, min, max, sum
......
...@@ -33,7 +33,7 @@ from .stmt import IfThenElse, Evaluate, Prefetch, LoweredFunc, stmt_seq, stmt_li ...@@ -33,7 +33,7 @@ from .stmt import IfThenElse, Evaluate, Prefetch, LoweredFunc, stmt_seq, stmt_li
from .op import call_packed, call_pure_intrin, call_intrin, call_pure_extern, call_extern from .op import call_packed, call_pure_intrin, call_intrin, call_pure_extern, call_extern
from .op import call_llvm_intrin, all, any, min_value, max_value, trace from .op import call_llvm_intrin, all, any, min_value, max_value, trace
from .op import exp, erf, tanh, sigmoid, log, cos, sin, atan, sqrt, rsqrt, floor, ceil from .op import exp, erf, tanh, sigmoid, log, tan, cos, sin, atan, sqrt, rsqrt, floor, ceil
from .op import trunc, abs, round, nearbyint, isnan, power, popcount, fmod, if_then_else from .op import trunc, abs, round, nearbyint, isnan, power, popcount, fmod, if_then_else
from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod
from .op import comm_reducer, min, max, sum from .op import comm_reducer, min, max, sum
......
...@@ -393,6 +393,22 @@ def log(x): ...@@ -393,6 +393,22 @@ def log(x):
""" """
return call_pure_intrin(x.dtype, "log", x) return call_pure_intrin(x.dtype, "log", x)
def tan(x):
"""Take tan of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x.dtype, "tan", x)
def cos(x): def cos(x):
"""Take cos of input x. """Take cos of input x.
......
...@@ -51,6 +51,17 @@ RELAY_REGISTER_UNARY_OP("log") ...@@ -51,6 +51,17 @@ RELAY_REGISTER_UNARY_OP("log")
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log)); .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log));
RELAY_REGISTER_UNARY_OP("tan")
.describe(R"code(Returns the tan of input array, computed element-wise.
.. math::
Y = tan(X)
)code" TVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::tan));
RELAY_REGISTER_UNARY_OP("cos") RELAY_REGISTER_UNARY_OP("cos")
.describe(R"code(Returns the cos of input array, computed element-wise. .describe(R"code(Returns the cos of input array, computed element-wise.
......
...@@ -40,6 +40,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log") ...@@ -40,6 +40,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tanh") TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tanh")
.set_body(DispatchExtern<FloatSuffix>); .set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tan")
.set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cos") TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cos")
.set_body(DispatchExtern<FloatSuffix>); .set_body(DispatchExtern<FloatSuffix>);
......
...@@ -91,6 +91,20 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.pow") ...@@ -91,6 +91,20 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.pow")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.popcount") TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.popcount")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>); .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tan")
.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
PrimExpr e = targs[0];
const tir::CallNode* call = e.as<tir::CallNode>();
CHECK(call != nullptr);
const PrimExpr& x = call->args[0];
PrimExpr sin_x = tir::CallNode::make(
x.dtype(), "sin", {x}, tir::CallNode::PureIntrinsic);
PrimExpr cos_x = tir::CallNode::make(
x.dtype(), "cos", {x}, tir::CallNode::PureIntrinsic);
PrimExpr tan_x = sin_x / cos_x;
*rv = tan_x;
});
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.cos") TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.cos")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>); .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>);
......
...@@ -81,6 +81,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.pow") ...@@ -81,6 +81,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.pow")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tanh") TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tanh")
.set_body(DispatchExternLibDevice); .set_body(DispatchExternLibDevice);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tan")
.set_body(DispatchExternLibDevice);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cos") TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cos")
.set_body(DispatchExternLibDevice); .set_body(DispatchExternLibDevice);
......
...@@ -80,6 +80,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.pow") ...@@ -80,6 +80,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.pow")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tanh") TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tanh")
.set_body(DispatchExternOCML); .set_body(DispatchExternOCML);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tan")
.set_body(DispatchExternOCML);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.cos") TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.cos")
.set_body(DispatchExternOCML); .set_body(DispatchExternOCML);
......
...@@ -54,6 +54,22 @@ struct CUDAFastMath : public CUDAMath { ...@@ -54,6 +54,22 @@ struct CUDAFastMath : public CUDAMath {
} }
}; };
struct CUDAFastMathTan : public CUDAMath {
std::string operator()(DataType t, std::string name) const {
if (t.lanes() == 1 && t.is_float()) {
switch (t.bits()) {
case 64: return name;
// `__tanf` seems to produce some values too deviant from numpy tan version.
// So, let's use just `tanf` instead.
case 32: return name + 'f';
case 16: LOG(FATAL) << "cuda tan unsupported for float16";
default: return "";
}
}
return "";
}
};
struct CUDAPopcount { struct CUDAPopcount {
std::string operator()(DataType t, std::string name) const { std::string operator()(DataType t, std::string name) const {
if (t.lanes() == 1 && t.is_uint()) { if (t.lanes() == 1 && t.is_uint()) {
...@@ -97,6 +113,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.erf") ...@@ -97,6 +113,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.erf")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log") TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log")
.set_body(DispatchExtern<CUDAFastMath>); .set_body(DispatchExtern<CUDAFastMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tan")
.set_body(DispatchExtern<CUDAFastMathTan>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cos") TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cos")
.set_body(DispatchExtern<CUDAFastMath>); .set_body(DispatchExtern<CUDAFastMath>);
......
...@@ -229,7 +229,7 @@ PrimExpr LetNode::make(Var var, PrimExpr value, PrimExpr body) { ...@@ -229,7 +229,7 @@ PrimExpr LetNode::make(Var var, PrimExpr value, PrimExpr body) {
const char* CallNode::vectorizable_intrinsics[] = { const char* CallNode::vectorizable_intrinsics[] = {
"floor", "ceil", "sign", "trunc", "fabs", "round", "exp", "tanh", "sqrt", "floor", "ceil", "sign", "trunc", "fabs", "round", "exp", "tanh", "sqrt",
"log", "sin", "cos", "pow", tir::CallNode::shift_left, tir::CallNode::shift_right, "log", "sin", "cos", "pow", "tan", tir::CallNode::shift_left, tir::CallNode::shift_right,
tir::CallNode::likely, tir::CallNode::popcount tir::CallNode::likely, tir::CallNode::popcount
}; };
......
...@@ -2624,6 +2624,15 @@ def test_forward_cos(): ...@@ -2624,6 +2624,15 @@ def test_forward_cos():
compare_tf_with_tvm([np_data], ['in_data:0'], 'cos:0') compare_tf_with_tvm([np_data], ['in_data:0'], 'cos:0')
def test_forward_tan():
"""test operator tan """
np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32)
tf.reset_default_graph()
in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data")
tf.tan(in_data, name="tan")
compare_tf_with_tvm([np_data], ['in_data:0'], 'tan:0')
def test_forward_sin(): def test_forward_sin():
"""test operator sin """ """test operator sin """
np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32) np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32)
...@@ -3027,6 +3036,7 @@ if __name__ == '__main__': ...@@ -3027,6 +3036,7 @@ if __name__ == '__main__':
test_forward_sign() test_forward_sign()
test_forward_log() test_forward_log()
test_forward_log1p() test_forward_log1p()
test_forward_tan()
test_forward_cos() test_forward_cos()
test_forward_sin() test_forward_sin()
test_forward_negative() test_forward_negative()
......
...@@ -723,6 +723,13 @@ def _test_cos(data): ...@@ -723,6 +723,13 @@ def _test_cos(data):
""" One iteration of cos """ """ One iteration of cos """
return _test_unary_elemwise(math_ops.cos, data) return _test_unary_elemwise(math_ops.cos, data)
####################################################################### #######################################################################
# Tan
# ---
def _test_tan(data):
""" One iteration of tan """
return _test_unary_elemwise(math_ops.tan, data)
#######################################################################
# Sqrt # Sqrt
# ---- # ----
...@@ -772,6 +779,7 @@ def test_all_unary_elemwise(): ...@@ -772,6 +779,7 @@ def test_all_unary_elemwise():
if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'): if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
_test_forward_unary_elemwise(_test_ceil) _test_forward_unary_elemwise(_test_ceil)
_test_forward_unary_elemwise(_test_cos) _test_forward_unary_elemwise(_test_cos)
_test_forward_unary_elemwise(_test_tan)
####################################################################### #######################################################################
# Element-wise # Element-wise
......
...@@ -64,6 +64,7 @@ def test_unary_op(): ...@@ -64,6 +64,7 @@ def test_unary_op():
(relay.nn.relu, lambda x: np.where(x < 0, np.zeros_like(x), np.ones_like(x))), (relay.nn.relu, lambda x: np.where(x < 0, np.zeros_like(x), np.ones_like(x))),
(tvm.relay.cos, lambda x: -1.0 * np.sin(x)), (tvm.relay.cos, lambda x: -1.0 * np.sin(x)),
(tvm.relay.sin, lambda x: np.cos(x)), (tvm.relay.sin, lambda x: np.cos(x)),
(tvm.relay.tan, lambda x: 1.0 / (np.cos(x) ** 2)),
(tvm.relay.atan, lambda x: 1 / (1 + np.power(x, 2.0)))]: (tvm.relay.atan, lambda x: 1 / (1 + np.power(x, 2.0)))]:
check_single_op(opfunc, ref) check_single_op(opfunc, ref)
......
...@@ -76,6 +76,7 @@ def test_unary_op(): ...@@ -76,6 +76,7 @@ def test_unary_op():
(relay.nn.relu, relu), (relay.nn.relu, relu),
(tvm.relay.cos, np.cos), (tvm.relay.cos, np.cos),
(tvm.relay.sin, np.sin), (tvm.relay.sin, np.sin),
(tvm.relay.tan, np.tan),
(tvm.relay.atan, np.arctan)]: (tvm.relay.atan, np.arctan)]:
for dtype in ['float16', 'float32']: for dtype in ['float16', 'float32']:
check_single_op(opfunc, ref, dtype) check_single_op(opfunc, ref, dtype)
......
...@@ -31,6 +31,7 @@ def test_check_numerical_grads(): ...@@ -31,6 +31,7 @@ def test_check_numerical_grads():
lambda x: (np.sign(np.sin(1/x)), np.zeros_like(x)), lambda x: (np.sign(np.sin(1/x)), np.zeros_like(x)),
lambda x: (x*np.sin(1/x), np.sin(1/x) - np.cos(1/x)/x), lambda x: (x*np.sin(1/x), np.sin(1/x) - np.cos(1/x)/x),
lambda x: (np.sin(1/x), - np.cos(1/x)/(x*x)), lambda x: (np.sin(1/x), - np.cos(1/x)/(x*x)),
lambda x: (np.tan(x), 1.0 / (np.cos(x) * np.cos(x))),
] ]
# Avoid values too close to 0 since singularities of our functions are there # Avoid values too close to 0 since singularities of our functions are there
......
...@@ -55,6 +55,7 @@ TOPI_DECLARE_UNARY_OP(round); ...@@ -55,6 +55,7 @@ TOPI_DECLARE_UNARY_OP(round);
TOPI_DECLARE_UNARY_OP(trunc); TOPI_DECLARE_UNARY_OP(trunc);
TOPI_DECLARE_UNARY_OP(abs); TOPI_DECLARE_UNARY_OP(abs);
TOPI_DECLARE_UNARY_OP(cos); TOPI_DECLARE_UNARY_OP(cos);
TOPI_DECLARE_UNARY_OP(tan);
TOPI_DECLARE_UNARY_OP(sin); TOPI_DECLARE_UNARY_OP(sin);
TOPI_DECLARE_UNARY_OP(atan); TOPI_DECLARE_UNARY_OP(atan);
TOPI_DECLARE_UNARY_OP(isnan); TOPI_DECLARE_UNARY_OP(isnan);
......
...@@ -110,6 +110,23 @@ def tanh(x): ...@@ -110,6 +110,23 @@ def tanh(x):
@tvm.te.tag_scope(tag=tag.ELEMWISE) @tvm.te.tag_scope(tag=tag.ELEMWISE)
def tan(x):
"""Take tan of input x.
Parameters
----------
x : tvm.te.Tensor
Input argument.
Returns
-------
y : tvm.te.Tensor
The result.
"""
return te.compute(x.shape, lambda *i: te.tan(x(*i)))
@tvm.te.tag_scope(tag=tag.ELEMWISE)
def cos(x): def cos(x):
"""Take cos of input x. """Take cos of input x.
......
...@@ -175,6 +175,11 @@ TVM_REGISTER_GLOBAL("topi.erf") ...@@ -175,6 +175,11 @@ TVM_REGISTER_GLOBAL("topi.erf")
*rv = erf(args[0]); *rv = erf(args[0]);
}); });
TVM_REGISTER_GLOBAL("topi.tan")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = tan(args[0]);
});
TVM_REGISTER_GLOBAL("topi.cos") TVM_REGISTER_GLOBAL("topi.cos")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = cos(args[0]); *rv = cos(args[0]);
......
...@@ -45,6 +45,7 @@ def test_ewise(): ...@@ -45,6 +45,7 @@ def test_ewise():
test_apply(topi.rsqrt, "rsqrt") test_apply(topi.rsqrt, "rsqrt")
test_apply(topi.sin, "sin") test_apply(topi.sin, "sin")
test_apply(topi.cos, "cos") test_apply(topi.cos, "cos")
test_apply(topi.tan, "tan")
test_apply(topi.atan, "atan") test_apply(topi.atan, "atan")
......
...@@ -127,6 +127,8 @@ def test_ewise(): ...@@ -127,6 +127,8 @@ def test_ewise():
test_apply(topi.sqrt, "sqrt", np.sqrt, 0, 100) test_apply(topi.sqrt, "sqrt", np.sqrt, 0, 100)
test_apply(topi.rsqrt, "rsqrt", lambda x: np.ones_like(x) / np.sqrt(x), 0, 100, skip_name_check=True) test_apply(topi.rsqrt, "rsqrt", lambda x: np.ones_like(x) / np.sqrt(x), 0, 100, skip_name_check=True)
test_apply(topi.cos, "cos", np.cos, -2.0*np.pi, 2.0*np.pi) test_apply(topi.cos, "cos", np.cos, -2.0*np.pi, 2.0*np.pi)
test_apply(topi.tan, "tan", np.tan, -2.0*np.pi, 2.0*np.pi, dtype='float32')
test_apply(topi.tan, "tan", np.tan, -2.0*np.pi, 2.0*np.pi, dtype='float64')
test_apply(topi.sin, "sin", np.sin, -2.0*np.pi, 2.0*np.pi) test_apply(topi.sin, "sin", np.sin, -2.0*np.pi, 2.0*np.pi)
test_apply(topi.erf, "erf", scipy.special.erf, -.1, .1, dtype="float32") test_apply(topi.erf, "erf", scipy.special.erf, -.1, .1, dtype="float32")
test_isnan(-100, 100) test_isnan(-100, 100)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment