Commit 16d4da4d by Huang, Guangtai Committed by Tianqi Chen

Add operator `isnan` (#3979)

* add expr `isnan`

* move to intrinsic

* doc & add to topi

* fix error from ci
parent 88cd1b1c
......@@ -36,6 +36,7 @@ tvm.intrin
tvm.trunc
tvm.round
tvm.abs
tvm.isnan
.. autofunction:: tvm.call_packed
.. autofunction:: tvm.call_pure_intrin
......@@ -52,3 +53,4 @@ tvm.intrin
.. autofunction:: tvm.trunc
.. autofunction:: tvm.round
.. autofunction:: tvm.abs
.. autofunction:: tvm.isnan
......@@ -32,6 +32,7 @@ List of operators
topi.trunc
topi.round
topi.abs
topi.isnan
topi.exp
topi.tanh
topi.log
......@@ -127,6 +128,7 @@ topi
.. autofunction:: topi.trunc
.. autofunction:: topi.round
.. autofunction:: topi.abs
.. autofunction:: topi.isnan
.. autofunction:: topi.exp
.. autofunction:: topi.tanh
.. autofunction:: topi.log
......
......@@ -465,6 +465,12 @@ TVM_DLL Expr pow(Expr x, Expr y);
* \return The aboslute value of input data x
*/
TVM_DLL Expr abs(Expr x);
/*!
* \brief Check if x is NaN.
* \param x The input data
* \return The result expression.
*/
TVM_DLL Expr isnan(Expr x);
/*!
* \brief sum of of source expression over axis
......
......@@ -574,6 +574,7 @@ class Call : public ExprNode {
static constexpr const char* likely = "likely";
static constexpr const char* glsl_texture_store = "glsl_texture_store";
static constexpr const char* prefetch = "prefetch";
static constexpr const char* isnan = "isnan";
/*! \brief Vectorizable intrinsic list. */
static const char* vectorizable_intrinsics[];
......
......@@ -434,6 +434,22 @@ def round(x):
return _make.round(x)
def isnan(x):
"""Check if input value is Nan.
Parameters
----------
x : Expr
Input argument.
Returns
-------
y : Expr
The result.
"""
return _make.isnan(x)
def power(x, y):
"""x power y
......
......@@ -38,6 +38,9 @@ TVM_REGISTER_API("_Var")
TVM_REGISTER_API("make.abs")
.set_body_typed(tvm::abs);
TVM_REGISTER_API("make.isnan")
.set_body_typed(tvm::isnan);
TVM_REGISTER_API("make.floor")
.set_body_typed(tvm::floor);
......
......@@ -576,6 +576,12 @@ void CodeGenC::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
os << " *)(&(";
this->PrintExpr(op->args[0], os);
os << ")))";
} else if (op->is_intrinsic(Call::isnan)) {
os << "(";
this->PrintExpr(op->args[0], os);
os << " != ";
this->PrintExpr(op->args[0], os);
os << ")";
} else {
if (op->call_type == Call::Intrinsic ||
op->call_type == Call::PureIntrinsic) {
......
......@@ -746,6 +746,10 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
} else if (op->is_intrinsic(Call::reinterpret)) {
llvm::Type * target = LLVMType(op->type);
return builder_->CreateBitCast(MakeValue(op->args[0]), target);
} else if (op->is_intrinsic(Call::isnan)) {
// TODO(hgt312): set fast math flag
llvm::Value* a = MakeValue(op->args[0]);
return builder_->CreateFCmpUNO(a, a);
} else if (op->is_intrinsic("vectorlow")) {
llvm::Value *v = MakeValue(op->args[0]);
int l = v->getType()->getVectorNumElements();
......
......@@ -424,6 +424,30 @@ Expr abs(Expr x) {
}
}
Expr isnan(Expr x) {
Type t = Bool(x.type().lanes());
if (x.type().is_int() || x.type().is_uint()) {
return make_const(t, false);
} else if (x.type().is_float()) {
using ir::FloatImm;
const FloatImm* fx = x.as<FloatImm>();
if (fx) {
return make_const(t, std::isnan(fx->value));
}
if (x.type().bits() == 16) {
return ir::Call::make(t, ir::Call::isnan,
{cast(Float(32, t.lanes()), std::move(x))},
ir::Call::PureIntrinsic);
} else {
return ir::Call::make(t, ir::Call::isnan, {x}, ir::Call::PureIntrinsic);
}
} else {
LOG(FATAL) << "Data type " << x.type()
<<" not supported for isnan op. Skipping isnan op...";
return x;
}
}
Expr sum(Expr source, Array<IterVar> rdom) {
Var x("x", source.type()), y("y", source.type());
Expr result = ir::Add::make(x, y);
......
......@@ -17,6 +17,7 @@
import tvm
import numpy as np
def test_const():
x = tvm.const(1, "int32")
print(x.dtype)
......@@ -39,6 +40,7 @@ def test_scalar_dtype_inference():
assert tvm.convert(1).dtype == 'int32'
assert tvm.convert(1.0).dtype == 'float32'
def test_make():
x = tvm.const(1, "int32")
y = tvm.var("x")
......@@ -46,6 +48,7 @@ def test_make():
assert isinstance(tvm.max(x, y), tvm.expr.Max)
assert isinstance(tvm.min(x, y), tvm.expr.Min)
def test_ir():
x = tvm.const(1, "int32")
y = tvm.make.IntImm('int32', 1)
......@@ -53,6 +56,7 @@ def test_ir():
stmt = tvm.make.Evaluate(z)
assert isinstance(stmt, tvm.stmt.Evaluate)
def test_ir2():
x = tvm.var("n")
a = tvm.var("array", tvm.handle)
......@@ -60,12 +64,14 @@ def test_ir2():
assert isinstance(st, tvm.stmt.Store)
assert(st.buffer_var == a)
def test_let():
x = tvm.var('x')
y = tvm.var('y')
stmt = tvm.make.LetStmt(
x, 10, tvm.make.Evaluate(x + 1));
def test_cast():
x = tvm.var('x', dtype="float32")
y = x.astype("int32")
......@@ -104,10 +110,12 @@ def test_stmt():
tvm.stmt.For.Serial, 0,
x)
def test_dir():
x = tvm.var('x')
dir(x)
def test_dtype():
x = tvm.var('x')
assert x.dtype == 'int32'
......@@ -158,6 +166,7 @@ def test_all():
'(((%s < %s) && (%s > (%s + 1))) && (%s < (%s*2)))' % (
x.name, y.name, y.name, z.name, x.name, z.name)
def test_bitwise():
x = tvm.var('x')
y = tvm.var('y')
......@@ -172,6 +181,18 @@ def test_bitwise():
assert(tvm.var("z", "int8x2") << tvm.const(1, "int8x2")).dtype == "int8x2"
def test_isnan():
x = tvm.var('x', 'float32')
assert str(tvm.isnan(x)) == 'isnan(x)'
assert str(tvm.isnan(x).dtype) == 'bool'
y = tvm.var('y', 'float16')
assert str(tvm.isnan(y)) == 'isnan(float32(y))'
z = tvm.var('z', 'int32')
assert str(tvm.isnan(z)) == '(bool)0'
k = tvm.var('k', 'int8x2')
assert str(tvm.isnan(k).dtype) == 'uint1x2'
def test_equality():
a = tvm.var('a')
b = tvm.var('b')
......@@ -203,5 +224,6 @@ if __name__ == "__main__":
test_any()
test_all()
test_bitwise()
test_isnan()
test_equality()
test_equality_string_imm()
......@@ -58,6 +58,7 @@ TOPI_DECLARE_UNARY_OP(abs);
TOPI_DECLARE_UNARY_OP(cos);
TOPI_DECLARE_UNARY_OP(sin);
TOPI_DECLARE_UNARY_OP(atan);
TOPI_DECLARE_UNARY_OP(isnan);
/*
* \brief Fast_tanh_float implementation from Eigen
......
......@@ -244,6 +244,23 @@ def abs(x):
@tvm.tag_scope(tag=tag.ELEMWISE)
def isnan(x):
"""Check if value of x is NaN, element-wise.
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
return tvm.compute(x.shape, lambda *i: tvm.isnan(x(*i)))
@tvm.tag_scope(tag=tag.ELEMWISE)
def round(x):
"""Round elements of x to nearest integer.
......
......@@ -72,6 +72,46 @@ def test_ewise():
for device in get_all_backend():
check_device(device)
def test_isnan(
low,
high,
shape=(20, 3),
dtype=tvm.float32,
check_round=False,
skip_name_check=False,
):
m = tvm.var("m")
l = tvm.var("l")
A = tvm.placeholder((m, l), dtype=dtype, name="A")
B = topi.isnan(A)
assert tuple(B.shape) == tuple(A.shape)
if not skip_name_check:
assert B.op.body[0].name == "isnan"
a_np = np.random.uniform(low=low, high=high, size=shape).astype(A.dtype) * 10
a_np.ravel()[np.random.choice(a_np.size, int(a_np.size * 0.5), replace=False)] = np.nan
# avoid round check too close to boundary
if check_round:
a_np += ((np.fmod(a_np, 1) - 0.5) < 1e-6) * 1e-5
b_np = np.isnan(a_np)
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_injective(B)
foo = tvm.build(s, [A, B], device, name="isnan")
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros_like(b_np), ctx)
foo(a, b)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5)
for device in get_all_backend():
check_device(device)
test_apply(topi.floor, "floor", np.floor, -100, 100)
test_apply(topi.ceil, "ceil", np.ceil, -100, 100)
test_apply(topi.sign, "sign", np.sign, -100, 100, skip_name_check=True)
......@@ -88,6 +128,7 @@ def test_ewise():
test_apply(topi.cos, "cos", np.cos, -2.0*np.pi, 2.0*np.pi)
test_apply(topi.sin, "sin", np.sin, -2.0*np.pi, 2.0*np.pi)
test_apply(topi.erf, "erf", scipy.special.erf, -.1, .1, dtype="float32")
test_isnan(-100, 100)
def test_cast():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment