upsampling.cc 3.82 KB
Newer Older
1 2 3 4 5
/*!
 *  Copyright (c) 2018 by Contributors
 * \file upsampling.cc
 * \brief upsampling operator
 */
6
#include <tvm/data_layout.h>
7 8
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h>
9
#include <tvm/relay/op_attr_types.h>
10
#include <tvm/build_module.h>
11 12
#include <topi/elemwise.h>
#include <topi/nn/upsampling.h>
13 14
#include <vector>
#include "../op_common.h"
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33

namespace tvm {
namespace relay {

TVM_REGISTER_NODE_TYPE(UpSamplingAttrs);

bool UpSamplingRel(const Array<Type>& types,
                   int num_inputs,
                   const Attrs& attrs,
                   const TypeReporter& reporter) {
  CHECK_EQ(types.size(), 2);
  const auto* data = types[0].as<TensorTypeNode>();
  if (data == nullptr) return false;

  static const Layout kNCHW("NCHW");

  const UpSamplingAttrs* param = attrs.as<UpSamplingAttrs>();
  CHECK(param != nullptr);
  const Layout in_layout(param->layout);
34 35 36

  auto layout_converter = BijectiveLayoutNode::make(in_layout, kNCHW);
  CHECK(layout_converter.defined())
37 38 39
    << "UpSampling only support input layouts that are convertible from NCHW."
    << " But got " << in_layout;

40
  auto oshape = layout_converter.ForwardShape(data->shape);
41

42 43
  oshape.Set(2, oshape[2] * param->scale);
  oshape.Set(3, oshape[3] * param->scale);
44 45 46

  // assign output type
  reporter->Assign(types[1],
47
                   TensorTypeNode::make(layout_converter.BackwardShape(oshape),
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
                                        data->dtype));
  return true;
}


// Positional relay function to create upsampling operator
// used by frontend FFI.
Expr MakeUpSampling(Expr data,
                    int scale,
                    std::string layout,
                    std::string method) {
  auto attrs = make_node<UpSamplingAttrs>();
  attrs->layout = std::move(layout);
  attrs->method = std::move(method);
  attrs->scale = scale;
  static const Op& op = Op::Get("nn.upsampling");
  return CallNode::make(op, {data}, Attrs(attrs), {});
}


TVM_REGISTER_API("relay.op.nn._make.upsampling")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
    runtime::detail::unpack_call<Expr, 4>(MakeUpSampling, args, rv);
  });


RELAY_REGISTER_OP("nn.upsampling")
.describe(R"code(Perform upsampling on input array with nearest neighbour or bilinear interpolation.

- **data**: data is 4D array of shape
            (batch_size, channels, in_height, in_width) for NCHW
            (batch_size, in_height, in_width, channels) for NHWC

- **out**: Output is 4D array of shape
           for layout NCHW
           (batch_size, channels, in_height*scale, in_width*scale)

           for layout NHWC
           (batch_size, in_height*scale, in_width*scale, channels)

)code" TVM_ADD_FILELINE)
89
.set_attrs_type_key("relay.attrs.UpSamplingAttrs")
90 91 92
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2)
93
.add_type_rel("UpSampling", UpSamplingRel)
94
.set_attr<TOpPattern>("TOpPattern", kInjective)
95 96
.set_attr<FTVMCompute>(
  "FTVMCompute", [](const Attrs& attrs,
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
                    const Array<Tensor>& inputs,
                    const Type& out_type,
                    const Target& target) {
    const auto* uattrs = attrs.as<UpSamplingAttrs>();
    CHECK(uattrs != nullptr);
    auto out_tt = out_type.as<TensorTypeNode>();
    CHECK(out_tt) << "expected a tensor type: " << out_type;
    CHECK(uattrs->layout == "NCHW" || uattrs->layout == "NHWC")
      << "unknown layout: " << uattrs->layout;

    Array<HalideIR::Expr> oshape;
    if (uattrs->layout == "NCHW") {
      oshape.push_back(out_tt->shape[2]);
      oshape.push_back(out_tt->shape[3]);
    } else if (uattrs->layout == "NHWC") {
      oshape.push_back(out_tt->shape[1]);
      oshape.push_back(out_tt->shape[2]);
    }

    return Array<Tensor>{
      topi::nn::upsampling(
        inputs[0],
        oshape,
        uattrs->layout,
        uattrs->method)
    };
123
});
124

125

126 127
}  // namespace relay
}  // namespace tvm