/*! * Copyright (c) 2018 by Contributors * \file resize.cc * \brief Image operators */ #include <tvm/relay/op.h> #include <tvm/relay/attrs/image.h> #include <topi/elemwise.h> #include <topi/image/resize.h> #include "../layout.h" #include "../op_common.h" namespace tvm { namespace relay { TVM_REGISTER_NODE_TYPE(ResizeAttrs); bool ResizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as<TensorTypeNode>(); if (data == nullptr) return false; static const Layout kNCHW("NCHW"); const ResizeAttrs* param = attrs.as<ResizeAttrs>(); CHECK(param != nullptr); const Layout in_layout(param->layout); CHECK(in_layout.Convertible(kNCHW)) << "Resize only support input layouts that are convertible from NCHW." << " But got " << in_layout; auto oshape = ConvertLayout(data->shape, in_layout, kNCHW); oshape[2] = param->size[0]; oshape[3] = param->size[1]; // assign output type reporter->Assign(types[1], TensorTypeNode::make(ConvertLayout(oshape, kNCHW, in_layout), data->dtype)); return true; } Array<Tensor> ResizeCompute(const Attrs& attrs, const Array<Tensor>& inputs, const Type& out_type, const Target& target) { const auto* param = attrs.as<ResizeAttrs>(); CHECK(param != nullptr); CHECK(param->layout == "NCHW" || param->layout == "NHWC"); const auto* out_ttype = out_type.as<TensorTypeNode>(); CHECK(out_ttype != nullptr); Array<IndexExpr> oshape; if (param->layout == "NCHW") { oshape.push_back(out_ttype->shape[2]); oshape.push_back(out_ttype->shape[3]); } else if (param->layout == "NHWC") { oshape.push_back(out_ttype->shape[1]); oshape.push_back(out_ttype->shape[2]); } return Array<Tensor>{ topi::image::resize(inputs[0], oshape, param->layout, param->align_corners, param->method) }; } // Positional relay function to create image operator // used by frontend FFI. Expr MakeResize(Expr data, Array<IndexExpr> size, std::string layout, std::string method, bool align_corners) { auto attrs = make_node<ResizeAttrs>(); attrs->size = std::move(size); attrs->layout = std::move(layout); attrs->method = std::move(method); attrs->align_corners = align_corners; static const Op& op = Op::Get("image.resize"); return CallNode::make(op, {data}, Attrs(attrs), {}); } TVM_REGISTER_API("relay.op.image._make.resize") .set_body([](const TVMArgs& args, TVMRetValue* rv) { runtime::detail::unpack_call<Expr, 5>(MakeResize, args, rv); }); RELAY_REGISTER_OP("image.resize") .describe(R"code(Perform resize to input array with nearest neighbour or bilinear interpolation. - **data**: data is 4D array of shape (batch_size, channels, in_height, in_width) for NCHW (batch_size, in_height, in_width, channels) for NHWC - **out**: Output is 4D array of shape for layout NCHW (batch_size, channels, size[0], size[1]) for layout NHWC (batch_size, size[0], size[1], channels) )code" TVM_ADD_FILELINE) .set_attrs_type_key("relay.attrs.ResizeAttrs") .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_support_level(5) .add_type_rel("Resize", ResizeRel) .set_attr<FTVMCompute>("FTVMCompute", ResizeCompute) .set_attr<TOpPattern>("TOpPattern", kInjective); } // namespace relay } // namespace tvm