Unverified Commit 9037f4ec by Mahesh Ambule Committed by GitHub

[Relay, Topi, TF Frontend] Isfinite operator (#4981)

* isfinite doc update

* isfinit expr

* isfinit expr

* isfinite schedule reg

* isfinite python binding

* isfinite python binding

* relay register isfinite

* isfinite type relation

* intrin isfinite

* topi isfinite

* testcase topi isfinite

* tf frontend isfinite

* tf frontend isfinite testcase

* test case relay isfinite

* small fixes

* test forward tf isfinite

* test cases injective for cuda

* remove float16 test case

* add support for isinf

* remove unwanted import

* fix conflict
parent fdc8b0dd
......@@ -33,6 +33,8 @@ List of operators
topi.round
topi.abs
topi.isnan
topi.isfinite
topi.isinf
topi.exp
topi.tanh
topi.log
......@@ -134,6 +136,8 @@ topi
.. autofunction:: topi.round
.. autofunction:: topi.abs
.. autofunction:: topi.isnan
.. autofunction:: topi.isfinite
.. autofunction:: topi.isinf
.. autofunction:: topi.exp
.. autofunction:: topi.tanh
.. autofunction:: topi.log
......
......@@ -160,6 +160,8 @@ Supported Ops
- Greater
- GreaterEqual
- Identity
- IsFinite
- IsInf
- LeakyRelu
- LeftShift
- Less
......
......@@ -832,6 +832,8 @@ class CallNode : public PrimExprNode {
static constexpr const char* glsl_texture_store = "glsl_texture_store";
static constexpr const char* prefetch = "prefetch";
static constexpr const char* isnan = "isnan";
static constexpr const char* isfinite = "isfinite";
static constexpr const char* isinf = "isinf";
/*! \brief Vectorizable intrinsic list. */
static const char* vectorizable_intrinsics[];
......
......@@ -84,6 +84,13 @@ TVM_DLL PrimExpr max_value(const DataType& dtype);
TVM_DLL PrimExpr min_value(const DataType& dtype);
/*!
* Get the value of infinity.
* \param dtype The data type.
* \return the infinity value in this format.
*/
TVM_DLL PrimExpr infinity(const DataType& dtype);
/*!
* \brief cast value to type.
*
* \param t the target type.
......@@ -440,6 +447,20 @@ TVM_DLL PrimExpr abs(PrimExpr x);
TVM_DLL PrimExpr isnan(PrimExpr x);
/*!
* \brief Check if x is finite.
* \param x The input data
* \return The result expression.
*/
TVM_DLL PrimExpr isfinite(PrimExpr x);
/*!
* \brief Check if x is infinite.
* \param x The input data
* \return The result expression.
*/
TVM_DLL PrimExpr isinf(PrimExpr x);
/*!
* \brief sum of of source expression over axis
* \param source The source expression.
* \param axis List of iteration variables that will be used for reduction.
......
......@@ -1667,6 +1667,8 @@ _convert_map = {
'Greater' : _broadcast('greater'),
'GreaterEqual' : _broadcast('greater_equal'),
'Identity' : _identity(),
'IsFinite' : AttrCvt('isfinite'),
'IsInf' : AttrCvt('isinf'),
'LeakyRelu' : AttrCvt('leaky_relu'),
'LeftShift' : AttrCvt('left_shift'),
'Less' : _broadcast('less'),
......
......@@ -66,6 +66,8 @@ register_broadcast_schedule("less")
register_broadcast_schedule("less_equal")
register_broadcast_schedule("greater")
register_broadcast_schedule("greater_equal")
register_broadcast_schedule("isfinite")
register_broadcast_schedule("isinf")
register_injective_schedule("maximum")
register_injective_schedule("minimum")
register_injective_schedule("right_shift")
......
......@@ -1008,3 +1008,35 @@ def ndarray_size(data, dtype="int32"):
The number of elements of input tensor.
"""
return _make.ndarray_size(data, dtype)
def isfinite(data):
"""Compute element-wise finiteness of data.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.isfinite(data)
def isinf(data):
"""Compute element-wise infiniteness of data.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.isinf(data)
......@@ -20,7 +20,8 @@
# expose all operators in tvm tir.op
from tvm.tir import any, all, min_value, max_value, trace
from tvm.tir import exp, erf, tanh, sigmoid, log, tan, cos, sin, atan, sqrt, rsqrt, floor, ceil
from tvm.tir import trunc, abs, round, nearbyint, isnan, power, popcount, fmod, if_then_else
from tvm.tir import trunc, abs, round, nearbyint, power, popcount, fmod, if_then_else
from tvm.tir import isnan, isfinite, isinf
from tvm.tir import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod
from tvm.tir import comm_reducer, min, max, sum
......
......@@ -38,7 +38,8 @@ from .op import call_llvm_intrin, all, any, min_value, max_value, trace
from .op import exp, exp2, exp10, log, log2, log10
from .op import cos, sin, cosh, sinh, tan, tanh, atan
from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil
from .op import trunc, abs, round, nearbyint, isnan, power, popcount, fmod, if_then_else
from .op import trunc, abs, round, nearbyint, power, popcount, fmod, if_then_else
from .op import isnan, isfinite, isinf
from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod
from .op import comm_reducer, min, max, sum
......
......@@ -706,6 +706,38 @@ def isnan(x):
return _ffi_api.isnan(x)
def isfinite(x):
"""Check if input value is finite.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return _ffi_api.isfinite(x)
def isinf(x):
"""Check if input value is infinite.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return _ffi_api.isinf(x)
def power(x, y):
"""x power y
......
......@@ -415,5 +415,23 @@ ElemwiseArbitraryLayout)
.set_support_level(10)
.set_attr<FTVMCompute>("FTVMCompute", NdarraySizeCompute);
RELAY_REGISTER_UNARY_OP("isfinite")
.describe(R"code(Returns the finiteness of input, computed element-wise.
.. math::
isfinite(x)
)code" TVM_ADD_FILELINE)
.set_support_level(3)
.add_type_rel("IdentityCompRel", IdentityCompRel)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::isfinite));
RELAY_REGISTER_UNARY_OP("isinf")
.describe(R"code(Returns the infiniteness of input, computed element-wise.
.. math::
isfinite(x)
)code" TVM_ADD_FILELINE)
.set_support_level(3)
.add_type_rel("IdentityCompRel", IdentityCompRel)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::isinf));
} // namespace relay
} // namespace tvm
......@@ -136,6 +136,18 @@ bool BroadcastCompRel(const Array<Type>& types,
return false;
}
bool IdentityCompRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
if (auto* t0 = types[0].as<TensorTypeNode>()) {
Type out_type = TensorType(GetRef<TensorType>(t0)->shape, DataType::Bool());
reporter->Assign(types[1], out_type);
return true;
}
return false;
}
Array<IndexExpr> RankShape(const Array<IndexExpr>& shape) {
if (shape.size() == 0) {
return {};
......
......@@ -79,6 +79,11 @@ bool BroadcastCompRel(const Array<Type>& types,
const Attrs& attrs,
const TypeReporter& reporter);
bool IdentityCompRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter);
Array<IndexExpr> RankShape(const Array<IndexExpr>& shape);
} // namespace relay
......
......@@ -78,6 +78,22 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sigmoid")
*rv = one / (one + exp(-call->args[0]));
});
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.isfinite")
.set_body([](const TVMArgs& args, TVMRetValue* rv){
PrimExpr e = args[0];
const CallNode* call = e.as<CallNode>();
CHECK(call != nullptr);
*rv = isfinite(call->args[0]);
});
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.isinf")
.set_body([](const TVMArgs& args, TVMRetValue* rv){
PrimExpr e = args[0];
const CallNode* call = e.as<CallNode>();
CHECK(call != nullptr);
*rv = isinf(call->args[0]);
});
} // namespace intrin
} // namespace codegen
} // namespace tvm
......@@ -180,6 +180,21 @@ PrimExpr min_value(const DataType& dtype) {
return PrimExpr();
}
// infinity
PrimExpr infinity(const DataType& dtype) {
using namespace tir;
CHECK_EQ(dtype.lanes(), 1);
if (dtype.is_float()) {
if (dtype.bits() == 64) {
return FloatImm(dtype, std::numeric_limits<double>::infinity());
} else if (dtype.bits() == 32 || dtype.bits() == 16) {
return FloatImm(dtype, std::numeric_limits<float>::infinity());
}
}
LOG(FATAL) << "Cannot decide infinity for type " << dtype;
return PrimExpr();
}
namespace tir {
template<typename ValueType>
inline bool ConstPowerHelper(ValueType val, int *shift) {
......@@ -575,6 +590,21 @@ PrimExpr isnan(PrimExpr x) {
}
}
PrimExpr isinf(PrimExpr x) {
DataType t = DataType::Bool(x.dtype().lanes());
if (x.dtype().is_int() || x.dtype().is_uint()) {
return make_const(t, false);
} else if (x.dtype().is_float()) {
PrimExpr infX = infinity(x.dtype());
return abs(x) == infX && !isnan(x);
} else {
LOG(FATAL) << "Data type " << x.dtype() << " not supported for finiteness ops. Skipping it...";
return x;
}
}
PrimExpr isfinite(PrimExpr x) { return !isinf(x) && !isnan(x); }
PrimExpr sum(PrimExpr source, Array<IterVar> rdom) {
Var x("x", source.dtype()), y("y", source.dtype());
PrimExpr result = tir::AddNode::make(x, y);
......@@ -721,6 +751,12 @@ TVM_REGISTER_GLOBAL("tir.abs")
TVM_REGISTER_GLOBAL("tir.isnan")
.set_body_typed(tvm::isnan);
TVM_REGISTER_GLOBAL("tir.isfinite")
.set_body_typed(tvm::isfinite);
TVM_REGISTER_GLOBAL("tir.isinf")
.set_body_typed(tvm::isinf);
TVM_REGISTER_GLOBAL("tir.floor")
.set_body_typed(tvm::floor);
......
......@@ -3152,7 +3152,37 @@ def test_forward_dilation():
_test_dilation2d([1, 3, 3, 1], [2, 2, 1], [1, 1, 1, 1], [1, 2, 2, 1], "SAME")
_test_dilation2d([1, 3, 3, 1], [2, 2, 1], [1, 1, 1, 1], [1, 1, 2, 1], "VALID")
# #######################################################################
#######################################################################
# infinity ops
# ------------
def _verify_infiniteness_ops(tf_op, name):
"""test operator infinity ops"""
# Only float types are allowed in Tensorflow for isfinite and isinf
# float16 is failing on cuda
tf_dtypes = ["float32", "float64"]
for tf_dtype in tf_dtypes:
shape = (8, 8)
data = np.random.uniform(size=shape).astype(tf_dtype)
data.ravel()[np.random.choice(data.size, int(data.size * 0.5), replace=False)] = np.infty
data.ravel()[np.random.choice(data.size, int(data.size * 0.5), replace=False)] = np.nan
tf.reset_default_graph()
in_data = tf.placeholder(tf_dtype, shape, name="in_data")
tf_op(in_data, name=name)
compare_tf_with_tvm([data], ['in_data:0'], '{}:0'.format(name))
def test_forward_isinf():
_verify_infiniteness_ops(tf.is_inf, "isinf")
def test_forward_isfinite():
_verify_infiniteness_ops(tf.is_finite, "isfinite")
#######################################################################
# Main
# ----
if __name__ == '__main__':
......@@ -3224,6 +3254,8 @@ if __name__ == '__main__':
test_forward_squared_difference()
test_forward_add_n()
test_forward_floormod()
test_forward_isfinite()
test_forward_isinf()
test_forward_unravel_index()
# Reductions
......
......@@ -684,6 +684,33 @@ def test_gather_nd():
verify_gather_nd((3, 2), (2, 2, 3), [[[0, 1, 2], [2, 0, 1]], [[0, 0, 0], [1, 1, 1]]])
def _verify_infiniteness_ops(relay_op, ref_op):
for dtype in ['float32', 'float16', 'float16', 'int32', 'int16']:
shape = (2, 8, 8)
x = relay.var("x", relay.TensorType(shape, dtype))
y = relay_op(x)
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType(shape, "bool")
data = np.random.uniform(size=shape).astype(dtype)
if dtype.startswith('float'):
data.ravel()[np.random.choice(data.size, int(data.size * 0.5), replace=False)] = np.infty
data.ravel()[np.random.choice(data.size, int(data.size * 0.5), replace=False)] = np.nan
intrp = create_executor()
op_res = intrp.evaluate(y, {x: data})
ref_res = ref_op(data)
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
def test_isfinite():
_verify_infiniteness_ops(relay.isfinite, np.isfinite)
def test_isinf():
_verify_infiniteness_ops(relay.isinf, np.isinf)
def test_unravel_index():
def verify_unravel_index(indices, shape, dtype):
x_data = np.array(indices).astype(dtype)
......@@ -751,4 +778,6 @@ if __name__ == "__main__":
test_tile()
test_repeat()
test_gather_nd()
test_unravel_index()
test_isfinite()
test_isinf()
test_unravel_index()
\ No newline at end of file
......@@ -60,6 +60,8 @@ TOPI_DECLARE_UNARY_OP(sin);
TOPI_DECLARE_UNARY_OP(atan);
TOPI_DECLARE_UNARY_OP(isnan);
TOPI_DECLARE_UNARY_OP(tanh);
TOPI_DECLARE_UNARY_OP(isfinite);
TOPI_DECLARE_UNARY_OP(isinf);
/*
* \brief Fast_tanh_float implementation from Eigen
......
......@@ -278,6 +278,40 @@ def isnan(x):
@tvm.te.tag_scope(tag=tag.ELEMWISE)
def isfinite(x):
"""Check if value of x is finite, element-wise.
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
return te.compute(x.shape, lambda *i: te.isfinite(x(*i)))
@tvm.te.tag_scope(tag=tag.ELEMWISE)
def isinf(x):
"""Check if value of x is infinite, element-wise.
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
return te.compute(x.shape, lambda *i: te.isinf(x(*i)))
@tvm.te.tag_scope(tag=tag.ELEMWISE)
def round(x):
"""Round elements of x to nearest integer.
......
......@@ -113,6 +113,36 @@ def test_ewise():
for target in get_all_backend():
check_device(target)
def test_infiniteness_ops(topi_op, ref_op, name):
for dtype in ['float32', 'float64', 'int32', 'int16']:
m = te.var("m")
l = te.var("l")
A = te.placeholder((m, l), dtype=dtype, name="A")
B = topi_op(A)
assert tuple(B.shape) == tuple(A.shape)
a_np = np.random.uniform(size=(8, 8)).astype(A.dtype) * 10
if dtype.startswith('float'):
a_np.ravel()[np.random.choice(a_np.size, int(a_np.size * 0.5), replace=False)] = np.infty
a_np.ravel()[np.random.choice(a_np.size, int(a_np.size * 0.5), replace=False)] = np.nan
b_np = ref_op(a_np)
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
with tvm.target.create(device):
s = topi.testing.get_injective_schedule(device)(B)
foo = tvm.build(s, [A, B], device, name=name)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros_like(b_np), ctx)
foo(a, b)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5)
for target in get_all_backend():
check_device(target)
test_apply(topi.floor, "floor", np.floor, -100, 100)
test_apply(topi.ceil, "ceil", np.ceil, -100, 100)
test_apply(topi.sign, "sign", np.sign, -100, 100, skip_name_check=True)
......@@ -132,6 +162,8 @@ def test_ewise():
test_apply(topi.sin, "sin", np.sin, -2.0*np.pi, 2.0*np.pi)
test_apply(topi.erf, "erf", scipy.special.erf, -.1, .1, dtype="float32")
test_isnan(-100, 100)
test_infiniteness_ops(topi.isfinite, np.isfinite, 'isifinite')
test_infiniteness_ops(topi.isinf, np.isinf, 'isinf')
def test_cast():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment