Commit c9a3e2ea by hlu1 Committed by Tianqi Chen

[nnpack] Preallocate workspace buffer (#2369)

parent a5eb4451
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
* Copyright (c) 2017 by Contributors * Copyright (c) 2017 by Contributors
* \file Use external nnpack library call. * \file Use external nnpack library call.
*/ */
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h> #include <tvm/runtime/util.h>
#include <dmlc/logging.h> #include <dmlc/logging.h>
...@@ -72,6 +73,25 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference") ...@@ -72,6 +73,25 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference")
zero_bias.reset(new std::vector<float>(output->shape[1], 0.0)); zero_bias.reset(new std::vector<float>(output->shape[1], 0.0));
} }
size_t workspace_size = 0;
nnp_status status = nnp_convolution_inference(
algo, nnp_convolution_transform_strategy_compute, input_channels,
output_channels, input_size, input_padding, kernel_size, stride_size,
nullptr, nullptr, nullptr, nullptr, nullptr, &workspace_size,
nnp_activation_identity, nullptr, entry->threadpool, nullptr);
CHECK_EQ(status, nnp_status_success);
// Division with rounding up, in case size is not multiple of sizeof(float)
const size_t workspace_elements = (workspace_size + sizeof(float) - 1) / sizeof(float);
TVMContext ctx = input->ctx;
TVMType type_hint = input->dtype;
DeviceAPI* cpu_api = DeviceAPI::Get(ctx);
void* workspace_buffer =
cpu_api->AllocWorkspace(ctx, workspace_elements * sizeof(float), type_hint);
CHECK(workspace_buffer != nullptr);
for (auto n = 0; n < input->shape[0]; ++n) { for (auto n = 0; n < input->shape[0]; ++n) {
nnp_status status = nnp_convolution_inference( nnp_status status = nnp_convolution_inference(
algo, nnp_convolution_transform_strategy_compute, input_channels, algo, nnp_convolution_transform_strategy_compute, input_channels,
...@@ -85,10 +105,12 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference") ...@@ -85,10 +105,12 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference")
static_cast<float *>(output->data) + n * output->shape[1] * static_cast<float *>(output->data) + n * output->shape[1] *
output->shape[2] * output->shape[2] *
output->shape[3], output->shape[3],
NULL, NULL, nnp_activation_identity, NULL, entry->threadpool, NULL); workspace_buffer, &workspace_size,
nnp_activation_identity, nullptr, entry->threadpool, nullptr);
CHECK_EQ(status, nnp_status_success); CHECK_EQ(status, nnp_status_success);
} }
cpu_api->FreeWorkspace(ctx, workspace_buffer);
}); });
TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference_without_weight_transform") TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference_without_weight_transform")
...@@ -147,6 +169,25 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference_without_weight_tra ...@@ -147,6 +169,25 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference_without_weight_tra
zero_bias.reset(new std::vector<float>(output->shape[1], 0.0)); zero_bias.reset(new std::vector<float>(output->shape[1], 0.0));
} }
size_t workspace_size = 0;
nnp_status status = nnp_convolution_inference(
algo, nnp_convolution_transform_strategy_reuse, input_channels,
output_channels, input_size, input_padding, kernel_size, stride_size,
nullptr, nullptr, nullptr, nullptr, nullptr, &workspace_size,
nnp_activation_identity, nullptr, entry->threadpool, nullptr);
CHECK_EQ(status, nnp_status_success);
// Division with rounding up, in case size is not multiple of sizeof(float)
const size_t workspace_elements = (workspace_size + sizeof(float) - 1) / sizeof(float);
TVMContext ctx = input->ctx;
TVMType type_hint = input->dtype;
DeviceAPI* cpu_api = DeviceAPI::Get(ctx);
void* workspace_buffer =
cpu_api->AllocWorkspace(ctx, workspace_elements * sizeof(float), type_hint);
CHECK(workspace_buffer != nullptr);
for (auto n = 0; n < input->shape[0]; ++n) { for (auto n = 0; n < input->shape[0]; ++n) {
nnp_status status = nnp_convolution_inference( nnp_status status = nnp_convolution_inference(
algo, nnp_convolution_transform_strategy_reuse, input_channels, output_channels, algo, nnp_convolution_transform_strategy_reuse, input_channels, output_channels,
...@@ -159,10 +200,12 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference_without_weight_tra ...@@ -159,10 +200,12 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference_without_weight_tra
static_cast<float *>(output->data) + n * output->shape[1] * static_cast<float *>(output->data) + n * output->shape[1] *
output->shape[2] * output->shape[2] *
output->shape[3], output->shape[3],
NULL, NULL, workspace_buffer, &workspace_size,
nnp_activation_identity, NULL, entry->threadpool, NULL); nnp_activation_identity, nullptr, entry->threadpool, nullptr);
CHECK_EQ(status, nnp_status_success); CHECK_EQ(status, nnp_status_success);
} }
cpu_api->FreeWorkspace(ctx, workspace_buffer);
}); });
TVM_REGISTER_GLOBAL( TVM_REGISTER_GLOBAL(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment