Commit c4ebe6bd by Steven S. Lyubomirsky Committed by Tianqi Chen

[Relay][Op] Add operators full and full_like (#1845)

parent 6f420e0f
...@@ -77,7 +77,8 @@ This level enables additional math and transform operators. ...@@ -77,7 +77,8 @@ This level enables additional math and transform operators.
tvm.relay.abs tvm.relay.abs
tvm.relay.negative tvm.relay.negative
tvm.relay.take tvm.relay.take
tvm.relay.full
tvm.relay.full_like
**Level 4: Broadcast and Reductions** **Level 4: Broadcast and Reductions**
......
...@@ -68,6 +68,20 @@ struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> { ...@@ -68,6 +68,20 @@ struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
} }
}; };
/*! \brief Attributes used in full operator */
struct FullAttrs : public tvm::AttrsNode<FullAttrs> {
Array<IndexExpr> shape;
DataType dtype;
TVM_DECLARE_ATTRS(FullAttrs, "relay.attrs.FullAttrs") {
TVM_ATTR_FIELD(shape)
.describe("Target shape.");
TVM_ATTR_FIELD(dtype)
.describe("Target data type.")
.set_default(Int(0));
}
}; // struct FullAttrs
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_ #endif // TVM_RELAY_ATTRS_TRANSFORM_H_
...@@ -18,7 +18,7 @@ def expand_dims(data, axis, num_newaxis=1): ...@@ -18,7 +18,7 @@ def expand_dims(data, axis, num_newaxis=1):
If `axis >= 0`, it is the last axis inserted in Python's negative indexing. If `axis >= 0`, it is the last axis inserted in Python's negative indexing.
num_newaxis : int num_newaxis : int
Number of axises to be inserted. Should be >= 0. Number of axes to be inserted. Should be >= 0.
Returns Returns
------- -------
...@@ -139,3 +139,44 @@ def take(data, indices, axis=None): ...@@ -139,3 +139,44 @@ def take(data, indices, axis=None):
The computed result. The computed result.
""" """
return _make.take(data, indices, axis) return _make.take(data, indices, axis)
def full(fill_value, shape=(), dtype=""):
"""Fill array with scalar value.
Parameters
----------
fill_value : relay.Expr
The value to fill. Must be a scalar.
shape : tuple of int
The shape of the target.
dtype : data type, optional (defaults to data type of the fill value)
The data type of the target.
Returns
-------
result : relay.Expr
The resulting tensor.
"""
return _make.full(fill_value, shape, dtype)
def full_like(data, fill_value):
"""Return an scalar value array with the same shape and type as the input array.
Parameters
----------
data : relay.Expr
The input tensor.
fill_value : relay.Expr
The scalar value to fill.
Returns
-------
result : relay.Expr
The resulting tensor.
"""
return _make.full_like(data, fill_value)
...@@ -404,5 +404,99 @@ Examples:: ...@@ -404,5 +404,99 @@ Examples::
.set_support_level(2) .set_support_level(2)
.add_type_rel("Take", TakeRel); .add_type_rel("Take", TakeRel);
TVM_REGISTER_NODE_TYPE(FullAttrs);
bool FullRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const FullAttrs* param = attrs.as<FullAttrs>();
const auto* fill_value = types[0].as<TensorTypeNode>();
if (fill_value == nullptr) {
return false;
}
DataType out_dtype = param->dtype;
if (out_dtype.bits() == 0) {
out_dtype = fill_value->dtype;
}
CHECK_EQ(fill_value->shape.size(), 0)
<< "Fill value should be a scalar but has dimension "
<< fill_value->shape.size() << ".";
reporter->Assign(types[1], TensorTypeNode::make(param->shape, out_dtype));
return true;
}
Expr MakeFull(Expr fill_value,
Array<IndexExpr> shape,
DataType dtype) {
auto attrs = make_node<FullAttrs>();
attrs->shape = std::move(shape);
attrs->dtype = std::move(dtype);
static const Op& op = Op::Get("full");
return CallNode::make(op, {fill_value}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op._make.full")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 3>(MakeFull, args, rv);
});
RELAY_REGISTER_OP("full")
.describe(R"code(Fill array with scalar value.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("fill_value", "double", "The value to fill.")
.set_support_level(3)
.add_type_rel("Full", FullRel);
bool FullLikeRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
return false;
}
const auto* fill_value = types[1].as<TensorTypeNode>();
if (fill_value == nullptr) {
return false;
}
CHECK_EQ(fill_value->shape.size(), 0)
<< "The fill value should be a scalar but here it has dimension "
<< fill_value->shape.size() << ".";
reporter->Assign(types[2], TensorTypeNode::make(data->shape, data->dtype));
return true;
}
Expr MakeFullLike(Expr data,
Expr fill_value) {
static const Op& op = Op::Get("full_like");
return CallNode::make(op, {data, fill_value}, Attrs(), {});
}
TVM_REGISTER_API("relay.op._make.full_like")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeFullLike, args, rv);
});
RELAY_REGISTER_OP("full_like")
.describe(R"code(Return an scalar value array with the same shape
and type as the input array.
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("fill_value", "double", "Scalar value to fill.")
.set_support_level(3)
.add_type_rel("FullLike", FullLikeRel);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -113,6 +113,53 @@ def test_take_infer_type(): ...@@ -113,6 +113,53 @@ def test_take_infer_type():
verify_take((d1, d2, d3, d4), (d5, d6), (d1, d2, d5, d6, d4), -2) verify_take((d1, d2, d3, d4), (d5, d6), (d1, d2, d5, d6, d4), -2)
def test_full():
# default settings: match input dtype
ib = relay.ir_builder.IRBuilder()
x = ib.param("x", relay.TensorType((), "int8"))
with ib.function(x) as func:
ib.ret(relay.full(x.var, ()))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.TensorType((), "int8")
# change the shape and dtype
ib = relay.ir_builder.IRBuilder()
x = ib.param("x", relay.TensorType((), "float32"))
with ib.function(x) as func:
ib.ret(relay.full(x.var, (1, 2), "int8"))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.TensorType((1, 2), "int8")
def test_full_like():
# concrete shape
ib = relay.ir_builder.IRBuilder()
base = ib.param("base", relay.TensorType((1, 2, 3), "float32"))
fill = ib.param("fill", relay.TensorType((), "float32"))
with ib.function(base, fill) as func:
ib.ret(relay.full_like(base.var, fill.var))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.TensorType((1, 2, 3), "float32")
# symbolic shape
ib = relay.ir_builder.IRBuilder()
n, c, h, w = tvm.var("n"), 2, 3, tvm.var("w")
base = ib.param("base", relay.TensorType((n, c, h, w), "float32"))
fill = ib.param("fill", relay.TensorType((), "float32"))
with ib.function(base, fill) as func:
ib.ret(relay.full_like(base.var, fill.var))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.TensorType((n, c, h, w), "float32")
if __name__ == "__main__": if __name__ == "__main__":
test_single_op() test_single_op()
test_unary_identity() test_unary_identity()
...@@ -121,3 +168,5 @@ if __name__ == "__main__": ...@@ -121,3 +168,5 @@ if __name__ == "__main__":
test_transpose_infer_type() test_transpose_infer_type()
test_reshape_infer_type() test_reshape_infer_type()
test_take_infer_type() test_take_infer_type()
test_full()
test_full_like()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment