Commit fc341e1c by 雾雨魔理沙 Committed by Tianqi Chen

[Relay] [Op] Zeros, Ones (#1885)

parent dfbe82b0
......@@ -67,7 +67,9 @@ This level enables additional math and transform operators.
.. autosummary::
:nosignatures:
tvm.relay.zeros
tvm.relay.zeros_like
tvm.relay.ones
tvm.relay.ones_like
tvm.relay.reshape
tvm.relay.copy
......@@ -155,10 +157,9 @@ Level 3 Definitions
.. autofunction:: tvm.relay.copy
.. autofunction:: tvm.relay.transpose
.. autofunction:: tvm.relay.take
Level 3 Definitions
-------------------
.. autofunction:: tvm.relay.zeros
.. autofunction:: tvm.relay.zeros_like
.. autofunction:: tvm.relay.ones
.. autofunction:: tvm.relay.ones_like
......@@ -177,6 +178,7 @@ Level 4 Definitions
.. autofunction:: tvm.relay.pow
.. autofunction:: tvm.relay.where
Level 5 Definitions
-------------------
.. autofunction:: tvm.relay.image.resize
......@@ -68,19 +68,19 @@ struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
}
};
/*! \brief Attributes used in full operator */
struct FullAttrs : public tvm::AttrsNode<FullAttrs> {
/*! \brief Attributes that specify a tensor */
struct InitOpAttrs : public tvm::AttrsNode<InitOpAttrs> {
Array<IndexExpr> shape;
DataType dtype;
TVM_DECLARE_ATTRS(FullAttrs, "relay.attrs.FullAttrs") {
TVM_DECLARE_ATTRS(InitOpAttrs, "relay.attrs.InitOpAttrs") {
TVM_ATTR_FIELD(shape)
.describe("Target shape.");
TVM_ATTR_FIELD(dtype)
.describe("Target data type.")
.set_default(Int(0));
}
}; // struct FullAttrs
}; // struct InitOpAttrs
} // namespace relay
} // namespace tvm
......
......@@ -484,6 +484,25 @@ def left_shift(lhs, rhs):
return _make.left_shift(lhs, rhs)
def zeros(shape, dtype):
"""Fill array with zeros.
Parameters
----------
shape : tuple of int
The shape of the target.
dtype : data type
The data type of the target.
Returns
-------
result : relay.Expr
The resulting tensor.
"""
return _make.zeros(shape, dtype)
def zeros_like(data):
"""Returns an array of zeros, with same type and shape as the input.
......@@ -500,6 +519,25 @@ def zeros_like(data):
return _make.zeros_like(data)
def ones(shape, dtype):
"""Fill array with ones.
Parameters
----------
shape : tuple of int
The shape of the target.
dtype : data type
The data type of the target.
Returns
-------
result : relay.Expr
The resulting tensor.
"""
return _make.ones(shape, dtype)
def ones_like(data):
"""Returns an array of ones, with same type and shape as the input.
......
......@@ -404,14 +404,14 @@ Examples::
.set_support_level(2)
.add_type_rel("Take", TakeRel);
TVM_REGISTER_NODE_TYPE(FullAttrs);
TVM_REGISTER_NODE_TYPE(InitOpAttrs);
bool FullRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const FullAttrs* param = attrs.as<FullAttrs>();
const InitOpAttrs* param = attrs.as<InitOpAttrs>();
const auto* fill_value = types[0].as<TensorTypeNode>();
if (fill_value == nullptr) {
return false;
......@@ -433,7 +433,7 @@ bool FullRel(const Array<Type>& types,
Expr MakeFull(Expr fill_value,
Array<IndexExpr> shape,
DataType dtype) {
auto attrs = make_node<FullAttrs>();
auto attrs = make_node<InitOpAttrs>();
attrs->shape = std::move(shape);
attrs->dtype = std::move(dtype);
static const Op& op = Op::Get("full");
......@@ -454,6 +454,61 @@ RELAY_REGISTER_OP("full")
.set_support_level(3)
.add_type_rel("Full", FullRel);
bool InitOpRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 1);
const InitOpAttrs* param = attrs.as<InitOpAttrs>();
reporter->Assign(types[0], TensorTypeNode::make(param->shape, param->dtype));
return true;
}
Expr MakeZeros(Array<IndexExpr> shape,
DataType dtype) {
auto attrs = make_node<InitOpAttrs>();
attrs->shape = std::move(shape);
attrs->dtype = std::move(dtype);
static const Op& op = Op::Get("zeros");
return CallNode::make(op, {}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op._make.zeros")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeZeros, args, rv);
});
RELAY_REGISTER_OP("zeros")
.describe(R"code(Fill array with zeros.
)code" TVM_ADD_FILELINE)
.set_num_inputs(0)
.set_support_level(3)
.add_type_rel("InitOp", InitOpRel);
Expr MakeOnes(Array<IndexExpr> shape,
DataType dtype) {
auto attrs = make_node<InitOpAttrs>();
attrs->shape = std::move(shape);
attrs->dtype = std::move(dtype);
static const Op& op = Op::Get("ones");
return CallNode::make(op, {}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op._make.ones")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeOnes, args, rv);
});
RELAY_REGISTER_OP("ones")
.describe(R"code(Fill array with ones.
)code" TVM_ADD_FILELINE)
.set_num_inputs(0)
.set_support_level(3)
.add_type_rel("InitOp", InitOpRel);
bool FullLikeRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
......
......@@ -7,6 +7,15 @@ from tvm.relay.ir_pass import infer_type
from tvm.relay.ir_builder import IRBuilder, func_type
from tvm.relay.env import Environment
def test_zeros_ones():
for op in [relay.zeros, relay.ones]:
ib = relay.ir_builder.IRBuilder()
with ib.function() as func:
ib.ret(op((124, 50), "float64"))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.TensorType((124, 50), "float64")
def test_unary_identity():
for op in [relay.zeros_like, relay.ones_like]:
......@@ -162,6 +171,7 @@ def test_full_like():
if __name__ == "__main__":
test_single_op()
test_zeros_ones()
test_unary_identity()
test_clip_type()
test_copy_infer_type()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment