Unverified Commit 0e8d8fdd by Haichen Shen Committed by GitHub

[Relay][OP] Add cast op (#2319)

* Add cast op
* Rename dtype_cast to cast
* Add additional safety check for String2TVMType
* Add missing relay op docs
parent 67fe4db0
...@@ -133,6 +133,9 @@ This level enables additional math and transform operators. ...@@ -133,6 +133,9 @@ This level enables additional math and transform operators.
:nosignatures: :nosignatures:
tvm.relay.image.resize tvm.relay.image.resize
tvm.relay.vision.multibox_prior
tvm.relay.vision.multibox_transform_loc
tvm.relay.vision.nms
**Level 10: Temporary Operators** **Level 10: Temporary Operators**
...@@ -160,6 +163,7 @@ Level 1 Definitions ...@@ -160,6 +163,7 @@ Level 1 Definitions
.. autofunction:: tvm.relay.mod .. autofunction:: tvm.relay.mod
.. autofunction:: tvm.relay.tanh .. autofunction:: tvm.relay.tanh
.. autofunction:: tvm.relay.concatenate .. autofunction:: tvm.relay.concatenate
.. autofunction:: tvm.relay.expand_dims
.. autofunction:: tvm.relay.nn.softmax .. autofunction:: tvm.relay.nn.softmax
.. autofunction:: tvm.relay.nn.log_softmax .. autofunction:: tvm.relay.nn.log_softmax
.. autofunction:: tvm.relay.nn.relu .. autofunction:: tvm.relay.nn.relu
...@@ -236,6 +240,9 @@ Level 4 Definitions ...@@ -236,6 +240,9 @@ Level 4 Definitions
Level 5 Definitions Level 5 Definitions
------------------- -------------------
.. autofunction:: tvm.relay.image.resize .. autofunction:: tvm.relay.image.resize
.. autofunction:: tvm.relay.vision.multibox_prior
.. autofunction:: tvm.relay.vision.multibox_transform_loc
.. autofunction:: tvm.relay.vision.nms
Level 10 Definitions Level 10 Definitions
......
...@@ -946,9 +946,11 @@ inline TVMType String2TVMType(std::string s) { ...@@ -946,9 +946,11 @@ inline TVMType String2TVMType(std::string s) {
char* xdelim; // emulate sscanf("%ux%u", bits, lanes) char* xdelim; // emulate sscanf("%ux%u", bits, lanes)
uint8_t bits = static_cast<uint8_t>(strtoul(scan, &xdelim, 10)); uint8_t bits = static_cast<uint8_t>(strtoul(scan, &xdelim, 10));
if (bits != 0) t.bits = bits; if (bits != 0) t.bits = bits;
char* endpt = xdelim;
if (*xdelim == 'x') { if (*xdelim == 'x') {
t.lanes = static_cast<uint16_t>(strtoul(xdelim + 1, nullptr, 10)); t.lanes = static_cast<uint16_t>(strtoul(xdelim + 1, &endpt, 10));
} }
CHECK(endpt == s.c_str() + s.length()) << "unknown type " << s;
return t; return t;
} }
......
...@@ -49,7 +49,7 @@ class Expr(RelayNode): ...@@ -49,7 +49,7 @@ class Expr(RelayNode):
result : tvm.relay.Expr result : tvm.relay.Expr
The result expression. The result expression.
""" """
return _make.dtype_cast(self, dtype) return _make.cast(self, dtype)
def __add__(self, other): def __add__(self, other):
if isinstance(other, Expr): if isinstance(other, Expr):
......
...@@ -4,6 +4,26 @@ from . import _make ...@@ -4,6 +4,26 @@ from . import _make
from ..expr import TupleWrapper from ..expr import TupleWrapper
def cast(data, dtype):
"""Cast input tensor to data type.
Parameters
----------
data : relay.Expr
The input data to the operator.
dtype: str
The target data type
Returns
-------
result : relay.Expr
The casted result.
"""
from .. import _make as _relay_make
return _relay_make.cast(data, dtype)
def expand_dims(data, axis, num_newaxis=1): def expand_dims(data, axis, num_newaxis=1):
"""Insert `num_newaxis` axises at the position given by `axis`. """Insert `num_newaxis` axises at the position given by `axis`.
......
...@@ -61,7 +61,7 @@ Expr MakeCast(Expr data, ...@@ -61,7 +61,7 @@ Expr MakeCast(Expr data,
return CallNode::make(op, {data}, Attrs(attrs), {}); return CallNode::make(op, {data}, Attrs(attrs), {});
} }
TVM_REGISTER_API("relay._make.dtype_cast") TVM_REGISTER_API("relay._make.cast")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeCast, args, rv); runtime::detail::unpack_call<Expr, 2>(MakeCast, args, rv);
}); });
......
...@@ -46,6 +46,11 @@ def test_cast(): ...@@ -46,6 +46,11 @@ def test_cast():
assert "dtype=" in yy.astext() assert "dtype=" in yy.astext()
assert yy.checked_type == relay.TensorType((8, 9, 4), "int32") assert yy.checked_type == relay.TensorType((8, 9, 4), "int32")
x = relay.var("x", relay.TensorType((8, 9, 4), "float32"))
y = relay.cast(x, "int32")
yy = relay.ir_pass.infer_type(y)
assert "dtype=" in yy.astext()
assert yy.checked_type == relay.TensorType((8, 9, 4), "int32")
def test_clip(): def test_clip():
a = relay.var("a", relay.TensorType((10, 4), "float32")) a = relay.var("a", relay.TensorType((10, 4), "float32"))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment