Commit 25ded693 by ziheng Committed by Tianqi Chen

[NNPack] Support for threadpool (#334)

* [NNPack] Support for threadpool

* fix lint

* fix lint

* Use static class function
parent 989e99e6
...@@ -3,6 +3,18 @@ from __future__ import absolute_import as _abs ...@@ -3,6 +3,18 @@ from __future__ import absolute_import as _abs
from .. import api as _api from .. import api as _api
from .. import intrin as _intrin from .. import intrin as _intrin
from .._ffi.function import _init_api
def config(nthreads):
"""Configure the nnpack library.
Parameters
----------
nthreads : int
The threads number of nnpack thread pool, must be a nonnegative.
"""
_Config(nthreads)
def fully_connected_inference(lhs, rhs): def fully_connected_inference(lhs, rhs):
"""Create an extern op that compute fully connected of 1D tensor lhs and """Create an extern op that compute fully connected of 1D tensor lhs and
...@@ -84,8 +96,8 @@ def convolution_inference(data, kernel, bias, padding, stride): ...@@ -84,8 +96,8 @@ def convolution_inference(data, kernel, bias, padding, stride):
assert isinstance(stride, list) and len(stride) == 2 assert isinstance(stride, list) and len(stride) == 2
_, input_height, input_width = data.shape _, input_height, input_width = data.shape
output_channels, _, kernel_height, kernel_width = kernel.shape output_channels, _, kernel_height, kernel_width = kernel.shape
output_height = (input_height + padding[0] + padding[1] - kernel_height) + 1 output_height = (input_height + padding[0] + padding[1] - kernel_height) / stride[0] + 1
output_width = (input_width + padding[0] + padding[1] - kernel_width) + 1 output_width = (input_width + padding[0] + padding[1] - kernel_width) / stride[1] + 1
return _api.extern( return _api.extern(
(output_channels, output_height, output_width), [data, kernel, bias], (output_channels, output_height, output_width), [data, kernel, bias],
...@@ -131,3 +143,5 @@ def convolution_output(data, kernel, bias, padding): ...@@ -131,3 +143,5 @@ def convolution_output(data, kernel, bias, padding):
lambda ins, outs: _intrin.call_packed( lambda ins, outs: _intrin.call_packed(
"tvm.contrib.nnpack.convolution_output", ins[0], ins[1], ins[2], "tvm.contrib.nnpack.convolution_output", ins[0], ins[1], ins[2],
outs[0], padding[0], padding[1], padding[2], padding[3]), name="C") outs[0], padding[0], padding[1], padding[2], padding[3]), name="C")
_init_api("tvm.contrib.nnpack")
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <tvm/runtime/util.h> #include <tvm/runtime/util.h>
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <nnpack.h> #include <nnpack.h>
#include "./nnpack_utils.h"
namespace tvm { namespace tvm {
namespace contrib { namespace contrib {
...@@ -13,6 +14,7 @@ using namespace runtime; ...@@ -13,6 +14,7 @@ using namespace runtime;
TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference") TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal();
nnp_initialize(); nnp_initialize();
DLTensor* input = args[0]; DLTensor* input = args[0];
DLTensor* kernel = args[1]; DLTensor* kernel = args[1];
...@@ -61,13 +63,14 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference") ...@@ -61,13 +63,14 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference")
NULL, NULL,
nnp_activation_identity, nnp_activation_identity,
NULL, NULL,
NULL, entry->threadpool,
NULL); NULL);
}); });
TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_output") TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_output")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal();
nnp_initialize(); nnp_initialize();
DLTensor* input = args[0]; DLTensor* input = args[0];
DLTensor* kernel = args[1]; DLTensor* kernel = args[1];
...@@ -112,9 +115,11 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_output") ...@@ -112,9 +115,11 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_output")
static_cast<float*>(kernel->data), static_cast<float*>(kernel->data),
static_cast<float*>(bias->data), static_cast<float*>(bias->data),
static_cast<float*>(output->data), static_cast<float*>(output->data),
nnp_activation_identity,
NULL, NULL,
NULL, NULL,
nnp_activation_identity,
NULL,
entry->threadpool,
NULL); NULL);
}); });
} // namespace contrib } // namespace contrib
......
...@@ -5,8 +5,8 @@ ...@@ -5,8 +5,8 @@
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h> #include <tvm/runtime/util.h>
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <nnpack.h> #include <nnpack.h>
#include "./nnpack_utils.h"
namespace tvm { namespace tvm {
namespace contrib { namespace contrib {
...@@ -16,6 +16,7 @@ using namespace runtime; ...@@ -16,6 +16,7 @@ using namespace runtime;
// matrix multiplication for row major // matrix multiplication for row major
TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.fully_connected_inference") TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.fully_connected_inference")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal();
nnp_initialize(); nnp_initialize();
DLTensor* A = args[0]; DLTensor* A = args[0];
DLTensor* B = args[1]; DLTensor* B = args[1];
...@@ -37,12 +38,13 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.fully_connected_inference") ...@@ -37,12 +38,13 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.fully_connected_inference")
static_cast<float*>(A->data), static_cast<float*>(A->data),
static_cast<float*>(B->data), static_cast<float*>(B->data),
static_cast<float*>(C->data), static_cast<float*>(C->data),
NULL); entry->threadpool);
}); });
TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.fully_connected_output") TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.fully_connected_output")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal();
nnp_initialize(); nnp_initialize();
DLTensor* A = args[0]; DLTensor* A = args[0];
DLTensor* B = args[1]; DLTensor* B = args[1];
...@@ -66,7 +68,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.fully_connected_output") ...@@ -66,7 +68,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.fully_connected_output")
static_cast<float*>(A->data), static_cast<float*>(A->data),
static_cast<float*>(B->data), static_cast<float*>(B->data),
static_cast<float*>(C->data), static_cast<float*>(C->data),
NULL, entry->threadpool,
NULL); NULL);
}); });
......
/*!
* Copyright (c) 2017 by Contributors
* \file Use external nnpack library call.
*/
#include "./nnpack_utils.h"
namespace tvm {
namespace contrib {
using namespace runtime;
typedef dmlc::ThreadLocalStore<NNPackThreadLocalEntry> NNPackThreadLocalStore;
NNPackThreadLocalEntry* NNPackThreadLocalEntry::ThreadLocal() {
return NNPackThreadLocalStore::Get();
}
TVM_REGISTER_GLOBAL("contrib.nnpack._Config")
.set_body([](TVMArgs args, TVMRetValue *ret) {
NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal();
size_t nthreads = args[0];
if (entry->threadpool != NULL &&
pthreadpool_get_threads_count(entry->threadpool) != nthreads) {
pthreadpool_destroy(entry->threadpool);
entry->threadpool = NULL;
}
if (entry->threadpool == NULL) {
entry->threadpool = pthreadpool_create(nthreads);
}
});
} // namespace contrib
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file Use external nnpack library call.
*/
#ifndef TVM_CONTRIB_NNPACK_NNPACK_UTILS_H_
#define TVM_CONTRIB_NNPACK_NNPACK_UTILS_H_
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <dmlc/thread_local.h>
#include <dmlc/logging.h>
#include <nnpack.h>
namespace tvm {
namespace contrib {
using namespace runtime;
struct NNPackThreadLocalEntry {
pthreadpool_t threadpool{NULL};
static NNPackThreadLocalEntry* ThreadLocal();
};
} // namespace contrib
} // namespace tvm
#endif // TVM_CONTRIB_NNPACK_NNPACK_UTILS_H_
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment