Commit c9a3e2ea by hlu1 Committed by Tianqi Chen

[nnpack] Preallocate workspace buffer (#2369)

parent a5eb4451
......@@ -2,6 +2,7 @@
* Copyright (c) 2017 by Contributors
* \file Use external nnpack library call.
*/
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <dmlc/logging.h>
......@@ -72,6 +73,25 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference")
zero_bias.reset(new std::vector<float>(output->shape[1], 0.0));
}
size_t workspace_size = 0;
nnp_status status = nnp_convolution_inference(
algo, nnp_convolution_transform_strategy_compute, input_channels,
output_channels, input_size, input_padding, kernel_size, stride_size,
nullptr, nullptr, nullptr, nullptr, nullptr, &workspace_size,
nnp_activation_identity, nullptr, entry->threadpool, nullptr);
CHECK_EQ(status, nnp_status_success);
// Division with rounding up, in case size is not multiple of sizeof(float)
const size_t workspace_elements = (workspace_size + sizeof(float) - 1) / sizeof(float);
TVMContext ctx = input->ctx;
TVMType type_hint = input->dtype;
DeviceAPI* cpu_api = DeviceAPI::Get(ctx);
void* workspace_buffer =
cpu_api->AllocWorkspace(ctx, workspace_elements * sizeof(float), type_hint);
CHECK(workspace_buffer != nullptr);
for (auto n = 0; n < input->shape[0]; ++n) {
nnp_status status = nnp_convolution_inference(
algo, nnp_convolution_transform_strategy_compute, input_channels,
......@@ -85,10 +105,12 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference")
static_cast<float *>(output->data) + n * output->shape[1] *
output->shape[2] *
output->shape[3],
NULL, NULL, nnp_activation_identity, NULL, entry->threadpool, NULL);
workspace_buffer, &workspace_size,
nnp_activation_identity, nullptr, entry->threadpool, nullptr);
CHECK_EQ(status, nnp_status_success);
}
cpu_api->FreeWorkspace(ctx, workspace_buffer);
});
TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference_without_weight_transform")
......@@ -147,6 +169,25 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference_without_weight_tra
zero_bias.reset(new std::vector<float>(output->shape[1], 0.0));
}
size_t workspace_size = 0;
nnp_status status = nnp_convolution_inference(
algo, nnp_convolution_transform_strategy_reuse, input_channels,
output_channels, input_size, input_padding, kernel_size, stride_size,
nullptr, nullptr, nullptr, nullptr, nullptr, &workspace_size,
nnp_activation_identity, nullptr, entry->threadpool, nullptr);
CHECK_EQ(status, nnp_status_success);
// Division with rounding up, in case size is not multiple of sizeof(float)
const size_t workspace_elements = (workspace_size + sizeof(float) - 1) / sizeof(float);
TVMContext ctx = input->ctx;
TVMType type_hint = input->dtype;
DeviceAPI* cpu_api = DeviceAPI::Get(ctx);
void* workspace_buffer =
cpu_api->AllocWorkspace(ctx, workspace_elements * sizeof(float), type_hint);
CHECK(workspace_buffer != nullptr);
for (auto n = 0; n < input->shape[0]; ++n) {
nnp_status status = nnp_convolution_inference(
algo, nnp_convolution_transform_strategy_reuse, input_channels, output_channels,
......@@ -159,10 +200,12 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference_without_weight_tra
static_cast<float *>(output->data) + n * output->shape[1] *
output->shape[2] *
output->shape[3],
NULL, NULL,
nnp_activation_identity, NULL, entry->threadpool, NULL);
workspace_buffer, &workspace_size,
nnp_activation_identity, nullptr, entry->threadpool, nullptr);
CHECK_EQ(status, nnp_status_success);
}
cpu_api->FreeWorkspace(ctx, workspace_buffer);
});
TVM_REGISTER_GLOBAL(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment