upsampling.cc 7.95 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23
/*!
 * \file upsampling.cc
 * \brief upsampling operator
 */
24
#include <tvm/tir/data_layout.h>
25 26
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h>
27
#include <tvm/relay/op_attr_types.h>
28 29
#include <vector>
#include "../op_common.h"
30 31 32 33 34

namespace tvm {
namespace relay {

TVM_REGISTER_NODE_TYPE(UpSamplingAttrs);
35
TVM_REGISTER_NODE_TYPE(UpSampling3DAttrs);
36

37 38 39 40 41
template <typename T>
Array<Array<Layout> > UpsamplingInferCorrectLayout(
    const Attrs& attrs,
    const Array<Layout>& new_in_layouts,
    const Array<Layout>& old_in_layouts,
42
    const Array<tvm::relay::Type> &old_in_types) {
43 44 45 46 47 48 49 50 51 52
  // NOTE: Discard "const" qualifier here.
  T *params = const_cast<T*>(attrs.as<T>());

  if (new_in_layouts.defined()) {
    CHECK_EQ(new_in_layouts.size(), 1);

    Layout raw_layout(params->layout);
    Layout input = new_in_layouts[0];
    if (input.IndexOf(LayoutAxis::Get('W')) == raw_layout.IndexOf(LayoutAxis::Get('W')) &&
      input.IndexOf(LayoutAxis::Get('H')) == raw_layout.IndexOf(LayoutAxis::Get('H')) &&
53 54 55 56 57
        !input.Contains(LayoutAxis::Get('w')) && !input.Contains(LayoutAxis::Get('h'))&&
        (input.IndexOf(LayoutAxis::Get('D')) == -1 ||
        (input.IndexOf(LayoutAxis::Get('D')) == raw_layout.IndexOf(LayoutAxis::Get('D')) &&
        !input.Contains(LayoutAxis::Get('d'))))) {
        params->layout = input.name();  // modify self to follow the input layout
58 59 60 61 62 63 64
    }
  }

  Layout inferred_layout(params->layout);
  return Array<Array<Layout> >{{inferred_layout}, {inferred_layout}};
}

65 66 67 68 69 70 71 72 73 74 75 76 77
bool UpSamplingRel(const Array<Type>& types,
                   int num_inputs,
                   const Attrs& attrs,
                   const TypeReporter& reporter) {
  CHECK_EQ(types.size(), 2);
  const auto* data = types[0].as<TensorTypeNode>();
  if (data == nullptr) return false;

  static const Layout kNCHW("NCHW");

  const UpSamplingAttrs* param = attrs.as<UpSamplingAttrs>();
  CHECK(param != nullptr);
  const Layout in_layout(param->layout);
78

79
  auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW);
80
  CHECK(layout_converter.defined())
81 82 83
    << "UpSampling only support input layouts that are convertible from NCHW."
    << " But got " << in_layout;

84
  auto oshape = layout_converter.ForwardShape(data->shape);
85 86
  oshape.Set(2, tir::CastNode::make(oshape[2].dtype(), tvm::round(oshape[2] * param->scale_h)));
  oshape.Set(3, tir::CastNode::make(oshape[3].dtype(), tvm::round(oshape[3] * param->scale_w)));
87 88 89

  // assign output type
  reporter->Assign(types[1],
90
                   TensorType(layout_converter.BackwardShape(oshape),
91 92 93 94 95 96 97 98
                                        data->dtype));
  return true;
}


// Positional relay function to create upsampling operator
// used by frontend FFI.
Expr MakeUpSampling(Expr data,
99 100
                    double scale_h,
                    double scale_w,
101
                    std::string layout,
102 103
                    std::string method,
                    bool align_corners) {
104
  auto attrs = make_object<UpSamplingAttrs>();
105 106
  attrs->layout = std::move(layout);
  attrs->method = std::move(method);
107 108
  attrs->scale_h = scale_h;
  attrs->scale_w = scale_w;
109
  attrs->align_corners = align_corners;
110
  static const Op& op = Op::Get("nn.upsampling");
111
  return Call(op, {data}, Attrs(attrs), {});
112 113
}

114
TVM_REGISTER_GLOBAL("relay.op.nn._make.upsampling")
115
.set_body_typed(MakeUpSampling);
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132


RELAY_REGISTER_OP("nn.upsampling")
.describe(R"code(Perform upsampling on input array with nearest neighbour or bilinear interpolation.

- **data**: data is 4D array of shape
            (batch_size, channels, in_height, in_width) for NCHW
            (batch_size, in_height, in_width, channels) for NHWC

- **out**: Output is 4D array of shape
           for layout NCHW
           (batch_size, channels, in_height*scale, in_width*scale)

           for layout NHWC
           (batch_size, in_height*scale, in_width*scale, channels)

)code" TVM_ADD_FILELINE)
133
.set_attrs_type<UpSamplingAttrs>()
134 135 136
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2)
137
.add_type_rel("UpSampling", UpSamplingRel)
138 139
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
  UpsamplingInferCorrectLayout<UpSamplingAttrs>)
140
.set_attr<TOpPattern>("TOpPattern", kInjective);
141

142

143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
// UpSampling3D
bool UpSampling3DRel(const Array<Type>& types,
                     int num_inputs,
                     const Attrs& attrs,
                     const TypeReporter& reporter) {
  CHECK_EQ(types.size(), 2);
  const auto* data = types[0].as<TensorTypeNode>();
  if (data == nullptr) return false;

  static const Layout kNCDHW("NCDHW");

  const UpSampling3DAttrs* param = attrs.as<UpSampling3DAttrs>();
  CHECK(param != nullptr);
  const Layout in_layout(param->layout);

158
  auto layout_converter = tir::BijectiveLayout(in_layout, kNCDHW);
159 160 161 162 163
  CHECK(layout_converter.defined())
    << "UpSampling3D only support input layouts that are convertible from NCDHW."
    << " But got " << in_layout;

  auto oshape = layout_converter.ForwardShape(data->shape);
164 165 166
  oshape.Set(2, tir::CastNode::make(oshape[2].dtype(), tvm::round(oshape[2] * param->scale_d)));
  oshape.Set(3, tir::CastNode::make(oshape[3].dtype(), tvm::round(oshape[3] * param->scale_h)));
  oshape.Set(4, tir::CastNode::make(oshape[4].dtype(), tvm::round(oshape[4] * param->scale_w)));
167 168 169

  // assign output type
  reporter->Assign(types[1],
170
                   TensorType(layout_converter.BackwardShape(oshape),
171 172 173 174 175 176 177 178 179 180 181 182 183
                                        data->dtype));
  return true;
}

// Positional relay function to create upsampling3d operator
// used by frontend FFI.
Expr MakeUpSampling3D(Expr data,
                      double scale_d,
                      double scale_h,
                      double scale_w,
                      std::string layout,
                      std::string method,
                      std::string coordinate_transformation_mode) {
184
  auto attrs = make_object<UpSampling3DAttrs>();
185 186 187 188 189 190 191
  attrs->layout = std::move(layout);
  attrs->method = std::move(method);
  attrs->scale_d = scale_d;
  attrs->scale_h = scale_h;
  attrs->scale_w = scale_w;
  attrs->coordinate_transformation_mode = coordinate_transformation_mode;
  static const Op& op = Op::Get("nn.upsampling3d");
192
  return Call(op, {data}, Attrs(attrs), {});
193 194
}

195
TVM_REGISTER_GLOBAL("relay.op.nn._make.upsampling3d")
196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
.set_body_typed(MakeUpSampling3D);


RELAY_REGISTER_OP("nn.upsampling3d")
.describe(R"code(Perform upsampling on input array with nearest neighbour or
bilinear interpolation.

- **data**: data is 5D array of shape
            (batch_size, channels, in_depth, in_height, in_width) for NCDHW
            (batch_size, in_depth, in_height, in_width, channels) for NDHWC

- **out**: Output is 5D array of shape
           for layout NCDHW
           (batch_size, channels, in_depth*scale, in_height*scale, in_width*scale)

           for layout NDHWC
           (batch_size, in_depth*scale, in_height*scale, in_width*scale, channels)

)code" TVM_ADD_FILELINE)
.set_attrs_type<UpSampling3DAttrs>()
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2)
.add_type_rel("UpSampling3D", UpSampling3DRel)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
  UpsamplingInferCorrectLayout<UpSampling3DAttrs>)
.set_attr<TOpPattern>("TOpPattern", kInjective);

224 225
}  // namespace relay
}  // namespace tvm