Commit 201cfdc5 by 雾雨魔理沙 Committed by Tianqi Chen

[Relay] [Op] Squeeze (#1858)

parent 47b8c36d
...@@ -82,6 +82,20 @@ struct InitOpAttrs : public tvm::AttrsNode<InitOpAttrs> { ...@@ -82,6 +82,20 @@ struct InitOpAttrs : public tvm::AttrsNode<InitOpAttrs> {
} }
}; // struct InitOpAttrs }; // struct InitOpAttrs
/*! \brief Attributes used in squeeze operators */
struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
Array<IndexExpr> axes;
TVM_DECLARE_ATTRS(SqueezeAttrs, "relay.attrs.SqueezeAttrs") {
TVM_ATTR_FIELD(axes)
.describe("The axes to squeeze in the input tensor."
"If `axes = []`, all axis of dimension 1 get squeezed;"
"Else, the dimension in axes get squeezed."
"It is an error if an axes does not has dimension 1.")
.set_default(Array<IndexExpr>({}));
}
}; // struct SqueezeAttrs
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_ #endif // TVM_RELAY_ATTRS_TRANSFORM_H_
...@@ -42,12 +42,35 @@ def transpose(data, axes=None): ...@@ -42,12 +42,35 @@ def transpose(data, axes=None):
Returns Returns
------- -------
result : relay.Expr result : relay.Expr
The reshaped result. The transposed result.
""" """
axes = axes or [] axes = axes or []
return _make.transpose(data, list(axes)) return _make.transpose(data, list(axes))
def squeeze(data, axes=None):
"""Squeeze axes in the array.
Parameters
----------
data : relay.Expr
The input data to the operator.
axes : None or List[int]
Axes to remove.
If axes = [] or = None, remove all axis of dimensions 1.
Otherwise, remove all axis in axes.
If any axis in axes has dimension that does not equal 1, it is an error.
Returns
-------
result : relay.Expr
The squeezed result.
"""
axes = axes or []
return _make.squeeze(data, list(axes))
def reshape(data, newshape): def reshape(data, newshape):
"""Reshapes the input array. """Reshapes the input array.
......
...@@ -80,8 +80,6 @@ RELAY_REGISTER_OP("expand_dims") ...@@ -80,8 +80,6 @@ RELAY_REGISTER_OP("expand_dims")
.set_support_level(1) .set_support_level(1)
.add_type_rel("ExpandDims", ExpandDimsRel); .add_type_rel("ExpandDims", ExpandDimsRel);
/* relay.concatenate */
TVM_REGISTER_NODE_TYPE(ConcatenateAttrs); TVM_REGISTER_NODE_TYPE(ConcatenateAttrs);
bool ConcatenateRel(const Array<Type>& types, bool ConcatenateRel(const Array<Type>& types,
...@@ -633,5 +631,75 @@ Examples:: ...@@ -633,5 +631,75 @@ Examples::
.set_support_level(4) .set_support_level(4)
.add_type_rel("Where", WhereRel); .add_type_rel("Where", WhereRel);
Expr MakeSqueeze(Expr data,
Array<IndexExpr> axes) {
auto attrs = make_node<SqueezeAttrs>();
attrs->axes = std::move(axes);
static const Op& op = Op::Get("squeeze");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op._make.squeeze")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeSqueeze, args, rv);
});
bool SqueezeRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
return false;
}
const auto* param = attrs.as<SqueezeAttrs>();
CHECK(param != nullptr);
std::vector<IndexExpr> result_shape;
// if axes is empty, squeeze all axes of dimension 1
if (param->axes.size() == 0) {
for (const auto& e : data->shape) {
const int64_t* axis_ptr = as_const_int(e);
CHECK(axis_ptr != nullptr) << "the axes attribute must be concrete";
if (*axis_ptr != 1) {
result_shape.push_back(e);
}
}
} else {
// pair up original shape with a boolean which control whether it will be in the final shape.
std::vector<std::pair<IndexExpr, bool> > original_shape;
for (const auto& e : data->shape) {
original_shape.push_back(std::pair<IndexExpr, bool>(e, true));
}
for (const auto& e : param->axes) {
const int64_t* axis_ptr = as_const_int(e);
CHECK(axis_ptr != nullptr);
original_shape.at(*axis_ptr).second = false;
}
for (const auto p : original_shape) {
if (p.second) {
result_shape.push_back(p.first);
} else {
const int64_t* axis_ptr = as_const_int(p.first);
CHECK(axis_ptr != nullptr) << "cannot get concrete shape of input tensor";
CHECK_EQ(*axis_ptr, 1) << "cannot squeeze axis with dimension not equal to 1";
}
}
}
reporter->Assign(types[1], TensorTypeNode::make(result_shape, data->dtype));
return true;
}
RELAY_REGISTER_OP("squeeze")
.describe(R"code(Squeeze the input tensor at the dimensions given by axes
- **data**: The input data to the operator.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(3)
.add_type_rel("Squeeze", SqueezeRel);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -6,6 +6,7 @@ from tvm import relay ...@@ -6,6 +6,7 @@ from tvm import relay
from tvm.relay.ir_pass import infer_type from tvm.relay.ir_pass import infer_type
from tvm.relay.ir_builder import IRBuilder, func_type from tvm.relay.ir_builder import IRBuilder, func_type
from tvm.relay.env import Environment from tvm.relay.env import Environment
from nose.tools import raises
def test_zeros_ones(): def test_zeros_ones():
for op in [relay.zeros, relay.ones]: for op in [relay.zeros, relay.ones]:
...@@ -67,6 +68,44 @@ def test_transpose_infer_type(): ...@@ -67,6 +68,44 @@ def test_transpose_infer_type():
(t, n, 100), "float32") (t, n, 100), "float32")
def test_squeeze_default_axes_infer_type():
ib = relay.ir_builder.IRBuilder()
n, t, d = 1, 4, 1
x = ib.param("x", relay.ty.TensorType((n, t, d), "float32"))
with ib.function(x) as func:
ib.ret(relay.squeeze(x))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType(
(4,), "float32")
def test_squeeze_axes_infer_type():
ib = relay.ir_builder.IRBuilder()
n, t, d = 1, 4, 1
x = ib.param("x", relay.ty.TensorType((n, t, d), "float32"))
with ib.function(x) as func:
ib.ret(relay.squeeze(x, axes=(2,)))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType(
(1, 4), "float32")
@raises(tvm._ffi.base.TVMError)
def test_squeeze_bad_axes_infer_type():
ib = relay.ir_builder.IRBuilder()
n, t, d = 1, 4, 1
x = ib.param("x", relay.ty.TensorType((n, t, d), "float32"))
with ib.function(x) as func:
ib.ret(relay.squeeze(x, axes=(1,)))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
def test_reshape_infer_type(): def test_reshape_infer_type():
ib = relay.ir_builder.IRBuilder() ib = relay.ir_builder.IRBuilder()
n, t, d1, d2 = tvm.var("n"), tvm.var("t"), 100, 20 n, t, d1, d2 = tvm.var("n"), tvm.var("t"), 100, 20
...@@ -181,3 +220,5 @@ if __name__ == "__main__": ...@@ -181,3 +220,5 @@ if __name__ == "__main__":
test_take_infer_type() test_take_infer_type()
test_full() test_full()
test_full_like() test_full_like()
test_squeeze_axes_infer_type()
test_squeeze_default_axes_infer_type()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment