Commit 182a7852 by ziheng Committed by GitHub

[NNPACK] Add argument nthreads (#631)

parent 35485307
......@@ -16,7 +16,7 @@ def config(nthreads):
"""
_Config(nthreads)
def fully_connected_inference(lhs, rhs):
def fully_connected_inference(lhs, rhs, nthreads=1):
"""Create an extern op that compute fully connected of 1D tensor lhs and
2D tensor rhs with nnpack.
......@@ -37,9 +37,9 @@ def fully_connected_inference(lhs, rhs):
(m, ), [lhs, rhs],
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.nnpack.fully_connected_inference",
ins[0], ins[1], outs[0]), name="C")
ins[0], ins[1], outs[0], nthreads), name="C")
def fully_connected_output(lhs, rhs):
def fully_connected_output(lhs, rhs, nthreads=1):
"""Create an extern op that compute fully connected of 2D tensor lhs and
2D tensor rhs with nnpack.
......@@ -61,9 +61,9 @@ def fully_connected_output(lhs, rhs):
(n, m), [lhs, rhs],
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.nnpack.fully_connected_output",
ins[0], ins[1], outs[0]), name="C")
ins[0], ins[1], outs[0], nthreads), name="C")
def convolution_inference(data, kernel, bias, padding, stride):
def convolution_inference(data, kernel, bias, padding, stride, nthreads=1):
"""Create an extern op to do inference convolution of 3D tensor data and
4D tensor kernel and 1D tensor bias with nnpack.
......@@ -104,9 +104,9 @@ def convolution_inference(data, kernel, bias, padding, stride):
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.nnpack.convolution_inference", ins[0], ins[1], ins[2],
outs[0], padding[0], padding[1], padding[2], padding[3],
stride[0], stride[1]), name="C")
stride[0], stride[1], nthreads), name="C")
def convolution_output(data, kernel, bias, padding):
def convolution_output(data, kernel, bias, padding, nthreads=1):
"""Create an extern op to compute convolution of 4D tensor data and
4D tensor kernel and 1D tensor bias with nnpack.
......@@ -142,6 +142,6 @@ def convolution_output(data, kernel, bias, padding):
(batch, output_channels, output_height, output_width), [data, kernel, bias],
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.nnpack.convolution_output", ins[0], ins[1], ins[2],
outs[0], padding[0], padding[1], padding[2], padding[3]), name="C")
outs[0], padding[0], padding[1], padding[2], padding[3], nthreads), name="C")
_init_api("tvm.contrib.nnpack")
......@@ -24,6 +24,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference")
nnp_padding input_padding{pad_top, pad_right, pad_bottom, pad_left};
uint64_t stride_width = args[8], stride_height = args[9];
nnp_size stride_size{stride_width, stride_height};
NNPackConfig(args[10]);
CHECK_EQ(input->ndim, 3);
CHECK_EQ(kernel->ndim, 4);
......@@ -80,6 +81,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_output")
DLTensor* output = args[3];
uint64_t pad_top = args[4], pad_right = args[5], pad_bottom = args[6], pad_left = args[7];
nnp_padding input_padding{pad_top, pad_right, pad_bottom, pad_left};
NNPackConfig(args[8]);
CHECK_EQ(input->ndim, 4);
CHECK_EQ(kernel->ndim, 4);
......
......@@ -21,6 +21,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.fully_connected_inference")
DLTensor* A = args[0];
DLTensor* B = args[1];
DLTensor* C = args[2];
NNPackConfig(args[3]);
CHECK_EQ(A->ndim, 1);
CHECK_EQ(B->ndim, 2);
CHECK_EQ(C->ndim, 1);
......@@ -49,6 +51,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.fully_connected_output")
DLTensor* A = args[0];
DLTensor* B = args[1];
DLTensor* C = args[2];
NNPackConfig(args[3]);
CHECK_EQ(A->ndim, 2);
CHECK_EQ(B->ndim, 2);
CHECK_EQ(C->ndim, 2);
......
......@@ -14,10 +14,8 @@ NNPackThreadLocalEntry* NNPackThreadLocalEntry::ThreadLocal() {
return NNPackThreadLocalStore::Get();
}
TVM_REGISTER_GLOBAL("contrib.nnpack._Config")
.set_body([](TVMArgs args, TVMRetValue *ret) {
bool NNPackConfig(uint64_t nthreads) {
NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal();
size_t nthreads = args[0].operator uint64_t();
if (entry->threadpool != NULL &&
pthreadpool_get_threads_count(entry->threadpool) != nthreads) {
pthreadpool_destroy(entry->threadpool);
......@@ -26,6 +24,13 @@ TVM_REGISTER_GLOBAL("contrib.nnpack._Config")
if (entry->threadpool == NULL) {
entry->threadpool = pthreadpool_create(nthreads);
}
return true;
}
TVM_REGISTER_GLOBAL("contrib.nnpack._Config")
.set_body([](TVMArgs args, TVMRetValue *ret) {
CHECK(NNPackConfig(args[0]));
});
} // namespace contrib
} // namespace tvm
......@@ -18,6 +18,8 @@ struct NNPackThreadLocalEntry {
pthreadpool_t threadpool{NULL};
static NNPackThreadLocalEntry* ThreadLocal();
};
bool NNPackConfig(uint64_t nthreads);
} // namespace contrib
} // namespace tvm
#endif // TVM_CONTRIB_NNPACK_NNPACK_UTILS_H_
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment