Commit 182a7852 by ziheng Committed by GitHub

[NNPACK] Add argument nthreads (#631)

parent 35485307
...@@ -16,7 +16,7 @@ def config(nthreads): ...@@ -16,7 +16,7 @@ def config(nthreads):
""" """
_Config(nthreads) _Config(nthreads)
def fully_connected_inference(lhs, rhs): def fully_connected_inference(lhs, rhs, nthreads=1):
"""Create an extern op that compute fully connected of 1D tensor lhs and """Create an extern op that compute fully connected of 1D tensor lhs and
2D tensor rhs with nnpack. 2D tensor rhs with nnpack.
...@@ -37,9 +37,9 @@ def fully_connected_inference(lhs, rhs): ...@@ -37,9 +37,9 @@ def fully_connected_inference(lhs, rhs):
(m, ), [lhs, rhs], (m, ), [lhs, rhs],
lambda ins, outs: _intrin.call_packed( lambda ins, outs: _intrin.call_packed(
"tvm.contrib.nnpack.fully_connected_inference", "tvm.contrib.nnpack.fully_connected_inference",
ins[0], ins[1], outs[0]), name="C") ins[0], ins[1], outs[0], nthreads), name="C")
def fully_connected_output(lhs, rhs): def fully_connected_output(lhs, rhs, nthreads=1):
"""Create an extern op that compute fully connected of 2D tensor lhs and """Create an extern op that compute fully connected of 2D tensor lhs and
2D tensor rhs with nnpack. 2D tensor rhs with nnpack.
...@@ -61,9 +61,9 @@ def fully_connected_output(lhs, rhs): ...@@ -61,9 +61,9 @@ def fully_connected_output(lhs, rhs):
(n, m), [lhs, rhs], (n, m), [lhs, rhs],
lambda ins, outs: _intrin.call_packed( lambda ins, outs: _intrin.call_packed(
"tvm.contrib.nnpack.fully_connected_output", "tvm.contrib.nnpack.fully_connected_output",
ins[0], ins[1], outs[0]), name="C") ins[0], ins[1], outs[0], nthreads), name="C")
def convolution_inference(data, kernel, bias, padding, stride): def convolution_inference(data, kernel, bias, padding, stride, nthreads=1):
"""Create an extern op to do inference convolution of 3D tensor data and """Create an extern op to do inference convolution of 3D tensor data and
4D tensor kernel and 1D tensor bias with nnpack. 4D tensor kernel and 1D tensor bias with nnpack.
...@@ -104,9 +104,9 @@ def convolution_inference(data, kernel, bias, padding, stride): ...@@ -104,9 +104,9 @@ def convolution_inference(data, kernel, bias, padding, stride):
lambda ins, outs: _intrin.call_packed( lambda ins, outs: _intrin.call_packed(
"tvm.contrib.nnpack.convolution_inference", ins[0], ins[1], ins[2], "tvm.contrib.nnpack.convolution_inference", ins[0], ins[1], ins[2],
outs[0], padding[0], padding[1], padding[2], padding[3], outs[0], padding[0], padding[1], padding[2], padding[3],
stride[0], stride[1]), name="C") stride[0], stride[1], nthreads), name="C")
def convolution_output(data, kernel, bias, padding): def convolution_output(data, kernel, bias, padding, nthreads=1):
"""Create an extern op to compute convolution of 4D tensor data and """Create an extern op to compute convolution of 4D tensor data and
4D tensor kernel and 1D tensor bias with nnpack. 4D tensor kernel and 1D tensor bias with nnpack.
...@@ -142,6 +142,6 @@ def convolution_output(data, kernel, bias, padding): ...@@ -142,6 +142,6 @@ def convolution_output(data, kernel, bias, padding):
(batch, output_channels, output_height, output_width), [data, kernel, bias], (batch, output_channels, output_height, output_width), [data, kernel, bias],
lambda ins, outs: _intrin.call_packed( lambda ins, outs: _intrin.call_packed(
"tvm.contrib.nnpack.convolution_output", ins[0], ins[1], ins[2], "tvm.contrib.nnpack.convolution_output", ins[0], ins[1], ins[2],
outs[0], padding[0], padding[1], padding[2], padding[3]), name="C") outs[0], padding[0], padding[1], padding[2], padding[3], nthreads), name="C")
_init_api("tvm.contrib.nnpack") _init_api("tvm.contrib.nnpack")
...@@ -24,6 +24,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference") ...@@ -24,6 +24,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference")
nnp_padding input_padding{pad_top, pad_right, pad_bottom, pad_left}; nnp_padding input_padding{pad_top, pad_right, pad_bottom, pad_left};
uint64_t stride_width = args[8], stride_height = args[9]; uint64_t stride_width = args[8], stride_height = args[9];
nnp_size stride_size{stride_width, stride_height}; nnp_size stride_size{stride_width, stride_height};
NNPackConfig(args[10]);
CHECK_EQ(input->ndim, 3); CHECK_EQ(input->ndim, 3);
CHECK_EQ(kernel->ndim, 4); CHECK_EQ(kernel->ndim, 4);
...@@ -80,6 +81,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_output") ...@@ -80,6 +81,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_output")
DLTensor* output = args[3]; DLTensor* output = args[3];
uint64_t pad_top = args[4], pad_right = args[5], pad_bottom = args[6], pad_left = args[7]; uint64_t pad_top = args[4], pad_right = args[5], pad_bottom = args[6], pad_left = args[7];
nnp_padding input_padding{pad_top, pad_right, pad_bottom, pad_left}; nnp_padding input_padding{pad_top, pad_right, pad_bottom, pad_left};
NNPackConfig(args[8]);
CHECK_EQ(input->ndim, 4); CHECK_EQ(input->ndim, 4);
CHECK_EQ(kernel->ndim, 4); CHECK_EQ(kernel->ndim, 4);
......
...@@ -21,6 +21,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.fully_connected_inference") ...@@ -21,6 +21,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.fully_connected_inference")
DLTensor* A = args[0]; DLTensor* A = args[0];
DLTensor* B = args[1]; DLTensor* B = args[1];
DLTensor* C = args[2]; DLTensor* C = args[2];
NNPackConfig(args[3]);
CHECK_EQ(A->ndim, 1); CHECK_EQ(A->ndim, 1);
CHECK_EQ(B->ndim, 2); CHECK_EQ(B->ndim, 2);
CHECK_EQ(C->ndim, 1); CHECK_EQ(C->ndim, 1);
...@@ -49,6 +51,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.fully_connected_output") ...@@ -49,6 +51,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.fully_connected_output")
DLTensor* A = args[0]; DLTensor* A = args[0];
DLTensor* B = args[1]; DLTensor* B = args[1];
DLTensor* C = args[2]; DLTensor* C = args[2];
NNPackConfig(args[3]);
CHECK_EQ(A->ndim, 2); CHECK_EQ(A->ndim, 2);
CHECK_EQ(B->ndim, 2); CHECK_EQ(B->ndim, 2);
CHECK_EQ(C->ndim, 2); CHECK_EQ(C->ndim, 2);
......
...@@ -14,18 +14,23 @@ NNPackThreadLocalEntry* NNPackThreadLocalEntry::ThreadLocal() { ...@@ -14,18 +14,23 @@ NNPackThreadLocalEntry* NNPackThreadLocalEntry::ThreadLocal() {
return NNPackThreadLocalStore::Get(); return NNPackThreadLocalStore::Get();
} }
bool NNPackConfig(uint64_t nthreads) {
NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal();
if (entry->threadpool != NULL &&
pthreadpool_get_threads_count(entry->threadpool) != nthreads) {
pthreadpool_destroy(entry->threadpool);
entry->threadpool = NULL;
}
if (entry->threadpool == NULL) {
entry->threadpool = pthreadpool_create(nthreads);
}
return true;
}
TVM_REGISTER_GLOBAL("contrib.nnpack._Config") TVM_REGISTER_GLOBAL("contrib.nnpack._Config")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal(); CHECK(NNPackConfig(args[0]));
size_t nthreads = args[0].operator uint64_t();
if (entry->threadpool != NULL &&
pthreadpool_get_threads_count(entry->threadpool) != nthreads) {
pthreadpool_destroy(entry->threadpool);
entry->threadpool = NULL;
}
if (entry->threadpool == NULL) {
entry->threadpool = pthreadpool_create(nthreads);
}
}); });
} // namespace contrib } // namespace contrib
} // namespace tvm } // namespace tvm
...@@ -18,6 +18,8 @@ struct NNPackThreadLocalEntry { ...@@ -18,6 +18,8 @@ struct NNPackThreadLocalEntry {
pthreadpool_t threadpool{NULL}; pthreadpool_t threadpool{NULL};
static NNPackThreadLocalEntry* ThreadLocal(); static NNPackThreadLocalEntry* ThreadLocal();
}; };
bool NNPackConfig(uint64_t nthreads);
} // namespace contrib } // namespace contrib
} // namespace tvm } // namespace tvm
#endif // TVM_CONTRIB_NNPACK_NNPACK_UTILS_H_ #endif // TVM_CONTRIB_NNPACK_NNPACK_UTILS_H_
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment