Commit 16d4da4d by Huang, Guangtai Committed by Tianqi Chen

Add operator `isnan` (#3979)

* add expr `isnan`

* move to intrinsic

* doc & add to topi

* fix error from ci
parent 88cd1b1c
...@@ -36,6 +36,7 @@ tvm.intrin ...@@ -36,6 +36,7 @@ tvm.intrin
tvm.trunc tvm.trunc
tvm.round tvm.round
tvm.abs tvm.abs
tvm.isnan
.. autofunction:: tvm.call_packed .. autofunction:: tvm.call_packed
.. autofunction:: tvm.call_pure_intrin .. autofunction:: tvm.call_pure_intrin
...@@ -52,3 +53,4 @@ tvm.intrin ...@@ -52,3 +53,4 @@ tvm.intrin
.. autofunction:: tvm.trunc .. autofunction:: tvm.trunc
.. autofunction:: tvm.round .. autofunction:: tvm.round
.. autofunction:: tvm.abs .. autofunction:: tvm.abs
.. autofunction:: tvm.isnan
...@@ -32,6 +32,7 @@ List of operators ...@@ -32,6 +32,7 @@ List of operators
topi.trunc topi.trunc
topi.round topi.round
topi.abs topi.abs
topi.isnan
topi.exp topi.exp
topi.tanh topi.tanh
topi.log topi.log
...@@ -127,6 +128,7 @@ topi ...@@ -127,6 +128,7 @@ topi
.. autofunction:: topi.trunc .. autofunction:: topi.trunc
.. autofunction:: topi.round .. autofunction:: topi.round
.. autofunction:: topi.abs .. autofunction:: topi.abs
.. autofunction:: topi.isnan
.. autofunction:: topi.exp .. autofunction:: topi.exp
.. autofunction:: topi.tanh .. autofunction:: topi.tanh
.. autofunction:: topi.log .. autofunction:: topi.log
......
...@@ -465,6 +465,12 @@ TVM_DLL Expr pow(Expr x, Expr y); ...@@ -465,6 +465,12 @@ TVM_DLL Expr pow(Expr x, Expr y);
* \return The aboslute value of input data x * \return The aboslute value of input data x
*/ */
TVM_DLL Expr abs(Expr x); TVM_DLL Expr abs(Expr x);
/*!
* \brief Check if x is NaN.
* \param x The input data
* \return The result expression.
*/
TVM_DLL Expr isnan(Expr x);
/*! /*!
* \brief sum of of source expression over axis * \brief sum of of source expression over axis
......
...@@ -574,6 +574,7 @@ class Call : public ExprNode { ...@@ -574,6 +574,7 @@ class Call : public ExprNode {
static constexpr const char* likely = "likely"; static constexpr const char* likely = "likely";
static constexpr const char* glsl_texture_store = "glsl_texture_store"; static constexpr const char* glsl_texture_store = "glsl_texture_store";
static constexpr const char* prefetch = "prefetch"; static constexpr const char* prefetch = "prefetch";
static constexpr const char* isnan = "isnan";
/*! \brief Vectorizable intrinsic list. */ /*! \brief Vectorizable intrinsic list. */
static const char* vectorizable_intrinsics[]; static const char* vectorizable_intrinsics[];
......
...@@ -434,6 +434,22 @@ def round(x): ...@@ -434,6 +434,22 @@ def round(x):
return _make.round(x) return _make.round(x)
def isnan(x):
"""Check if input value is Nan.
Parameters
----------
x : Expr
Input argument.
Returns
-------
y : Expr
The result.
"""
return _make.isnan(x)
def power(x, y): def power(x, y):
"""x power y """x power y
......
...@@ -38,6 +38,9 @@ TVM_REGISTER_API("_Var") ...@@ -38,6 +38,9 @@ TVM_REGISTER_API("_Var")
TVM_REGISTER_API("make.abs") TVM_REGISTER_API("make.abs")
.set_body_typed(tvm::abs); .set_body_typed(tvm::abs);
TVM_REGISTER_API("make.isnan")
.set_body_typed(tvm::isnan);
TVM_REGISTER_API("make.floor") TVM_REGISTER_API("make.floor")
.set_body_typed(tvm::floor); .set_body_typed(tvm::floor);
......
...@@ -576,6 +576,12 @@ void CodeGenC::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*) ...@@ -576,6 +576,12 @@ void CodeGenC::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
os << " *)(&("; os << " *)(&(";
this->PrintExpr(op->args[0], os); this->PrintExpr(op->args[0], os);
os << ")))"; os << ")))";
} else if (op->is_intrinsic(Call::isnan)) {
os << "(";
this->PrintExpr(op->args[0], os);
os << " != ";
this->PrintExpr(op->args[0], os);
os << ")";
} else { } else {
if (op->call_type == Call::Intrinsic || if (op->call_type == Call::Intrinsic ||
op->call_type == Call::PureIntrinsic) { op->call_type == Call::PureIntrinsic) {
......
...@@ -746,6 +746,10 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) { ...@@ -746,6 +746,10 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
} else if (op->is_intrinsic(Call::reinterpret)) { } else if (op->is_intrinsic(Call::reinterpret)) {
llvm::Type * target = LLVMType(op->type); llvm::Type * target = LLVMType(op->type);
return builder_->CreateBitCast(MakeValue(op->args[0]), target); return builder_->CreateBitCast(MakeValue(op->args[0]), target);
} else if (op->is_intrinsic(Call::isnan)) {
// TODO(hgt312): set fast math flag
llvm::Value* a = MakeValue(op->args[0]);
return builder_->CreateFCmpUNO(a, a);
} else if (op->is_intrinsic("vectorlow")) { } else if (op->is_intrinsic("vectorlow")) {
llvm::Value *v = MakeValue(op->args[0]); llvm::Value *v = MakeValue(op->args[0]);
int l = v->getType()->getVectorNumElements(); int l = v->getType()->getVectorNumElements();
......
...@@ -424,6 +424,30 @@ Expr abs(Expr x) { ...@@ -424,6 +424,30 @@ Expr abs(Expr x) {
} }
} }
Expr isnan(Expr x) {
Type t = Bool(x.type().lanes());
if (x.type().is_int() || x.type().is_uint()) {
return make_const(t, false);
} else if (x.type().is_float()) {
using ir::FloatImm;
const FloatImm* fx = x.as<FloatImm>();
if (fx) {
return make_const(t, std::isnan(fx->value));
}
if (x.type().bits() == 16) {
return ir::Call::make(t, ir::Call::isnan,
{cast(Float(32, t.lanes()), std::move(x))},
ir::Call::PureIntrinsic);
} else {
return ir::Call::make(t, ir::Call::isnan, {x}, ir::Call::PureIntrinsic);
}
} else {
LOG(FATAL) << "Data type " << x.type()
<<" not supported for isnan op. Skipping isnan op...";
return x;
}
}
Expr sum(Expr source, Array<IterVar> rdom) { Expr sum(Expr source, Array<IterVar> rdom) {
Var x("x", source.type()), y("y", source.type()); Var x("x", source.type()), y("y", source.type());
Expr result = ir::Add::make(x, y); Expr result = ir::Add::make(x, y);
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import tvm import tvm
import numpy as np import numpy as np
def test_const(): def test_const():
x = tvm.const(1, "int32") x = tvm.const(1, "int32")
print(x.dtype) print(x.dtype)
...@@ -39,6 +40,7 @@ def test_scalar_dtype_inference(): ...@@ -39,6 +40,7 @@ def test_scalar_dtype_inference():
assert tvm.convert(1).dtype == 'int32' assert tvm.convert(1).dtype == 'int32'
assert tvm.convert(1.0).dtype == 'float32' assert tvm.convert(1.0).dtype == 'float32'
def test_make(): def test_make():
x = tvm.const(1, "int32") x = tvm.const(1, "int32")
y = tvm.var("x") y = tvm.var("x")
...@@ -46,6 +48,7 @@ def test_make(): ...@@ -46,6 +48,7 @@ def test_make():
assert isinstance(tvm.max(x, y), tvm.expr.Max) assert isinstance(tvm.max(x, y), tvm.expr.Max)
assert isinstance(tvm.min(x, y), tvm.expr.Min) assert isinstance(tvm.min(x, y), tvm.expr.Min)
def test_ir(): def test_ir():
x = tvm.const(1, "int32") x = tvm.const(1, "int32")
y = tvm.make.IntImm('int32', 1) y = tvm.make.IntImm('int32', 1)
...@@ -53,6 +56,7 @@ def test_ir(): ...@@ -53,6 +56,7 @@ def test_ir():
stmt = tvm.make.Evaluate(z) stmt = tvm.make.Evaluate(z)
assert isinstance(stmt, tvm.stmt.Evaluate) assert isinstance(stmt, tvm.stmt.Evaluate)
def test_ir2(): def test_ir2():
x = tvm.var("n") x = tvm.var("n")
a = tvm.var("array", tvm.handle) a = tvm.var("array", tvm.handle)
...@@ -60,12 +64,14 @@ def test_ir2(): ...@@ -60,12 +64,14 @@ def test_ir2():
assert isinstance(st, tvm.stmt.Store) assert isinstance(st, tvm.stmt.Store)
assert(st.buffer_var == a) assert(st.buffer_var == a)
def test_let(): def test_let():
x = tvm.var('x') x = tvm.var('x')
y = tvm.var('y') y = tvm.var('y')
stmt = tvm.make.LetStmt( stmt = tvm.make.LetStmt(
x, 10, tvm.make.Evaluate(x + 1)); x, 10, tvm.make.Evaluate(x + 1));
def test_cast(): def test_cast():
x = tvm.var('x', dtype="float32") x = tvm.var('x', dtype="float32")
y = x.astype("int32") y = x.astype("int32")
...@@ -104,10 +110,12 @@ def test_stmt(): ...@@ -104,10 +110,12 @@ def test_stmt():
tvm.stmt.For.Serial, 0, tvm.stmt.For.Serial, 0,
x) x)
def test_dir(): def test_dir():
x = tvm.var('x') x = tvm.var('x')
dir(x) dir(x)
def test_dtype(): def test_dtype():
x = tvm.var('x') x = tvm.var('x')
assert x.dtype == 'int32' assert x.dtype == 'int32'
...@@ -158,6 +166,7 @@ def test_all(): ...@@ -158,6 +166,7 @@ def test_all():
'(((%s < %s) && (%s > (%s + 1))) && (%s < (%s*2)))' % ( '(((%s < %s) && (%s > (%s + 1))) && (%s < (%s*2)))' % (
x.name, y.name, y.name, z.name, x.name, z.name) x.name, y.name, y.name, z.name, x.name, z.name)
def test_bitwise(): def test_bitwise():
x = tvm.var('x') x = tvm.var('x')
y = tvm.var('y') y = tvm.var('y')
...@@ -172,6 +181,18 @@ def test_bitwise(): ...@@ -172,6 +181,18 @@ def test_bitwise():
assert(tvm.var("z", "int8x2") << tvm.const(1, "int8x2")).dtype == "int8x2" assert(tvm.var("z", "int8x2") << tvm.const(1, "int8x2")).dtype == "int8x2"
def test_isnan():
x = tvm.var('x', 'float32')
assert str(tvm.isnan(x)) == 'isnan(x)'
assert str(tvm.isnan(x).dtype) == 'bool'
y = tvm.var('y', 'float16')
assert str(tvm.isnan(y)) == 'isnan(float32(y))'
z = tvm.var('z', 'int32')
assert str(tvm.isnan(z)) == '(bool)0'
k = tvm.var('k', 'int8x2')
assert str(tvm.isnan(k).dtype) == 'uint1x2'
def test_equality(): def test_equality():
a = tvm.var('a') a = tvm.var('a')
b = tvm.var('b') b = tvm.var('b')
...@@ -203,5 +224,6 @@ if __name__ == "__main__": ...@@ -203,5 +224,6 @@ if __name__ == "__main__":
test_any() test_any()
test_all() test_all()
test_bitwise() test_bitwise()
test_isnan()
test_equality() test_equality()
test_equality_string_imm() test_equality_string_imm()
...@@ -58,6 +58,7 @@ TOPI_DECLARE_UNARY_OP(abs); ...@@ -58,6 +58,7 @@ TOPI_DECLARE_UNARY_OP(abs);
TOPI_DECLARE_UNARY_OP(cos); TOPI_DECLARE_UNARY_OP(cos);
TOPI_DECLARE_UNARY_OP(sin); TOPI_DECLARE_UNARY_OP(sin);
TOPI_DECLARE_UNARY_OP(atan); TOPI_DECLARE_UNARY_OP(atan);
TOPI_DECLARE_UNARY_OP(isnan);
/* /*
* \brief Fast_tanh_float implementation from Eigen * \brief Fast_tanh_float implementation from Eigen
......
...@@ -244,6 +244,23 @@ def abs(x): ...@@ -244,6 +244,23 @@ def abs(x):
@tvm.tag_scope(tag=tag.ELEMWISE) @tvm.tag_scope(tag=tag.ELEMWISE)
def isnan(x):
"""Check if value of x is NaN, element-wise.
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
return tvm.compute(x.shape, lambda *i: tvm.isnan(x(*i)))
@tvm.tag_scope(tag=tag.ELEMWISE)
def round(x): def round(x):
"""Round elements of x to nearest integer. """Round elements of x to nearest integer.
......
...@@ -72,6 +72,46 @@ def test_ewise(): ...@@ -72,6 +72,46 @@ def test_ewise():
for device in get_all_backend(): for device in get_all_backend():
check_device(device) check_device(device)
def test_isnan(
low,
high,
shape=(20, 3),
dtype=tvm.float32,
check_round=False,
skip_name_check=False,
):
m = tvm.var("m")
l = tvm.var("l")
A = tvm.placeholder((m, l), dtype=dtype, name="A")
B = topi.isnan(A)
assert tuple(B.shape) == tuple(A.shape)
if not skip_name_check:
assert B.op.body[0].name == "isnan"
a_np = np.random.uniform(low=low, high=high, size=shape).astype(A.dtype) * 10
a_np.ravel()[np.random.choice(a_np.size, int(a_np.size * 0.5), replace=False)] = np.nan
# avoid round check too close to boundary
if check_round:
a_np += ((np.fmod(a_np, 1) - 0.5) < 1e-6) * 1e-5
b_np = np.isnan(a_np)
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_injective(B)
foo = tvm.build(s, [A, B], device, name="isnan")
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros_like(b_np), ctx)
foo(a, b)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5)
for device in get_all_backend():
check_device(device)
test_apply(topi.floor, "floor", np.floor, -100, 100) test_apply(topi.floor, "floor", np.floor, -100, 100)
test_apply(topi.ceil, "ceil", np.ceil, -100, 100) test_apply(topi.ceil, "ceil", np.ceil, -100, 100)
test_apply(topi.sign, "sign", np.sign, -100, 100, skip_name_check=True) test_apply(topi.sign, "sign", np.sign, -100, 100, skip_name_check=True)
...@@ -88,6 +128,7 @@ def test_ewise(): ...@@ -88,6 +128,7 @@ def test_ewise():
test_apply(topi.cos, "cos", np.cos, -2.0*np.pi, 2.0*np.pi) test_apply(topi.cos, "cos", np.cos, -2.0*np.pi, 2.0*np.pi)
test_apply(topi.sin, "sin", np.sin, -2.0*np.pi, 2.0*np.pi) test_apply(topi.sin, "sin", np.sin, -2.0*np.pi, 2.0*np.pi)
test_apply(topi.erf, "erf", scipy.special.erf, -.1, .1, dtype="float32") test_apply(topi.erf, "erf", scipy.special.erf, -.1, .1, dtype="float32")
test_isnan(-100, 100)
def test_cast(): def test_cast():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment