Commit 1bfda4d3 by 雾雨魔理沙 Committed by Tianqi Chen

[Relay] [Op] zeros_like and ones_like (#1835)

parent 3aaafc38
...@@ -41,6 +41,12 @@ This level enables typical convnet models. ...@@ -41,6 +41,12 @@ This level enables typical convnet models.
**Level 3: Additional Math And Transform Operators** **Level 3: Additional Math And Transform Operators**
.. autosummary::
:nosignatures:
tvm.relay.zeros_like
tvm.relay.ones_like
**Level 4: Broadcast and Reductions** **Level 4: Broadcast and Reductions**
.. autosummary:: .. autosummary::
......
...@@ -295,3 +295,35 @@ def concat(*args): ...@@ -295,3 +295,35 @@ def concat(*args):
""" """
tup = Tuple(list(args)) tup = Tuple(list(args))
return _make.concat(tup) return _make.concat(tup)
def zeros_like(data):
"""Returns an array of zeros, with same type and shape as the input.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.zeros_like(data)
def ones_like(data):
"""Returns an array of ones, with same type and shape as the input.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.ones_like(data)
...@@ -56,14 +56,21 @@ RELAY_REGISTER_UNARY_OP("exp") ...@@ -56,14 +56,21 @@ RELAY_REGISTER_UNARY_OP("exp")
RELAY_REGISTER_UNARY_OP("sqrt") RELAY_REGISTER_UNARY_OP("sqrt")
.describe(R"code(Returns the sqrt input array, computed element-wise. .describe(R"code(Returns the sqrt input array, computed element-wise.
)code" TVM_ADD_FILELINE)
.set_support_level(1)
.add_type_rel("Identity", IdentityRel);
.. math:: RELAY_REGISTER_UNARY_OP("zeros_like")
sqrt(x) .describe(R"code(Returns an array of zeros, with same type and shape as the input.
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_support_level(1) .set_support_level(1)
.add_type_rel("Identity", IdentityRel); .add_type_rel("Identity", IdentityRel);
RELAY_REGISTER_UNARY_OP("ones_like")
.describe(R"code(Returns an array of ones, with same type and shape as the input.
)code" TVM_ADD_FILELINE)
.set_support_level(1)
.add_type_rel("Identity", IdentityRel);
RELAY_REGISTER_UNARY_OP("sigmoid") RELAY_REGISTER_UNARY_OP("sigmoid")
.describe(R"code(Returns the sigmoid input array, computed element-wise. .describe(R"code(Returns the sigmoid input array, computed element-wise.
...@@ -75,7 +82,6 @@ RELAY_REGISTER_UNARY_OP("sigmoid") ...@@ -75,7 +82,6 @@ RELAY_REGISTER_UNARY_OP("sigmoid")
.set_support_level(1) .set_support_level(1)
.add_type_rel("Identity", IdentityRel); .add_type_rel("Identity", IdentityRel);
// Concat // Concat
TVM_REGISTER_API("relay.op._make.concat") TVM_REGISTER_API("relay.op._make.concat")
.set_body_typed<Expr(Expr)>([](Expr tuple) { .set_body_typed<Expr(Expr)>([](Expr tuple) {
......
import tvm
from tvm import relay
def test_unary_identity():
for op in [relay.zeros_like, relay.ones_like]:
ib = relay.ir_builder.IRBuilder()
x = ib.param("x", relay.TensorType((8, 9, 4), "int32"))
with ib.function(x) as func:
ib.ret(op(x.var))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type()
assert ftype.ret_type == relay.TensorType((8, 9, 4), "int32")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment