convolution.cc 5.06 KB
Newer Older
1 2 3 4 5 6 7 8
/*!
 *  Copyright (c) 2017 by Contributors
 * \file Use external nnpack library call.
 */
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <dmlc/logging.h>
#include <nnpack.h>
9
#include "./nnpack_utils.h"
10 11 12 13 14 15 16

namespace tvm {
namespace contrib {
using namespace runtime;

TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference")
.set_body([](TVMArgs args, TVMRetValue *ret) {
17
    NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal();
18 19 20 21 22
    nnp_initialize();
    DLTensor* input  = args[0];
    DLTensor* kernel = args[1];
    DLTensor* bias   = args[2];
    DLTensor* output = args[3];
23 24 25 26
    uint64_t pad_top = args[4], pad_right = args[5], pad_bottom = args[6], pad_left = args[7];
    nnp_padding input_padding{pad_top, pad_right, pad_bottom, pad_left};
    uint64_t stride_width = args[8], stride_height = args[9];
    nnp_size stride_size{stride_width, stride_height};
27
    NNPackConfig(args[10]);
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47

    CHECK_EQ(input->ndim, 3);
    CHECK_EQ(kernel->ndim, 4);
    CHECK_EQ(bias->ndim, 1);
    CHECK_EQ(output->ndim, 3);

    CHECK_EQ(input->shape[0], kernel->shape[1]);
    size_t input_channels = input->shape[0];
    CHECK_EQ(output->shape[0], kernel->shape[0]);
    CHECK_EQ(output->shape[0], bias->shape[0]);
    size_t output_channels = output->shape[0];
    nnp_size input_size{static_cast<size_t>(input->shape[1]),
                        static_cast<size_t>(input->shape[2])};
    nnp_size kernel_size{static_cast<size_t>(kernel->shape[2]),
                         static_cast<size_t>(kernel->shape[3])};

    CHECK(input->strides == nullptr);
    CHECK(kernel->strides == nullptr);
    CHECK(bias->strides == nullptr);

48 49 50 51
    CHECK(TypeMatch(input->dtype, kDLFloat, 32));
    CHECK(TypeMatch(kernel->dtype, kDLFloat, 32));
    CHECK(TypeMatch(bias->dtype, kDLFloat, 32));
    CHECK(TypeMatch(output->dtype, kDLFloat, 32));
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68

    nnp_convolution_inference(nnp_convolution_algorithm_auto,
                              nnp_convolution_transform_strategy_block_based,
                              input_channels,
                              output_channels,
                              input_size,
                              input_padding,
                              kernel_size,
                              stride_size,
                              static_cast<float*>(input->data),
                              static_cast<float*>(kernel->data),
                              static_cast<float*>(bias->data),
                              static_cast<float*>(output->data),
                              NULL,
                              NULL,
                              nnp_activation_identity,
                              NULL,
69
                              entry->threadpool,
70 71 72 73 74 75
                              NULL);
  });


TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_output")
.set_body([](TVMArgs args, TVMRetValue *ret) {
76
    NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal();
77 78 79 80 81
    nnp_initialize();
    DLTensor* input  = args[0];
    DLTensor* kernel = args[1];
    DLTensor* bias   = args[2];
    DLTensor* output = args[3];
82 83
    uint64_t pad_top = args[4], pad_right = args[5], pad_bottom = args[6], pad_left = args[7];
    nnp_padding input_padding{pad_top, pad_right, pad_bottom, pad_left};
84
    NNPackConfig(args[8]);
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106

    CHECK_EQ(input->ndim, 4);
    CHECK_EQ(kernel->ndim, 4);
    CHECK_EQ(bias->ndim, 1);
    CHECK_EQ(output->ndim, 4);

    CHECK_EQ(input->shape[0], output->shape[0]);
    size_t batch_size = input->shape[0];
    CHECK_EQ(input->shape[1], kernel->shape[1]);
    size_t input_channels = input->shape[1];
    CHECK_EQ(output->shape[1], bias->shape[0]);
    CHECK_EQ(output->shape[1], kernel->shape[0]);
    size_t output_channels = output->shape[1];
    nnp_size input_size{static_cast<size_t>(input->shape[2]),
                        static_cast<size_t>(input->shape[3])};
    nnp_size kernel_size{static_cast<size_t>(kernel->shape[2]),
                         static_cast<size_t>(kernel->shape[3])};

    CHECK(input->strides == nullptr);
    CHECK(kernel->strides == nullptr);
    CHECK(bias->strides == nullptr);

107 108 109 110
    CHECK(TypeMatch(input->dtype, kDLFloat, 32));
    CHECK(TypeMatch(kernel->dtype, kDLFloat, 32));
    CHECK(TypeMatch(bias->dtype, kDLFloat, 32));
    CHECK(TypeMatch(output->dtype, kDLFloat, 32));
111 112 113 114 115 116 117 118 119 120 121 122 123 124

    nnp_convolution_output(nnp_convolution_algorithm_auto,
                           batch_size,
                           input_channels,
                           output_channels,
                           input_size,
                           input_padding,
                           kernel_size,
                           static_cast<float*>(input->data),
                           static_cast<float*>(kernel->data),
                           static_cast<float*>(bias->data),
                           static_cast<float*>(output->data),
                           NULL,
                           NULL,
125 126 127
                           nnp_activation_identity,
                           NULL,
                           entry->threadpool,
128 129 130 131
                           NULL);
  });
}  // namespace contrib
}  // namespace tvm