Commit 25ded693 by ziheng Committed by Tianqi Chen

[NNPack] Support for threadpool (#334)

* [NNPack] Support for threadpool

* fix lint

* fix lint

* Use static class function
parent 989e99e6
......@@ -3,6 +3,18 @@ from __future__ import absolute_import as _abs
from .. import api as _api
from .. import intrin as _intrin
from .._ffi.function import _init_api
def config(nthreads):
"""Configure the nnpack library.
Parameters
----------
nthreads : int
The threads number of nnpack thread pool, must be a nonnegative.
"""
_Config(nthreads)
def fully_connected_inference(lhs, rhs):
"""Create an extern op that compute fully connected of 1D tensor lhs and
......@@ -84,8 +96,8 @@ def convolution_inference(data, kernel, bias, padding, stride):
assert isinstance(stride, list) and len(stride) == 2
_, input_height, input_width = data.shape
output_channels, _, kernel_height, kernel_width = kernel.shape
output_height = (input_height + padding[0] + padding[1] - kernel_height) + 1
output_width = (input_width + padding[0] + padding[1] - kernel_width) + 1
output_height = (input_height + padding[0] + padding[1] - kernel_height) / stride[0] + 1
output_width = (input_width + padding[0] + padding[1] - kernel_width) / stride[1] + 1
return _api.extern(
(output_channels, output_height, output_width), [data, kernel, bias],
......@@ -131,3 +143,5 @@ def convolution_output(data, kernel, bias, padding):
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.nnpack.convolution_output", ins[0], ins[1], ins[2],
outs[0], padding[0], padding[1], padding[2], padding[3]), name="C")
_init_api("tvm.contrib.nnpack")
......@@ -6,6 +6,7 @@
#include <tvm/runtime/util.h>
#include <dmlc/logging.h>
#include <nnpack.h>
#include "./nnpack_utils.h"
namespace tvm {
namespace contrib {
......@@ -13,6 +14,7 @@ using namespace runtime;
TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference")
.set_body([](TVMArgs args, TVMRetValue *ret) {
NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal();
nnp_initialize();
DLTensor* input = args[0];
DLTensor* kernel = args[1];
......@@ -61,13 +63,14 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference")
NULL,
nnp_activation_identity,
NULL,
NULL,
entry->threadpool,
NULL);
});
TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_output")
.set_body([](TVMArgs args, TVMRetValue *ret) {
NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal();
nnp_initialize();
DLTensor* input = args[0];
DLTensor* kernel = args[1];
......@@ -112,9 +115,11 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_output")
static_cast<float*>(kernel->data),
static_cast<float*>(bias->data),
static_cast<float*>(output->data),
nnp_activation_identity,
NULL,
NULL,
nnp_activation_identity,
NULL,
entry->threadpool,
NULL);
});
} // namespace contrib
......
......@@ -5,8 +5,8 @@
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <dmlc/logging.h>
#include <nnpack.h>
#include "./nnpack_utils.h"
namespace tvm {
namespace contrib {
......@@ -16,6 +16,7 @@ using namespace runtime;
// matrix multiplication for row major
TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.fully_connected_inference")
.set_body([](TVMArgs args, TVMRetValue *ret) {
NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal();
nnp_initialize();
DLTensor* A = args[0];
DLTensor* B = args[1];
......@@ -37,12 +38,13 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.fully_connected_inference")
static_cast<float*>(A->data),
static_cast<float*>(B->data),
static_cast<float*>(C->data),
NULL);
entry->threadpool);
});
TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.fully_connected_output")
.set_body([](TVMArgs args, TVMRetValue *ret) {
NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal();
nnp_initialize();
DLTensor* A = args[0];
DLTensor* B = args[1];
......@@ -66,7 +68,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.fully_connected_output")
static_cast<float*>(A->data),
static_cast<float*>(B->data),
static_cast<float*>(C->data),
NULL,
entry->threadpool,
NULL);
});
......
/*!
* Copyright (c) 2017 by Contributors
* \file Use external nnpack library call.
*/
#include "./nnpack_utils.h"
namespace tvm {
namespace contrib {
using namespace runtime;
typedef dmlc::ThreadLocalStore<NNPackThreadLocalEntry> NNPackThreadLocalStore;
NNPackThreadLocalEntry* NNPackThreadLocalEntry::ThreadLocal() {
return NNPackThreadLocalStore::Get();
}
TVM_REGISTER_GLOBAL("contrib.nnpack._Config")
.set_body([](TVMArgs args, TVMRetValue *ret) {
NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal();
size_t nthreads = args[0];
if (entry->threadpool != NULL &&
pthreadpool_get_threads_count(entry->threadpool) != nthreads) {
pthreadpool_destroy(entry->threadpool);
entry->threadpool = NULL;
}
if (entry->threadpool == NULL) {
entry->threadpool = pthreadpool_create(nthreads);
}
});
} // namespace contrib
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file Use external nnpack library call.
*/
#ifndef TVM_CONTRIB_NNPACK_NNPACK_UTILS_H_
#define TVM_CONTRIB_NNPACK_NNPACK_UTILS_H_
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <dmlc/thread_local.h>
#include <dmlc/logging.h>
#include <nnpack.h>
namespace tvm {
namespace contrib {
using namespace runtime;
struct NNPackThreadLocalEntry {
pthreadpool_t threadpool{NULL};
static NNPackThreadLocalEntry* ThreadLocal();
};
} // namespace contrib
} // namespace tvm
#endif // TVM_CONTRIB_NNPACK_NNPACK_UTILS_H_
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment