Unverified Commit 0e8d8fdd by Haichen Shen Committed by GitHub

[Relay][OP] Add cast op (#2319)

* Add cast op
* Rename dtype_cast to cast
* Add additional safety check for String2TVMType
* Add missing relay op docs
parent 67fe4db0
......@@ -133,6 +133,9 @@ This level enables additional math and transform operators.
:nosignatures:
tvm.relay.image.resize
tvm.relay.vision.multibox_prior
tvm.relay.vision.multibox_transform_loc
tvm.relay.vision.nms
**Level 10: Temporary Operators**
......@@ -160,6 +163,7 @@ Level 1 Definitions
.. autofunction:: tvm.relay.mod
.. autofunction:: tvm.relay.tanh
.. autofunction:: tvm.relay.concatenate
.. autofunction:: tvm.relay.expand_dims
.. autofunction:: tvm.relay.nn.softmax
.. autofunction:: tvm.relay.nn.log_softmax
.. autofunction:: tvm.relay.nn.relu
......@@ -236,6 +240,9 @@ Level 4 Definitions
Level 5 Definitions
-------------------
.. autofunction:: tvm.relay.image.resize
.. autofunction:: tvm.relay.vision.multibox_prior
.. autofunction:: tvm.relay.vision.multibox_transform_loc
.. autofunction:: tvm.relay.vision.nms
Level 10 Definitions
......
......@@ -946,9 +946,11 @@ inline TVMType String2TVMType(std::string s) {
char* xdelim; // emulate sscanf("%ux%u", bits, lanes)
uint8_t bits = static_cast<uint8_t>(strtoul(scan, &xdelim, 10));
if (bits != 0) t.bits = bits;
char* endpt = xdelim;
if (*xdelim == 'x') {
t.lanes = static_cast<uint16_t>(strtoul(xdelim + 1, nullptr, 10));
t.lanes = static_cast<uint16_t>(strtoul(xdelim + 1, &endpt, 10));
}
CHECK(endpt == s.c_str() + s.length()) << "unknown type " << s;
return t;
}
......
......@@ -49,7 +49,7 @@ class Expr(RelayNode):
result : tvm.relay.Expr
The result expression.
"""
return _make.dtype_cast(self, dtype)
return _make.cast(self, dtype)
def __add__(self, other):
if isinstance(other, Expr):
......
......@@ -4,6 +4,26 @@ from . import _make
from ..expr import TupleWrapper
def cast(data, dtype):
"""Cast input tensor to data type.
Parameters
----------
data : relay.Expr
The input data to the operator.
dtype: str
The target data type
Returns
-------
result : relay.Expr
The casted result.
"""
from .. import _make as _relay_make
return _relay_make.cast(data, dtype)
def expand_dims(data, axis, num_newaxis=1):
"""Insert `num_newaxis` axises at the position given by `axis`.
......
......@@ -61,7 +61,7 @@ Expr MakeCast(Expr data,
return CallNode::make(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay._make.dtype_cast")
TVM_REGISTER_API("relay._make.cast")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeCast, args, rv);
});
......
......@@ -46,6 +46,11 @@ def test_cast():
assert "dtype=" in yy.astext()
assert yy.checked_type == relay.TensorType((8, 9, 4), "int32")
x = relay.var("x", relay.TensorType((8, 9, 4), "float32"))
y = relay.cast(x, "int32")
yy = relay.ir_pass.infer_type(y)
assert "dtype=" in yy.astext()
assert yy.checked_type == relay.TensorType((8, 9, 4), "int32")
def test_clip():
a = relay.var("a", relay.TensorType((10, 4), "float32"))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment