Commit 2ed31b24 by Andrew Tulloch Committed by Tianqi Chen

{relay,topi}.reinterpret support (#3599)

= Motivation

It's useful to expose the tvm::reinterpret functionality to Relay/TOPI users, as
this allows them to build (fused) operators leveraging the bitwise
reinterpretation of an operator. An example is approximate transcendental
functions, which can be implemented similar to:

```.py
    def C(x):
        return relay.expr.const(x, "float32")

    def approx_exp(x):
        x = relay.minimum(relay.maximum(x, C(-88.0)), C(88.0))
        x = C(127.0) + x * C(1.44269504)
        xf = relay.floor(x)
        i = relay.cast(xf, "int32")
        x = x - xf
        Y = C(0.99992522) + x * (C(0.69583354) + x * (C(0.22606716) + x * C(0.078024523)))
        exponent = relay.left_shift(i, relay.expr.const(23, "int32"))
        exponent = relay.reinterpret(exponent, "float32")
        return exponent * Y

    def approx_sigmoid(x):
        # <2.0e-5 absolute error over [-5, 5]
        y = approx_exp(x)
        return y / (y + C(1.0))

    def approx_tanh(x):
        # <4.0e-5 absolute error over [-5, 5]
        x = x * C(2.0)
        y = approx_exp(x)
        return (y - C(1.0)) / (y + C(1.0))
```

See unit tests for implementations of these approximate transendentals.
parent 66f3bf83
...@@ -40,6 +40,7 @@ List of operators ...@@ -40,6 +40,7 @@ List of operators
topi.sigmoid topi.sigmoid
topi.clip topi.clip
topi.cast topi.cast
topi.reinterpret
topi.transpose topi.transpose
topi.flip topi.flip
topi.strided_slice topi.strided_slice
...@@ -133,6 +134,7 @@ topi ...@@ -133,6 +134,7 @@ topi
.. autofunction:: topi.sigmoid .. autofunction:: topi.sigmoid
.. autofunction:: topi.clip .. autofunction:: topi.clip
.. autofunction:: topi.cast .. autofunction:: topi.cast
.. autofunction:: topi.reinterpret
.. autofunction:: topi.transpose .. autofunction:: topi.transpose
.. autofunction:: topi.flip .. autofunction:: topi.flip
.. autofunction:: topi.strided_slice .. autofunction:: topi.strided_slice
......
...@@ -114,6 +114,7 @@ This level enables additional math and transform operators. ...@@ -114,6 +114,7 @@ This level enables additional math and transform operators.
tvm.relay.full tvm.relay.full
tvm.relay.full_like tvm.relay.full_like
tvm.relay.cast tvm.relay.cast
tvm.relay.reinterpret
tvm.relay.split tvm.relay.split
tvm.relay.arange tvm.relay.arange
tvm.relay.stack tvm.relay.stack
...@@ -263,6 +264,7 @@ Level 3 Definitions ...@@ -263,6 +264,7 @@ Level 3 Definitions
.. autofunction:: tvm.relay.full .. autofunction:: tvm.relay.full
.. autofunction:: tvm.relay.full_like .. autofunction:: tvm.relay.full_like
.. autofunction:: tvm.relay.cast .. autofunction:: tvm.relay.cast
.. autofunction:: tvm.relay.reinterpret
.. autofunction:: tvm.relay.split .. autofunction:: tvm.relay.split
.. autofunction:: tvm.relay.arange .. autofunction:: tvm.relay.arange
.. autofunction:: tvm.relay.stack .. autofunction:: tvm.relay.stack
......
...@@ -40,6 +40,7 @@ _reg.register_schedule("reverse", schedule_injective) ...@@ -40,6 +40,7 @@ _reg.register_schedule("reverse", schedule_injective)
_reg.register_schedule("repeat", schedule_broadcast) _reg.register_schedule("repeat", schedule_broadcast)
_reg.register_schedule("tile", schedule_broadcast) _reg.register_schedule("tile", schedule_broadcast)
_reg.register_schedule("cast", schedule_injective) _reg.register_schedule("cast", schedule_injective)
_reg.register_schedule("reinterpret", schedule_injective)
_reg.register_schedule("strided_slice", schedule_injective) _reg.register_schedule("strided_slice", schedule_injective)
_reg.register_schedule("slice_like", schedule_injective) _reg.register_schedule("slice_like", schedule_injective)
_reg.register_schedule("split", schedule_injective) _reg.register_schedule("split", schedule_injective)
......
...@@ -40,6 +40,26 @@ def cast(data, dtype): ...@@ -40,6 +40,26 @@ def cast(data, dtype):
return _relay_make.cast(data, dtype) return _relay_make.cast(data, dtype)
def reinterpret(data, dtype):
"""Reinterpret input tensor to data type.
Parameters
----------
data : relay.Expr
The input data to the operator.
dtype: str
The target data type
Returns
-------
result : relay.Expr
The reinterpreted result.
"""
from .. import _make as _relay_make
return _relay_make.reinterpret(data, dtype)
def expand_dims(data, axis, num_newaxis=1): def expand_dims(data, axis, num_newaxis=1):
"""Insert `num_newaxis` axises at the position given by `axis`. """Insert `num_newaxis` axises at the position given by `axis`.
......
...@@ -569,6 +569,13 @@ void CodeGenC::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*) ...@@ -569,6 +569,13 @@ void CodeGenC::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
os << "("; os << "(";
this->PrintExpr(op->args[0], os); this->PrintExpr(op->args[0], os);
os << " == NULL)"; os << " == NULL)";
} else if (op->is_intrinsic(Call::reinterpret)) {
// generate (*( TYPE *)(&(ARG)))
os << "(*(";
this->PrintType(op->type, os);
os << " *)(&(";
this->PrintExpr(op->args[0], os);
os << ")))";
} else { } else {
if (op->call_type == Call::Intrinsic || if (op->call_type == Call::Intrinsic ||
op->call_type == Call::PureIntrinsic) { op->call_type == Call::PureIntrinsic) {
......
...@@ -97,6 +97,37 @@ RELAY_REGISTER_OP("cast") ...@@ -97,6 +97,37 @@ RELAY_REGISTER_OP("cast")
.set_attr<TOpPattern>("TOpPattern", kElemWise) .set_attr<TOpPattern>("TOpPattern", kElemWise)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout); .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);
Array<Tensor> ReinterpretCompute(const Attrs& attrs, const Array<Tensor>& inputs,
const Type& out_type, const Target& target) {
const CastAttrs* param = attrs.as<CastAttrs>();
CHECK(param != nullptr);
DataType dtype = param->dtype;
return {topi::reinterpret(inputs[0], dtype)};
}
Expr MakeReinterpret(Expr data, DataType dtype) {
auto attrs = make_node<CastAttrs>();
attrs->dtype = dtype;
static const Op& op = Op::Get("reinterpret");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay._make.reinterpret").set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeReinterpret, args, rv);
});
RELAY_REGISTER_OP("reinterpret")
.describe(R"code(Reinterpret the data into a new data type.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_attrs_type_key("relay.attrs.CastAttrs")
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(3)
.add_type_rel("Reinterpret", CastRel)
.set_attr<FTVMCompute>("FTVMCompute", ReinterpretCompute)
.set_attr<TOpPattern>("TOpPattern", kElemWise)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);
// relay.expand_dims // relay.expand_dims
TVM_REGISTER_NODE_TYPE(ExpandDimsAttrs); TVM_REGISTER_NODE_TYPE(ExpandDimsAttrs);
......
...@@ -75,6 +75,7 @@ def test_cast(): ...@@ -75,6 +75,7 @@ def test_cast():
assert "dtype=" in yy.astext() assert "dtype=" in yy.astext()
assert yy.checked_type == relay.TensorType((8, 9, 4), "int32") assert yy.checked_type == relay.TensorType((8, 9, 4), "int32")
def test_clip(): def test_clip():
a = relay.var("a", relay.TensorType((10, 4), "float32")) a = relay.var("a", relay.TensorType((10, 4), "float32"))
y = relay.clip(a, 1., 4.) y = relay.clip(a, 1., 4.)
...@@ -88,6 +89,69 @@ def test_clip(): ...@@ -88,6 +89,69 @@ def test_clip():
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01) np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
def test_reinterpret():
a = relay.var("a", relay.TensorType((1000, 4), "float32"))
y = relay.reinterpret(a, "int32")
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((1000, 4), "int32")
data = np.random.randn(1000, 4).astype('float32') * 1000
intrp = create_executor()
op_res = intrp.evaluate(y, {a: relay.const(data)})
ref_res = data.view("int32")
np.testing.assert_equal(op_res.asnumpy(), ref_res)
def test_approximate_transcendental():
def C(x):
return relay.expr.const(x, "float32")
def approx_exp(x):
# An approximation derived from Opus,
# https://github.com/xiph/opus/blob/c1c247/celt/mathops.h#L147-L165
x = relay.minimum(relay.maximum(x, C(-88.0)), C(88.0))
x = C(127.0) + x * C(1.44269504)
xf = relay.floor(x)
i = relay.cast(xf, "int32")
x = x - xf
Y = C(0.99992522) + x * (C(0.69583354) + x * (C(0.22606716) + x * C(0.078024523)))
exponent = relay.left_shift(i, relay.expr.const(23, "int32"))
exponent = relay.reinterpret(exponent, "float32")
return exponent * Y
def approximate_sigmoid(x):
y = approx_exp(x)
return y / (y + C(1.0))
def approximate_tanh(x):
x = x * C(2.0)
y = approx_exp(x)
return (y - C(1.0)) / (y + C(1.0))
a = relay.var("a", relay.TensorType((1000,), "float32"))
y = approximate_sigmoid(a)
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((1000,), "float32")
data = np.linspace(-5, 5, 1000).astype("float32")
intrp = create_executor()
op_res = intrp.evaluate(y, {a: relay.const(data)})
def reference_sigmoid(x):
return np.exp(-np.logaddexp(0, -x))
np.testing.assert_allclose(op_res.asnumpy(), reference_sigmoid(data), atol=2e-5, rtol=1e-9)
y = approximate_tanh(a)
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((1000,), "float32")
data = np.linspace(-5, 5, 1000).astype("float32")
intrp = create_executor()
op_res = intrp.evaluate(y, {a: relay.const(data)})
def reference_tanh(x):
return np.tanh(x)
np.testing.assert_allclose(op_res.asnumpy(), reference_tanh(data), atol=4e-5, rtol=1e-9)
def test_squeeze(): def test_squeeze():
def verify_squeeze(shape, dtype, axis): def verify_squeeze(shape, dtype, axis):
x = relay.var("x", relay.TensorType(shape, dtype)) x = relay.var("x", relay.TensorType(shape, dtype))
......
...@@ -95,6 +95,31 @@ def test_add_pipeline(): ...@@ -95,6 +95,31 @@ def test_add_pipeline():
with tvm.build_config(offset_factor=4): with tvm.build_config(offset_factor=4):
check_c() check_c()
def test_reinterpret():
nn = 1024
n = tvm.convert(nn)
A = tvm.placeholder((n,), name='A', dtype="int32")
B = tvm.compute(A.shape, lambda *i: tvm.call_pure_intrin("float32", "reinterpret", A(*i)), name='B')
s = tvm.create_schedule(B.op)
def check_c():
mhost = tvm.build(s, [A, B], "c", name="reinterpret")
temp = util.tempdir()
path_dso = temp.relpath("temp.so")
mhost.export_library(path_dso)
m = tvm.module.load(path_dso)
fadd = m['reinterpret']
ctx = tvm.cpu(0)
n = nn
a = tvm.nd.array(np.random.randint(-2 ** 30, 2 ** 30, size=n).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
fadd(a, b)
tvm.testing.assert_allclose(
b.asnumpy(), a.asnumpy().view('float32'))
check_c()
if __name__ == "__main__": if __name__ == "__main__":
test_add() test_add()
test_add_pipeline() test_add_pipeline()
test_reinterpret()
...@@ -269,14 +269,34 @@ inline Tensor cast(const Tensor& x, ...@@ -269,14 +269,34 @@ inline Tensor cast(const Tensor& x,
} }
/*! /*!
* \brief Creates an operation that sum each element of a tensor * \brief Reinterpret each element of x to the given type.
*
* \param xs The input tensor array * \param x The input tensor
* \param name The name of the operation * \param type The type to cast to
* \param tag The tag to mark the operation * \param name The name of the operation
* * \param tag The tag to mark the operation
* \return A Tensor whose op member is the sum operation *
*/ * \return A Tensor whose op member is the reinterpret operation
*/
inline Tensor reinterpret(const Tensor& x, Type type, std::string name = "tensor",
std::string tag = kElementWise) {
return compute(x->shape,
[&](const Array<Var>& i) {
return tvm::ir::Call::make(type, "reinterpret", {x(i)},
tvm::ir::Call::PureIntrinsic);
},
name, tag);
}
/*!
* \brief Creates an operation that sum each element of a tensor
*
* \param xs The input tensor array
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the sum operation
*/
inline Tensor elemwise_sum(const Array<Tensor>& xs, inline Tensor elemwise_sum(const Array<Tensor>& xs,
std::string name = "T_elemwise_sum", std::string name = "T_elemwise_sum",
std::string tag = kElementWise) { std::string tag = kElementWise) {
......
...@@ -343,3 +343,21 @@ def cast(x, dtype): ...@@ -343,3 +343,21 @@ def cast(x, dtype):
return tvm.compute( return tvm.compute(
x.shape, lambda *i: x(*i).astype(dtype), tag=tag.ELEMWISE) x.shape, lambda *i: x(*i).astype(dtype), tag=tag.ELEMWISE)
return tvm.make._cast(dtype, x) return tvm.make._cast(dtype, x)
def reinterpret(x, dtype):
"""Reinterpret input to specified data type.
Parameters
----------
x : tvm.Tensor
Input argument.
dtype : str
Data type.
Returns
-------
y : tvm.Tensor
The result.
"""
return cpp.reinterpret(x, dtype)
...@@ -193,6 +193,12 @@ TVM_REGISTER_GLOBAL("topi.cast") ...@@ -193,6 +193,12 @@ TVM_REGISTER_GLOBAL("topi.cast")
*rv = cast(args[0], args[1]); *rv = cast(args[0], args[1]);
}); });
TVM_REGISTER_GLOBAL("topi.reinterpret")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = reinterpret(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.elemwise_sum") TVM_REGISTER_GLOBAL("topi.elemwise_sum")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = elemwise_sum(args[0]); *rv = elemwise_sum(args[0]);
......
...@@ -45,6 +45,29 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis): ...@@ -45,6 +45,29 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
check_device(device) check_device(device)
def verify_reinterpret(in_shape, in_dtype, out_dtype, generator):
A = tvm.placeholder(shape=in_shape, name="A", dtype=in_dtype)
B = topi.reinterpret(A, out_dtype)
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_elemwise(B)
foo = tvm.build(s, [A, B], device, name="reinterpret")
data_npy = generator(in_shape).astype(in_dtype)
out_npy = data_npy.view(B.dtype)
data_nd = tvm.nd.array(data_npy, ctx)
out_nd = tvm.nd.array(np.empty(in_shape).astype(B.dtype), ctx)
foo(data_nd, out_nd)
np.testing.assert_equal(out_nd.asnumpy(), out_npy)
for device in get_all_backend():
check_device(device)
def verify_transpose(in_shape, axes): def verify_transpose(in_shape, axes):
A = tvm.placeholder(shape=in_shape, name="A") A = tvm.placeholder(shape=in_shape, name="A")
B = topi.transpose(A, axes) B = topi.transpose(A, axes)
...@@ -434,6 +457,19 @@ def test_expand_dims(): ...@@ -434,6 +457,19 @@ def test_expand_dims():
verify_expand_dims((3, 10), (1, 3, 10), -3, 1) verify_expand_dims((3, 10), (1, 3, 10), -3, 1)
def test_reinterpret():
verify_reinterpret((1000,), "float32", "int32",
lambda shape: np.random.randn(*shape) * 1000)
verify_reinterpret((1000,), "float16", "int16",
lambda shape: np.random.randn(*shape) * 100)
verify_reinterpret((1000,), "int16", "uint16",
lambda shape: np.random.randint(-1000, 1000, size=shape))
verify_reinterpret((1000,), "uint32", "int32",
lambda shape: np.random.randint(0, 2 ** 32 - 1, size=shape))
verify_reinterpret((1000,), "uint32", "int32",
lambda shape: np.random.randint(0, 2 ** 32 - 1, size=shape))
def test_transpose(): def test_transpose():
verify_transpose((3, 10, 2), (1, 0, 2)) verify_transpose((3, 10, 2), (1, 0, 2))
verify_transpose((3, 10, 5), (2, 0, 1)) verify_transpose((3, 10, 5), (2, 0, 1))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment