Commit fc83c7f2 by Andrew Tulloch Committed by Tianqi Chen

[TVM] [NNPACK] Modernize and improve NNPACK bindings (#2084)

parent 9f441d81
...@@ -9,6 +9,10 @@ if(USE_NNPACK) ...@@ -9,6 +9,10 @@ if(USE_NNPACK)
include_directories(${PTHREAD_POOL_PATH}/include) include_directories(${PTHREAD_POOL_PATH}/include)
find_library(NNPACK_CONTRIB_LIB nnpack ${NNPACK_PATH}/lib) find_library(NNPACK_CONTRIB_LIB nnpack ${NNPACK_PATH}/lib)
find_library(NNPACK_PTHREAD_CONTRIB_LIB pthreadpool ${NNPACK_PATH}/lib) find_library(NNPACK_PTHREAD_CONTRIB_LIB pthreadpool ${NNPACK_PATH}/lib)
find_library(NNPACK_CPUINFO_CONTRIB_LIB cpuinfo ${NNPACK_PATH}/lib)
find_library(NNPACK_CLOG_CONTRIB_LIB clog ${NNPACK_PATH}/lib)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${NNPACK_CONTRIB_LIB}) list(APPEND TVM_RUNTIME_LINKER_LIBS ${NNPACK_CONTRIB_LIB})
list(APPEND TVM_RUNTIME_LINKER_LIBS ${NNPACK_PTHREAD_CONTRIB_LIB}) list(APPEND TVM_RUNTIME_LINKER_LIBS ${NNPACK_PTHREAD_CONTRIB_LIB})
list(APPEND TVM_RUNTIME_LINKER_LIBS ${NNPACK_CPUINFO_CONTRIB_LIB})
list(APPEND TVM_RUNTIME_LINKER_LIBS ${NNPACK_CLOG_CONTRIB_LIB})
endif(USE_NNPACK) endif(USE_NNPACK)
...@@ -63,14 +63,32 @@ def fully_connected_output(lhs, rhs, nthreads=1): ...@@ -63,14 +63,32 @@ def fully_connected_output(lhs, rhs, nthreads=1):
"tvm.contrib.nnpack.fully_connected_output", "tvm.contrib.nnpack.fully_connected_output",
ins[0], ins[1], outs[0], nthreads), name="C") ins[0], ins[1], outs[0], nthreads), name="C")
def convolution_inference(data, kernel, bias, padding, stride, nthreads=1):
"""Create an extern op to do inference convolution of 3D tensor data and class ConvolutionAlgorithm:
AUTO = 0
FFT_8x8 = 1
FFT_16x16 = 2
WT_8x8 = 3
IMPLICIT_GEMM = 4
DIRECT = 5
WT_8x8_FP16 = 6
class ConvolutionTransformStrategy:
COMPUTE = 1
PRECOMPUTE = 2
def convolution_inference(
data, kernel, bias, padding, stride, nthreads=1,
algorithm=ConvolutionAlgorithm.AUTO):
"""Create an extern op to do inference convolution of 4D tensor data and
4D tensor kernel and 1D tensor bias with nnpack. 4D tensor kernel and 1D tensor bias with nnpack.
Parameters Parameters
---------- ----------
data : Tensor data : Tensor
data 3D tensor input[input_channels][input_height][input_width] of data 4D tensor input[batch][input_channels][input_height][input_width] of
FP32 elements. FP32 elements.
kernel : Tensor kernel : Tensor
kernel 4D tensor kernel[output_channels][input_channels][kernel_height] kernel 4D tensor kernel[output_channels][input_channels][kernel_height]
...@@ -88,23 +106,108 @@ def convolution_inference(data, kernel, bias, padding, stride, nthreads=1): ...@@ -88,23 +106,108 @@ def convolution_inference(data, kernel, bias, padding, stride, nthreads=1):
Returns Returns
------- -------
output : Tensor output : Tensor
output 3D tensor output[output_channels][output_height][output_width] output 4D tensor output[batch][output_channels][output_height][output_width]
of FP32 elements. of FP32 elements.
""" """
assert isinstance(padding, list) and len(padding) == 4 assert isinstance(padding, list) and len(padding) == 4
assert isinstance(stride, list) and len(stride) == 2 assert isinstance(stride, list) and len(stride) == 2
_, input_height, input_width = data.shape batch, _, input_height, input_width = data.shape
output_channels, _, kernel_height, kernel_width = kernel.shape output_channels, _, kernel_height, kernel_width = kernel.shape
output_height = (input_height + padding[0] + padding[1] - kernel_height) / stride[0] + 1 output_height = (input_height + padding[0] + padding[1] - kernel_height) / stride[0] + 1
output_width = (input_width + padding[0] + padding[1] - kernel_width) / stride[1] + 1 output_width = (input_width + padding[0] + padding[1] - kernel_width) / stride[1] + 1
return _api.extern( return _api.extern(
(output_channels, output_height, output_width), [data, kernel, bias], (batch, output_channels, output_height, output_width),
[data, kernel, bias] if bias is not None else [data, kernel],
lambda ins, outs: _intrin.call_packed( lambda ins, outs: _intrin.call_packed(
"tvm.contrib.nnpack.convolution_inference", ins[0], ins[1], ins[2], "tvm.contrib.nnpack.convolution_inference",
ins[0],
ins[1],
ins[2] if bias is not None else 0,
outs[0], padding[0], padding[1], padding[2], padding[3], outs[0], padding[0], padding[1], padding[2], padding[3],
stride[0], stride[1], nthreads), name="C") stride[0], stride[1], nthreads, algorithm), name="C")
def convolution_inference_without_weight_transform(
data, transformed_kernel, bias, padding, stride, nthreads=1,
algorithm=ConvolutionAlgorithm.AUTO):
"""Create an extern op to do inference convolution of 4D tensor data and
4D pre-transformed tensor kernel and 1D tensor bias with nnpack.
Parameters
----------
data : Tensor
data 4D tensor input[batch][input_channels][input_height][input_width] of
FP32 elements.
transformed_kernel : Tensor
transformed_kernel 4D tensor kernel[output_channels][input_channels][tile]
[tile] of FP32 elements.
bias : Tensor
bias 1D array bias[output_channels][input_channels][kernel_height]
[kernel_width] of FP32 elements.
padding : list
padding A 4-dim list of [pad_top, pad_bottom, pad_left, pad_right],
which indicates the padding around the feature map.
stride : list
stride A 2-dim list of [stride_height, stride_width], which indicates
the stride.
Returns
-------
output : Tensor
output 4D tensor output[batch][output_channels][output_height][output_width]
of FP32 elements.
"""
assert algorithm in (ConvolutionAlgorithm.WT_8x8,
ConvolutionAlgorithm.WT_8x8_FP16)
assert isinstance(padding, list) and len(padding) == 4
assert isinstance(stride, list) and len(stride) == 2
batch, _, input_height, input_width = data.shape
output_channels, _, _, _ = transformed_kernel.shape
kernel_height, kernel_width = (3, 3)
output_height = (input_height + padding[0] + padding[1] - kernel_height) / stride[0] + 1
output_width = (input_width + padding[0] + padding[1] - kernel_width) / stride[1] + 1
return _api.extern(
(batch, output_channels, output_height, output_width),
[data, transformed_kernel, bias] if bias is not None else [data, transformed_kernel],
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.nnpack.convolution_inference_without_weight_transform",
ins[0],
ins[1],
ins[2] if bias is not None else 0,
outs[0], padding[0], padding[1], padding[2], padding[3],
stride[0], stride[1], nthreads, algorithm), name="C")
def convolution_inference_weight_transform(
kernel, nthreads=1,
algorithm=ConvolutionAlgorithm.AUTO):
"""Create an extern op to do inference convolution of 3D tensor data and
4D tensor kernel and 1D tensor bias with nnpack.
Parameters
----------
kernel : Tensor
kernel 4D tensor kernel[output_channels][input_channels][kernel_height]
[kernel_width] of FP32 elements.
Returns
-------
output : Tensor
output 4D tensor output[output_channels][input_channels][tile][tile]
of FP32 elements.
"""
assert algorithm in (ConvolutionAlgorithm.WT_8x8, ConvolutionAlgorithm.WT_8x8_FP16)
output_channels, input_channels, _, _ = kernel.shape
transform_tile_size = 8
return _api.extern(
(output_channels, input_channels, transform_tile_size, transform_tile_size),
[kernel],
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.nnpack.convolution_inference_weight_transform",
ins[0], outs[0], nthreads, algorithm), name="transform_kernel")
def convolution_output(data, kernel, bias, padding, nthreads=1): def convolution_output(data, kernel, bias, padding, nthreads=1):
"""Create an extern op to compute convolution of 4D tensor data and """Create an extern op to compute convolution of 4D tensor data and
...@@ -144,4 +247,5 @@ def convolution_output(data, kernel, bias, padding, nthreads=1): ...@@ -144,4 +247,5 @@ def convolution_output(data, kernel, bias, padding, nthreads=1):
"tvm.contrib.nnpack.convolution_output", ins[0], ins[1], ins[2], "tvm.contrib.nnpack.convolution_output", ins[0], ins[1], ins[2],
outs[0], padding[0], padding[1], padding[2], padding[3], nthreads), name="C") outs[0], padding[0], padding[1], padding[2], padding[3], nthreads), name="C")
_init_api("tvm.contrib.nnpack") _init_api("tvm.contrib.nnpack")
...@@ -10,20 +10,30 @@ using namespace runtime; ...@@ -10,20 +10,30 @@ using namespace runtime;
typedef dmlc::ThreadLocalStore<NNPackThreadLocalEntry> NNPackThreadLocalStore; typedef dmlc::ThreadLocalStore<NNPackThreadLocalEntry> NNPackThreadLocalStore;
NNPackThreadLocalEntry* NNPackThreadLocalEntry::ThreadLocal() { NNPackThreadLocalEntry* NNPackThreadLocalEntry::ThreadLocal() {
return NNPackThreadLocalStore::Get(); return NNPackThreadLocalStore::Get();
} }
bool NNPackConfig(uint64_t nthreads) { bool NNPackConfig(uint64_t nthreads) {
NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal(); NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal();
if (entry->threadpool != NULL && if (entry->threadpool && pthreadpool_get_threads_count(entry->threadpool) == nthreads) {
pthreadpool_get_threads_count(entry->threadpool) != nthreads) { CHECK_NE(nthreads, 1);
return true;
}
if (entry->threadpool) {
pthreadpool_destroy(entry->threadpool); pthreadpool_destroy(entry->threadpool);
entry->threadpool = NULL; entry->threadpool = nullptr;
} }
if (entry->threadpool == NULL) {
entry->threadpool = pthreadpool_create(nthreads); if (nthreads == 1) {
// a null threadpool means the function is invoked on the calling thread,
// which is the desired logic for nthreads == 1
CHECK(!entry->threadpool);
return true;
} }
entry->threadpool = pthreadpool_create(nthreads);
return true; return true;
} }
......
...@@ -15,7 +15,7 @@ namespace contrib { ...@@ -15,7 +15,7 @@ namespace contrib {
using namespace runtime; using namespace runtime;
struct NNPackThreadLocalEntry { struct NNPackThreadLocalEntry {
pthreadpool_t threadpool{NULL}; pthreadpool_t threadpool{nullptr};
static NNPackThreadLocalEntry* ThreadLocal(); static NNPackThreadLocalEntry* ThreadLocal();
}; };
......
...@@ -290,10 +290,10 @@ variable-rgx=[a-z_][a-z0-9_]{2,30}$ ...@@ -290,10 +290,10 @@ variable-rgx=[a-z_][a-z0-9_]{2,30}$
variable-name-hint=[a-z_][a-z0-9_]{2,30}$ variable-name-hint=[a-z_][a-z0-9_]{2,30}$
# Regular expression matching correct function names # Regular expression matching correct function names
function-rgx=[a-z_][a-z0-9_]{2,30}$ function-rgx=[a-z_][a-z0-9_]{2,48}$
# Naming hint for function names # Naming hint for function names
function-name-hint=[a-z_][a-z0-9_]{2,30}$ function-name-hint=[a-z_][a-z0-9_]{2,48}$
# Regular expression matching correct class names # Regular expression matching correct class names
class-rgx=[A-Z_][a-zA-Z0-9]+$ class-rgx=[A-Z_][a-zA-Z0-9]+$
......
...@@ -100,7 +100,7 @@ def np_conv(na, nw, padding, stride=1): ...@@ -100,7 +100,7 @@ def np_conv(na, nw, padding, stride=1):
return nb return nb
def test_convolution_inference(): def test_convolution_inference():
BATCH = 32 BATCH = 8
IH = 48 IH = 48
IW = 48 IW = 48
IC = 16 IC = 16
...@@ -111,40 +111,111 @@ def test_convolution_inference(): ...@@ -111,40 +111,111 @@ def test_convolution_inference():
OH = (IH + 2*PAD - K) + 1 OH = (IH + 2*PAD - K) + 1
OW = (IW + 2*PAD - K) + 1 OW = (IW + 2*PAD - K) + 1
dshape = (IC, IH, IW) dshape = (BATCH, IC, IH, IW)
kshape = (OC, IC, K, K) kshape = (OC, IC, K, K)
bshape = (OC, ) bshape = (OC, )
oshape = (OC, OH, OW) oshape = (BATCH, OC, OH, OW)
data = tvm.placeholder(dshape, name='data') data = tvm.placeholder(dshape, name='data')
kernel = tvm.placeholder(kshape, name='kernel') kernel = tvm.placeholder(kshape, name='kernel')
bias = tvm.placeholder(bshape, name='bias') bias = tvm.placeholder(bshape, name='bias')
output = nnpack.convolution_inference(data, kernel, bias, def verify(target="llvm",
[PAD, PAD, PAD, PAD], [STRIDE, STRIDE]) algorithm=nnpack.ConvolutionAlgorithm.AUTO,
with_bias=True):
if not tvm.module.enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func("tvm.contrib.nnpack.fully_connected_inference", True):
print("skip because extern function is not available")
return
ctx = tvm.cpu(0)
output = nnpack.convolution_inference(
data, kernel, bias if with_bias else None,
[PAD, PAD, PAD, PAD], [STRIDE, STRIDE],
algorithm=algorithm)
s = tvm.create_schedule(output.op) s = tvm.create_schedule(output.op)
def verify(target="llvm"): f = tvm.build(s, [data, kernel, bias, output], target)
na = np.random.uniform(size=dshape).astype(data.dtype)
nb = np.random.uniform(size=kshape).astype(kernel.dtype)
nc = np.zeros(bshape, dtype=bias.dtype)
ta = tvm.nd.array(na, ctx)
tb = tvm.nd.array(nb, ctx)
tc = tvm.nd.array(nc, ctx)
td = tvm.nd.array(np.zeros(oshape, dtype=output.dtype), ctx)
f(ta, tb, tc, td)
nd = np_conv(np.reshape(na, (BATCH, IC, IH, IW)), nb, PAD, STRIDE) + nc.reshape(1, bshape[0], 1, 1)
tvm.testing.assert_allclose(
td.asnumpy(), nd.reshape(BATCH, IC, IH, IW), rtol=1e-5)
for algorithm in [
nnpack.ConvolutionAlgorithm.AUTO,
nnpack.ConvolutionAlgorithm.FFT_8x8,
nnpack.ConvolutionAlgorithm.FFT_16x16,
nnpack.ConvolutionAlgorithm.WT_8x8,
nnpack.ConvolutionAlgorithm.IMPLICIT_GEMM,
nnpack.ConvolutionAlgorithm.WT_8x8_FP16,
]:
for with_bias in [True, False]:
verify(algorithm=algorithm, with_bias=with_bias)
def test_convolution_inference_without_weight_transform():
BATCH = 6
IH = 48
IW = 48
IC = 16
OC = 16
K = 3
PAD = 1
STRIDE = 1
OH = (IH + 2*PAD - K) + 1
OW = (IW + 2*PAD - K) + 1
dshape = (BATCH, IC, IH, IW)
kshape = (OC, IC, K, K)
bshape = (OC, )
oshape = (BATCH, OC, OH, OW)
data = tvm.placeholder(dshape, name='data')
kernel = tvm.placeholder(kshape, name='kernel')
bias = tvm.placeholder(bshape, name='bias')
def verify(target="llvm",
algorithm=nnpack.ConvolutionAlgorithm.AUTO,
with_bias=True):
if not tvm.module.enabled(target): if not tvm.module.enabled(target):
print("skip because %s is not enabled..." % target) print("skip because %s is not enabled..." % target)
return return
if not tvm.get_global_func("tvm.contrib.nnpack.fully_connected_inference", True): if not tvm.get_global_func("tvm.contrib.nnpack.fully_connected_inference", True):
print("skip because extern function is not available") print("skip because extern function is not available")
return return
ctx = tvm.cpu(0) ctx = tvm.cpu(0)
transformed_kernel = nnpack.convolution_inference_weight_transform(
kernel, algorithm=algorithm)
output = nnpack.convolution_inference_without_weight_transform(
data, transformed_kernel, bias if with_bias else None,
[PAD, PAD, PAD, PAD], [STRIDE, STRIDE],
algorithm=algorithm)
s = tvm.create_schedule(output.op)
f = tvm.build(s, [data, kernel, bias, output], target) f = tvm.build(s, [data, kernel, bias, output], target)
na = np.random.uniform(size=dshape).astype(data.dtype) na = np.random.uniform(size=dshape).astype(data.dtype)
nb = np.random.uniform(size=kshape).astype(kernel.dtype) nb = np.random.uniform(size=kshape).astype(kernel.dtype)
nc = np.zeros(bshape, dtype=bias.dtype) nc = np.random.uniform(size=bshape).astype(bias.dtype) if with_bias else np.zeros(bshape, dtype=bias.dtype)
ta = tvm.nd.array(na, ctx) ta = tvm.nd.array(na, ctx)
tb = tvm.nd.array(nb, ctx) tb = tvm.nd.array(nb, ctx)
tc = tvm.nd.array(nc, ctx) tc = tvm.nd.array(nc, ctx)
td = tvm.nd.array(np.zeros(oshape, dtype=output.dtype), ctx) td = tvm.nd.array(np.zeros(oshape, dtype=output.dtype), ctx)
f(ta, tb, tc, td) f(ta, tb, tc, td)
nd = np_conv(np.reshape(na, (1, IC, IH, IW)), nb, PAD, STRIDE) nd = np_conv(np.reshape(na, (BATCH, IC, IH, IW)), nb, PAD, STRIDE) + nc.reshape(1, bshape[0], 1, 1)
tvm.testing.assert_allclose( tvm.testing.assert_allclose(
td.asnumpy(), nd.reshape(IC, IH, IW), rtol=1e-5) td.asnumpy(), nd.reshape(BATCH, IC, IH, IW), rtol=1e-5)
verify() for algorithm in [nnpack.ConvolutionAlgorithm.WT_8x8]:
for with_bias in [True, False]:
verify(algorithm=algorithm, with_bias=with_bias)
def test_convolution_output(): def test_convolution_output():
BATCH = 32 BATCH = 32
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment