convolution.cc 22.6 KB
Newer Older
1 2 3 4 5 6 7 8
/*!
 *  Copyright (c) 2018 by Contributors
 * \file convolution.cc
 * \brief Convolution operators
 */
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h>
#include <vector>
9

10
#include "../../pass/alter_op_layout.h"
11
#include "../layout.h"
12 13 14 15

namespace tvm {
namespace relay {

16
// relay.nn.conv2d
17
TVM_REGISTER_NODE_TYPE(Conv2DAttrs);
18 19 20 21 22 23 24 25 26 27 28 29

bool Conv2DRel(const Array<Type>& types,
               int num_inputs,
               const Attrs& attrs,
               const TypeReporter& reporter) {
  CHECK_EQ(types.size(), 3);
  const auto* data = types[0].as<TensorTypeNode>();
  const auto* weight = types[1].as<TensorTypeNode>();
  if (data == nullptr) return false;
  static const Layout kNCHW("NCHW");
  static const Layout kOIHW("OIHW");

30
  const Conv2DAttrs* param = attrs.as<Conv2DAttrs>();
31 32
  CHECK(param != nullptr);
  const Layout in_layout(param->data_layout);
33
  const Layout kernel_layout(param->kernel_layout);
34
  CHECK(in_layout.Convertible(kNCHW))
35 36
    << "Conv only support input layouts that are convertible from NCHW."
    << " But got " << in_layout;
37
  CHECK(kernel_layout.Convertible(kOIHW))
38 39 40
    << "Conv only support kernel layouts that are convertible from OIHW."
    << " But got "<< kernel_layout;

41
  Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
42
  CHECK(out_layout.Convertible(kNCHW))
43 44 45
      << "Conv only support output layouts that are convertible from NCHW."
      << " But got " << out_layout;

46 47 48
  std::vector<IndexExpr> dshape_nchw = ConvertLayout(
      data->shape, in_layout, kNCHW);

49 50 51 52 53 54
  IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
  // infer weight if the kernel_size and channels are defined
  if (param->kernel_size.defined() && param->channels.defined()) {
    CHECK_EQ(param->kernel_size.size(), 2);
    CHECK_EQ(param->dilation.size(), 2);
    std::vector<IndexExpr> wshape(
55
       {param->channels,
56
         dshape_nchw[1] / param->groups,
57 58
         param->kernel_size[0],
         param->kernel_size[1]});
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
    wshape = ConvertLayout(wshape, kOIHW, kernel_layout);
    channels = param->channels;
    dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
    dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
    // assign result to reporter
    reporter->Assign(types[1], TensorTypeNode::make(wshape, data->dtype));
  } else {
    // use weight to infer the conv shape.
    if (weight == nullptr) return false;
    auto wshape = ConvertLayout(weight->shape, kernel_layout, kOIHW);
    if (param->kernel_size.defined()) {
      CHECK_EQ(param->kernel_size.size(), 2);
      // check the size
      CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) &&
            reporter->AssertEQ(param->kernel_size[1], wshape[3]))
          << "Conv2D: shape of weight is inconsistent with kernel_size, "
          << " kernel_size=" << param->kernel_size
          << " wshape=" << Array<IndexExpr>(wshape);
    }
    if (param->channels.defined()) {
      CHECK(reporter->AssertEQ(param->channels, wshape[0]))
          << "Conv2D: shape of weight is inconsistent with channels, "
          << " channels=" << param->channels
          << " wshape=" << Array<IndexExpr>(wshape);
    }
84
    CHECK(reporter->AssertEQ(dshape_nchw[1] / param->groups, wshape[1]));
85 86 87 88 89
    channels = wshape[0];
    dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0];
    dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
  }
  // dilation
90
  std::vector<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
91

92 93
  oshape[2] = (dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y) / param->strides[0] + 1;
  oshape[3] = (dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x) / param->strides[1] + 1;
94 95 96 97 98 99 100 101 102 103
  DataType out_dtype = param->out_dtype;
  if (out_dtype.bits() == 0) {
    out_dtype = data->dtype;
  }
  oshape = ConvertLayout(oshape, kNCHW, out_layout);
  // assign output type
  reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
  return true;
}

104 105 106 107 108 109 110 111 112 113
template<typename T>
Array<Array<Layout> > Conv2DInferCorrectLayout(
    const Attrs& attrs,
    const Array<Layout>& new_in_layouts,
    const Array<Layout>& old_in_layouts,
    const Array<Array<IndexExpr>> &old_in_shapes) {
  const T* params = attrs.as<T>();

  // We always make other operators to fit the layouts of convolution layers
  // So this inference ignores all inputs
114 115 116
  return Array<Array<Layout> >{{params->data_layout, params->kernel_layout},
                               {params->out_layout == "" ?
                                   params->data_layout : params->out_layout}};
117
}
118 119 120 121 122 123 124 125 126 127 128 129

// Positional relay function to create conv2d operator
// used by frontend FFI.
Expr MakeConv2D(Expr data,
                Expr weight,
                Array<IndexExpr> strides,
                Array<IndexExpr> padding,
                Array<IndexExpr> dilation,
                int groups,
                IndexExpr channels,
                Array<IndexExpr> kernel_size,
                std::string data_layout,
130
                std::string kernel_layout,
131 132
                std::string out_layout,
                DataType out_dtype) {
133
  auto attrs = make_node<Conv2DAttrs>();
134 135 136 137
  attrs->strides = std::move(strides);
  attrs->padding = std::move(padding);
  attrs->dilation = std::move(dilation);
  attrs->groups = groups;
138 139
  attrs->channels = std::move(channels);
  attrs->kernel_size = std::move(kernel_size);
140
  attrs->data_layout = std::move(data_layout);
141
  attrs->kernel_layout = std::move(kernel_layout);
142 143
  attrs->out_layout = std::move(out_layout);
  attrs->out_dtype = std::move(out_dtype);
144
  static const Op& op = Op::Get("nn.conv2d");
145 146 147 148
  return CallNode::make(op, {data, weight}, Attrs(attrs), {});
}


149
TVM_REGISTER_API("relay.op.nn._make.conv2d")
150 151 152 153 154
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
    runtime::detail::unpack_call<Expr, 12>(MakeConv2D, args, rv);
  });


155
RELAY_REGISTER_OP("nn.conv2d")
156 157 158 159 160 161 162 163 164 165 166 167
.describe(R"code(2D convolution layer (e.g. spatial convolution over images).

This layer creates a convolution kernel that is convolved
with the layer input to produce a tensor of outputs.

- **data**: This depends on the `layout` parameter. Input is 4D array of shape
            (batch_size, in_channels, height, width) if `layout` is `NCHW`.
- **weight**: (channels, in_channels, kernel_size[0], kernel_size[1])
- **out**:  This depends on the `layout` parameter. Output is 4D array of shape
            (batch_size, channels, out_height, out_width) if `layout` is `NCHW`.

)code" TVM_ADD_FILELINE)
168
.set_attrs_type_key("relay.attrs.Conv2DAttrs")
169 170 171 172
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(2)
173 174
.add_type_rel("Conv2D", Conv2DRel)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", Conv2DInferCorrectLayout<Conv2DAttrs>);
175

176

177
// relay.nn.conv2d_transpose
178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
TVM_REGISTER_NODE_TYPE(Conv2DTransposeAttrs);

bool Conv2DTransposeRel(const Array<Type>& types,
                        int num_inputs,
                        const Attrs& attrs,
                        const TypeReporter& reporter) {
  CHECK_EQ(types.size(), 3);
  const auto* data = types[0].as<TensorTypeNode>();
  const auto* weight = types[1].as<TensorTypeNode>();
  if (data == nullptr) return false;

  static const Layout kNCHW("NCHW");
  static const Layout kOIHW("OIHW");

  const Conv2DTransposeAttrs* param = attrs.as<Conv2DTransposeAttrs>();
  CHECK(param != nullptr);
  const Layout in_layout(param->data_layout);
195
  const Layout kernel_layout(param->kernel_layout);
196
  CHECK(in_layout.Convertible(kNCHW))
197 198
    << "Conv only support input layouts that are convertible from NCHW."
    << " But got " << in_layout;
199
  CHECK(kernel_layout.Convertible(kOIHW))
200 201 202
    << "Conv only support kernel layouts that are convertible from OIHW."
    << " But got "<< kernel_layout;

203
  Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
204 205 206 207
  CHECK(out_layout.Convertible(kNCHW))
    << "Conv only support output layouts that are convertible from NCHW."
    << " But got " << out_layout;

208
  IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
209 210 211

  auto dshape_nchw = ConvertLayout(data->shape, in_layout, kNCHW);

212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263
  // infer weight if the kernel_size and channels are defined
  if (param->kernel_size.defined() && param->channels.defined()) {
    CHECK_EQ(param->kernel_size.size(), 2);
    CHECK_EQ(param->dilation.size(), 2);

    std::vector<IndexExpr> wshape({dshape_nchw[1],
                                   param->channels / param->groups,
                                   param->kernel_size[0],
                                   param->kernel_size[1]});

    wshape = ConvertLayout(wshape, kOIHW, kernel_layout);
    dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
    dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
    channels = param->channels;

    // assign result to reporter
    reporter->Assign(types[1], TensorTypeNode::make(wshape, data->dtype));
  } else {
    // use weight to infer the conv shape.
    if (weight == nullptr) return false;
    auto wshape = ConvertLayout(weight->shape, kernel_layout, kOIHW);
    if (param->kernel_size.defined()) {
      CHECK_EQ(param->kernel_size.size(), 2);
      // check the size
      CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) &&
            reporter->AssertEQ(param->kernel_size[1], wshape[3]))
          << "Conv2D: shape of weight is inconsistent with kernel_size, "
          << " kernel_size=" << param->kernel_size
          << " wshape=" << Array<IndexExpr>(wshape);
    }
    if (param->channels.defined()) {
      CHECK(reporter->AssertEQ(param->channels, wshape[1]))
          << "Conv2D: shape of weight is inconsistent with channels, "
          << " channels=" << param->channels
          << " wshape=" << Array<IndexExpr>(wshape);
    }
    CHECK(reporter->AssertEQ(dshape_nchw[1] / param->groups, wshape[0]));
    channels = wshape[1];
    dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0];
    dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
  }
  // dilation
  std::vector<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
  oshape[2] = (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y -
               2 * param->padding[0] + param->output_padding[0]);
  oshape[3] = (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x -
               2 * param->padding[1] + param->output_padding[1]);

  DataType out_dtype = param->out_dtype;
  if (out_dtype.bits() == 0) {
    out_dtype = data->dtype;
  }
264
  oshape = ConvertLayout(oshape, kNCHW, out_layout);
265 266 267 268 269 270 271 272 273 274 275 276 277 278
  reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
  return true;
}


Expr MakeConv2DTranspose(Expr data,
                         Expr weight,
                         Array<IndexExpr> strides,
                         Array<IndexExpr> padding,
                         Array<IndexExpr> dilation,
                         int groups,
                         IndexExpr channels,
                         Array<IndexExpr> kernel_size,
                         std::string data_layout,
279
                         std::string kernel_layout,
280 281 282
                         Array<IndexExpr> output_padding,
                         DataType out_dtype) {
  auto attrs = make_node<Conv2DTransposeAttrs>();
283 284
  attrs->channels = std::move(channels);
  attrs->kernel_size = std::move(kernel_size);
285 286 287 288 289 290
  attrs->strides = std::move(strides);
  attrs->padding = std::move(padding);
  attrs->output_padding = std::move(output_padding);
  attrs->dilation = std::move(dilation);
  attrs->groups = groups;
  attrs->data_layout = std::move(data_layout);
291
  attrs->kernel_layout = std::move(kernel_layout);
292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324
  attrs->out_dtype = std::move(out_dtype);
  static const Op& op = Op::Get("nn.conv2d_transpose");
  return CallNode::make(op, {data, weight}, Attrs(attrs), {});
}


TVM_REGISTER_API("relay.op.nn._make.conv2d_transpose")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
    runtime::detail::unpack_call<Expr, 12>(MakeConv2DTranspose, args, rv);
  });

RELAY_REGISTER_OP("nn.conv2d_transpose")
.describe(R"code(Transposed 2D convolution layer (sometimes called Deconvolution).

The need for transposed convolutions generally arises
from the desire to use a transformation going in the opposite direction
of a normal convolution, i.e., from something that has the shape of the
output of some convolution to something that has the shape of its input
while maintaining a connectivity pattern that is compatible with
said convolution.

- **data**: This depends on the `layout` parameter. Input is 4D array of shape
            (batch_size, in_channels, height, width) if `layout` is `NCHW`.
- **weight**: (in_channels, channels, kernel_size[0], kernel_size[1])
- **bias**: (channels,)
- **out**:  This depends on the `layout` parameter. Output is 4D array of shape
v            (batch_size, channels, out_height, out_width) if `layout` is `NCHW`.

            out_height and out_width are calculated as::
                out_height = (height-1)*strides[0]-2*padding[0]+kernel_size[0]+output_padding[0]
                out_width = (width-1)*strides[1]-2*padding[1]+kernel_size[1]+output_padding[1]

)code" TVM_ADD_FILELINE)
325
.set_attrs_type_key("relay.attrs.Conv2DTransposeAttrs")
326 327 328 329
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(2)
330 331
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
                               Conv2DInferCorrectLayout<Conv2DTransposeAttrs>)
332 333
.add_type_rel("Conv2DTranspose", Conv2DTransposeRel);

334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455

// relay.nn.contrib_conv2d_winograd_without_weight_transform
TVM_REGISTER_NODE_TYPE(Conv2DWinogradAttrs);

bool Conv2DWinogradRel(const Array<Type>& types,
                       int num_inputs,
                       const Attrs& attrs,
                       const TypeReporter& reporter) {
  CHECK_EQ(types.size(), 3);
  const auto* data = types[0].as<TensorTypeNode>();
  if (data == nullptr) return false;
  static const Layout kNCHW("NCHW");
  static const Layout kOIHW("OIHW");

  const Conv2DWinogradAttrs* param = attrs.as<Conv2DWinogradAttrs>();
  CHECK(param != nullptr);
  const Layout in_layout(param->data_layout);
  const Layout kernel_layout(param->kernel_layout);
  CHECK(in_layout.Convertible(kNCHW))
    << "Conv only support input layouts that are convertible from NCHW."
    << " But got " << in_layout;
  CHECK(kernel_layout.Convertible(kOIHW))
    << "Conv only support kernel layouts that are convertible from OIHW."
    << " But got "<< kernel_layout;

  Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
  CHECK(out_layout.Convertible(kNCHW))
      << "Conv only support output layouts that are convertible from NCHW."
      << " But got " << out_layout;

  std::vector<IndexExpr> dshape_nchw = ConvertLayout(
      data->shape, in_layout, kNCHW);

  IndexExpr channels, dilated_ksize_y, dilated_ksize_x;

  CHECK(param->kernel_size.defined() && param->channels.defined())
      << "The kernel size and channels of a Conv must be set or infered by previous pass";

  CHECK_EQ(param->kernel_size.size(), 2);
  CHECK_EQ(param->dilation.size(), 2);

  channels = param->channels;
  dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
  dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];

  // NOTE: Do not check weight shape here!
  // Different backend requires different layout to compute
  // the batch gemm stage in winograd efficiently, but we want to
  // make this op work for all backends.
  // So we accept all weight shapes, and assume the TOPI developers
  // can handle this correctly in alter_op_layout.

  // dilation
  std::vector<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});

  oshape[2] = (dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y) / param->strides[0] + 1;
  oshape[3] = (dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x) / param->strides[1] + 1;
  DataType out_dtype = param->out_dtype;
  if (out_dtype.bits() == 0) {
    out_dtype = data->dtype;
  }
  oshape = ConvertLayout(oshape, kNCHW, out_layout);
  // assign output type
  reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
  return true;
}


// Positional relay function to create conv2d winograd operator
// used by frontend FFI.
Expr MakeConv2DWinograd(Expr data,
                        Expr weight,
                        int tile_size,
                        Array<IndexExpr> strides,
                        Array<IndexExpr> padding,
                        Array<IndexExpr> dilation,
                        int groups,
                        IndexExpr channels,
                        Array<IndexExpr> kernel_size,
                        std::string data_layout,
                        std::string kernel_layout,
                        std::string out_layout,
                        DataType out_dtype) {
  auto attrs = make_node<Conv2DWinogradAttrs>();
  attrs->tile_size = tile_size;
  attrs->strides = std::move(strides);
  attrs->padding = std::move(padding);
  attrs->dilation = std::move(dilation);
  attrs->groups = groups;
  attrs->channels = channels;
  attrs->kernel_size = std::move(kernel_size);
  attrs->data_layout = std::move(data_layout);
  attrs->kernel_layout = std::move(kernel_layout);
  attrs->out_layout = std::move(out_layout);
  attrs->out_dtype = std::move(out_dtype);
  static const Op& op = Op::Get("nn.contrib_conv2d_winograd_without_weight_transform");
  return CallNode::make(op, {data, weight}, Attrs(attrs), {});
}


TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_winograd_without_weight_transform")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
    runtime::detail::unpack_call<Expr, 13>(MakeConv2DWinograd, args, rv);
  });


RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_without_weight_transform")
.describe(R"code(Compute conv2d with winograd algorithm. Only supports NCHW layout.
                 This operator assumes the weight tensor is already pre-transformed by
                 nn.contrib_conv2d_winograd_weight_transform.

- **data**: Input is 4D array of shape  (batch_size, in_channels, height, width)
- **weight**: Any shape
            We do not check the shape for this input tensor. Since different backend
            has different layout strategy.

- **out**:  Output is 4D array of shape (batch_size, channels, out_height, out_width)
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.Conv2DWinograd")
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
456
.set_support_level(10)
457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515
.add_type_rel("Conv2DWinograd", Conv2DWinogradRel)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
        Conv2DInferCorrectLayout<Conv2DWinogradAttrs>);

// relay.nn.contrib_conv2d_winograd_weight_transform
TVM_REGISTER_NODE_TYPE(Conv2DWinogradWeightTransformAttrs);

bool Conv2DWinogradWeightTransformRel(const Array<Type>& types,
                                      int num_inputs,
                                      const Attrs& attrs,
                                      const TypeReporter& reporter) {
  CHECK_EQ(types.size(), 2);
  const auto* data = types[0].as<TensorTypeNode>();
  if (data == nullptr) return false;

  const Conv2DWinogradWeightTransformAttrs* param = attrs.as<Conv2DWinogradWeightTransformAttrs>();
  CHECK(param != nullptr);

  CHECK_EQ(data->shape.size(), 4) << "Only support NCHW normal kernel layout";

  // each pad width element should be a pair of positive integers
  std::vector<IndexExpr> oshape {
      param->tile_size + data->shape[2] - 1,
      param->tile_size + data->shape[3] - 1,
      data->shape[0],
      data->shape[1],
  };

  reporter->Assign(types[1], TensorTypeNode::make(Array<IndexExpr>(oshape),
                                                  data->dtype));
  return true;
}

Expr MakeConv2DWinogradWeightTransform(Expr weight,
                                       int tile_size) {
  auto attrs = make_node<Conv2DWinogradWeightTransformAttrs>();
  attrs->tile_size = tile_size;
  static const Op& op = Op::Get("nn.contrib_conv2d_winograd_weight_transform");
  return CallNode::make(op, {weight}, Attrs(attrs), {});
}


TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_winograd_weight_transform")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
    runtime::detail::unpack_call<Expr, 2>(MakeConv2DWinogradWeightTransform, args, rv);
  });


RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_weight_transform")
.describe(R"code(Weight transformation of winograd fast convolution algorithm.

Separate this into another nnvm symbol in order to enable Precompute Pass to compute the
weight transformation in advance.

- **weight**: (channels, in_channels, kernel_size[0], kernel_size[1])
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.Conv2DWinogradWeightTransformAttrs")
.set_num_inputs(1)
.add_argument("weight", "Tensor", "The weight tensor.")
516
.set_support_level(10)
517 518
.add_type_rel("Conv2DWinogradWeightTransform", Conv2DWinogradWeightTransformRel);

519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571

// Positional relay function to create conv2d NCHWc operator
// used by frontend FFI.
Expr MakeConv2DNCHWc(Expr data,
                     Expr kernel,
                     Array<IndexExpr> strides,
                     Array<IndexExpr> padding,
                     Array<IndexExpr> dilation,
                     int groups,
                     IndexExpr channels,
                     Array<IndexExpr> kernel_size,
                     std::string data_layout,
                     std::string kernel_layout,
                     std::string out_layout,
                     DataType out_dtype) {
  auto attrs = make_node<Conv2DAttrs>();
  attrs->strides = std::move(strides);
  attrs->padding = std::move(padding);
  attrs->dilation = std::move(dilation);
  attrs->groups = groups;
  attrs->channels = channels;
  attrs->kernel_size = std::move(kernel_size);
  attrs->data_layout = std::move(data_layout);
  attrs->kernel_layout = std::move(kernel_layout);
  attrs->out_layout = std::move(out_layout);
  attrs->out_dtype = std::move(out_dtype);
  static const Op& op = Op::Get("nn.contrib_conv2d_NCHWc");
  return CallNode::make(op, {data, kernel}, Attrs(attrs), {});
}

TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_NCHWc")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
    runtime::detail::unpack_call<Expr, 12>(MakeConv2DNCHWc, args, rv);
  });


RELAY_REGISTER_OP("nn.contrib_conv2d_NCHWc")
.describe(R"code(Compute conv2d with NCHWc data layout. Only supports NCHW layout.
- **data**: Input is 5D packed tensor.
- **weight**: 6D packed tensor.

- **out**:  Output is 5D packed tensor
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.Conv2D")
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(10)
.add_type_rel("Conv2D", Conv2DRel)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
        Conv2DInferCorrectLayout<Conv2DAttrs>);


572 573
}  // namespace relay
}  // namespace tvm