Commit fc83c7f2 by Andrew Tulloch Committed by Tianqi Chen

[TVM] [NNPACK] Modernize and improve NNPACK bindings (#2084)

parent 9f441d81
......@@ -9,6 +9,10 @@ if(USE_NNPACK)
include_directories(${PTHREAD_POOL_PATH}/include)
find_library(NNPACK_CONTRIB_LIB nnpack ${NNPACK_PATH}/lib)
find_library(NNPACK_PTHREAD_CONTRIB_LIB pthreadpool ${NNPACK_PATH}/lib)
find_library(NNPACK_CPUINFO_CONTRIB_LIB cpuinfo ${NNPACK_PATH}/lib)
find_library(NNPACK_CLOG_CONTRIB_LIB clog ${NNPACK_PATH}/lib)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${NNPACK_CONTRIB_LIB})
list(APPEND TVM_RUNTIME_LINKER_LIBS ${NNPACK_PTHREAD_CONTRIB_LIB})
list(APPEND TVM_RUNTIME_LINKER_LIBS ${NNPACK_CPUINFO_CONTRIB_LIB})
list(APPEND TVM_RUNTIME_LINKER_LIBS ${NNPACK_CLOG_CONTRIB_LIB})
endif(USE_NNPACK)
......@@ -63,14 +63,32 @@ def fully_connected_output(lhs, rhs, nthreads=1):
"tvm.contrib.nnpack.fully_connected_output",
ins[0], ins[1], outs[0], nthreads), name="C")
def convolution_inference(data, kernel, bias, padding, stride, nthreads=1):
"""Create an extern op to do inference convolution of 3D tensor data and
class ConvolutionAlgorithm:
AUTO = 0
FFT_8x8 = 1
FFT_16x16 = 2
WT_8x8 = 3
IMPLICIT_GEMM = 4
DIRECT = 5
WT_8x8_FP16 = 6
class ConvolutionTransformStrategy:
COMPUTE = 1
PRECOMPUTE = 2
def convolution_inference(
data, kernel, bias, padding, stride, nthreads=1,
algorithm=ConvolutionAlgorithm.AUTO):
"""Create an extern op to do inference convolution of 4D tensor data and
4D tensor kernel and 1D tensor bias with nnpack.
Parameters
----------
data : Tensor
data 3D tensor input[input_channels][input_height][input_width] of
data 4D tensor input[batch][input_channels][input_height][input_width] of
FP32 elements.
kernel : Tensor
kernel 4D tensor kernel[output_channels][input_channels][kernel_height]
......@@ -88,23 +106,108 @@ def convolution_inference(data, kernel, bias, padding, stride, nthreads=1):
Returns
-------
output : Tensor
output 3D tensor output[output_channels][output_height][output_width]
output 4D tensor output[batch][output_channels][output_height][output_width]
of FP32 elements.
"""
assert isinstance(padding, list) and len(padding) == 4
assert isinstance(stride, list) and len(stride) == 2
_, input_height, input_width = data.shape
batch, _, input_height, input_width = data.shape
output_channels, _, kernel_height, kernel_width = kernel.shape
output_height = (input_height + padding[0] + padding[1] - kernel_height) / stride[0] + 1
output_width = (input_width + padding[0] + padding[1] - kernel_width) / stride[1] + 1
return _api.extern(
(output_channels, output_height, output_width), [data, kernel, bias],
(batch, output_channels, output_height, output_width),
[data, kernel, bias] if bias is not None else [data, kernel],
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.nnpack.convolution_inference", ins[0], ins[1], ins[2],
"tvm.contrib.nnpack.convolution_inference",
ins[0],
ins[1],
ins[2] if bias is not None else 0,
outs[0], padding[0], padding[1], padding[2], padding[3],
stride[0], stride[1], nthreads), name="C")
stride[0], stride[1], nthreads, algorithm), name="C")
def convolution_inference_without_weight_transform(
data, transformed_kernel, bias, padding, stride, nthreads=1,
algorithm=ConvolutionAlgorithm.AUTO):
"""Create an extern op to do inference convolution of 4D tensor data and
4D pre-transformed tensor kernel and 1D tensor bias with nnpack.
Parameters
----------
data : Tensor
data 4D tensor input[batch][input_channels][input_height][input_width] of
FP32 elements.
transformed_kernel : Tensor
transformed_kernel 4D tensor kernel[output_channels][input_channels][tile]
[tile] of FP32 elements.
bias : Tensor
bias 1D array bias[output_channels][input_channels][kernel_height]
[kernel_width] of FP32 elements.
padding : list
padding A 4-dim list of [pad_top, pad_bottom, pad_left, pad_right],
which indicates the padding around the feature map.
stride : list
stride A 2-dim list of [stride_height, stride_width], which indicates
the stride.
Returns
-------
output : Tensor
output 4D tensor output[batch][output_channels][output_height][output_width]
of FP32 elements.
"""
assert algorithm in (ConvolutionAlgorithm.WT_8x8,
ConvolutionAlgorithm.WT_8x8_FP16)
assert isinstance(padding, list) and len(padding) == 4
assert isinstance(stride, list) and len(stride) == 2
batch, _, input_height, input_width = data.shape
output_channels, _, _, _ = transformed_kernel.shape
kernel_height, kernel_width = (3, 3)
output_height = (input_height + padding[0] + padding[1] - kernel_height) / stride[0] + 1
output_width = (input_width + padding[0] + padding[1] - kernel_width) / stride[1] + 1
return _api.extern(
(batch, output_channels, output_height, output_width),
[data, transformed_kernel, bias] if bias is not None else [data, transformed_kernel],
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.nnpack.convolution_inference_without_weight_transform",
ins[0],
ins[1],
ins[2] if bias is not None else 0,
outs[0], padding[0], padding[1], padding[2], padding[3],
stride[0], stride[1], nthreads, algorithm), name="C")
def convolution_inference_weight_transform(
kernel, nthreads=1,
algorithm=ConvolutionAlgorithm.AUTO):
"""Create an extern op to do inference convolution of 3D tensor data and
4D tensor kernel and 1D tensor bias with nnpack.
Parameters
----------
kernel : Tensor
kernel 4D tensor kernel[output_channels][input_channels][kernel_height]
[kernel_width] of FP32 elements.
Returns
-------
output : Tensor
output 4D tensor output[output_channels][input_channels][tile][tile]
of FP32 elements.
"""
assert algorithm in (ConvolutionAlgorithm.WT_8x8, ConvolutionAlgorithm.WT_8x8_FP16)
output_channels, input_channels, _, _ = kernel.shape
transform_tile_size = 8
return _api.extern(
(output_channels, input_channels, transform_tile_size, transform_tile_size),
[kernel],
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.nnpack.convolution_inference_weight_transform",
ins[0], outs[0], nthreads, algorithm), name="transform_kernel")
def convolution_output(data, kernel, bias, padding, nthreads=1):
"""Create an extern op to compute convolution of 4D tensor data and
......@@ -144,4 +247,5 @@ def convolution_output(data, kernel, bias, padding, nthreads=1):
"tvm.contrib.nnpack.convolution_output", ins[0], ins[1], ins[2],
outs[0], padding[0], padding[1], padding[2], padding[3], nthreads), name="C")
_init_api("tvm.contrib.nnpack")
......@@ -13,61 +13,207 @@ namespace contrib {
using namespace runtime;
TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference")
.set_body([](TVMArgs args, TVMRetValue *ret) {
.set_body([](TVMArgs args, TVMRetValue *ret) {
NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal();
nnp_initialize();
DLTensor* input = args[0];
DLTensor* kernel = args[1];
DLTensor* bias = args[2];
DLTensor* output = args[3];
uint64_t pad_top = args[4], pad_right = args[5], pad_bottom = args[6], pad_left = args[7];
static std::once_flag flag;
std::call_once(flag,
[]() { CHECK_EQ(nnp_initialize(), nnp_status_success); });
DLTensor *input = args[0];
DLTensor *kernel = args[1];
DLTensor *bias = nullptr;
if (args[2].type_code() == kArrayHandle) {
bias = args[2];
}
DLTensor *output = args[3];
uint64_t pad_top = args[4], pad_right = args[5], pad_bottom = args[6],
pad_left = args[7];
nnp_padding input_padding{pad_top, pad_right, pad_bottom, pad_left};
uint64_t stride_width = args[8], stride_height = args[9];
nnp_size stride_size{stride_width, stride_height};
NNPackConfig(args[10]);
CHECK_EQ(input->ndim, 3);
uint64_t algo_ = args[11];
nnp_convolution_algorithm algo =
static_cast<nnp_convolution_algorithm>(algo_);
CHECK_EQ(input->ndim, 4);
CHECK_EQ(kernel->ndim, 4);
if (bias) {
CHECK_EQ(bias->ndim, 1);
CHECK_EQ(output->ndim, 3);
CHECK_EQ(input->shape[0], kernel->shape[1]);
size_t input_channels = input->shape[0];
CHECK_EQ(output->shape[0], kernel->shape[0]);
CHECK_EQ(output->shape[0], bias->shape[0]);
size_t output_channels = output->shape[0];
nnp_size input_size{static_cast<size_t>(input->shape[1]),
static_cast<size_t>(input->shape[2])};
}
CHECK_EQ(output->ndim, 4);
CHECK_EQ(input->shape[1], kernel->shape[1]);
CHECK_EQ(input->shape[0], output->shape[0]);
size_t input_channels = input->shape[1];
CHECK_EQ(output->shape[1], kernel->shape[0]);
if (bias) {
CHECK_EQ(output->shape[1], bias->shape[0]);
}
size_t output_channels = output->shape[1];
nnp_size input_size{static_cast<size_t>(input->shape[2]),
static_cast<size_t>(input->shape[3])};
nnp_size kernel_size{static_cast<size_t>(kernel->shape[2]),
static_cast<size_t>(kernel->shape[3])};
CHECK(input->strides == nullptr);
CHECK(kernel->strides == nullptr);
if (bias) {
CHECK(bias->strides == nullptr);
}
CHECK(TypeMatch(input->dtype, kDLFloat, 32));
CHECK(TypeMatch(kernel->dtype, kDLFloat, 32));
if (bias) {
CHECK(TypeMatch(bias->dtype, kDLFloat, 32));
}
CHECK(TypeMatch(output->dtype, kDLFloat, 32));
nnp_convolution_inference(nnp_convolution_algorithm_auto,
nnp_convolution_transform_strategy_block_based,
input_channels,
output_channels,
input_size,
input_padding,
kernel_size,
// Allocate a zero-bias if we don't pass one in.
std::unique_ptr<std::vector<float>> zero_bias;
if (!bias) {
zero_bias.reset(new std::vector<float>(output->shape[1], 0.0));
}
for (auto n = 0; n < input->shape[0]; ++n) {
nnp_status status = nnp_convolution_inference(
algo, nnp_convolution_transform_strategy_compute, input_channels,
output_channels, input_size, input_padding, kernel_size,
stride_size,
static_cast<float*>(input->data),
static_cast<float*>(kernel->data),
static_cast<float*>(bias->data),
static_cast<float*>(output->data),
NULL,
NULL,
nnp_activation_identity,
NULL,
entry->threadpool,
NULL);
static_cast<float *>(input->data) + n * input->shape[1] *
input->shape[2] *
input->shape[3],
static_cast<float *>(kernel->data),
bias ? static_cast<float *>(bias->data) : zero_bias->data(),
static_cast<float *>(output->data) + n * output->shape[1] *
output->shape[2] *
output->shape[3],
NULL, NULL, nnp_activation_identity, NULL, entry->threadpool, NULL);
CHECK_EQ(status, nnp_status_success);
}
});
TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference_without_weight_transform")
.set_body([](TVMArgs args, TVMRetValue *ret) {
NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal();
static std::once_flag flag;
std::call_once(flag,
[]() { CHECK_EQ(nnp_initialize(), nnp_status_success); });
DLTensor *input = args[0];
DLTensor *transformed_kernel = args[1];
DLTensor *bias = nullptr;
if (args[2].type_code() == kArrayHandle) {
bias = args[2];
}
DLTensor *output = args[3];
uint64_t pad_top = args[4], pad_right = args[5], pad_bottom = args[6],
pad_left = args[7];
nnp_padding input_padding{pad_top, pad_right, pad_bottom, pad_left};
uint64_t stride_width = args[8], stride_height = args[9];
nnp_size stride_size{stride_width, stride_height};
NNPackConfig(args[10]);
uint64_t algo_ = args[11];
nnp_convolution_algorithm algo =
static_cast<nnp_convolution_algorithm>(algo_);
CHECK_EQ(input->ndim, 4);
if (bias) {
CHECK_EQ(bias->ndim, 1);
}
CHECK_EQ(output->ndim, 4);
CHECK_EQ(input->shape[0], output->shape[0]);
size_t input_channels = input->shape[1];
if (bias) {
CHECK_EQ(output->shape[1], bias->shape[0]);
}
size_t output_channels = output->shape[1];
nnp_size input_size{static_cast<size_t>(input->shape[2]),
static_cast<size_t>(input->shape[3])};
nnp_size kernel_size{3, 3};
CHECK(input->strides == nullptr);
CHECK(transformed_kernel->strides == nullptr);
if (bias) {
CHECK(bias->strides == nullptr);
}
CHECK(TypeMatch(input->dtype, kDLFloat, 32));
CHECK(TypeMatch(transformed_kernel->dtype, kDLFloat, 32));
if (bias) {
CHECK(TypeMatch(bias->dtype, kDLFloat, 32));
}
CHECK(TypeMatch(output->dtype, kDLFloat, 32));
// Allocate a zero-bias if we don't pass one in.
std::unique_ptr<std::vector<float>> zero_bias;
if (!bias) {
zero_bias.reset(new std::vector<float>(output->shape[1], 0.0));
}
for (auto n = 0; n < input->shape[0]; ++n) {
nnp_status status = nnp_convolution_inference(
algo, nnp_convolution_transform_strategy_reuse, input_channels, output_channels,
input_size, input_padding, kernel_size, stride_size,
static_cast<float *>(input->data) + n * input->shape[1] *
input->shape[2] *
input->shape[3],
static_cast<float *>(transformed_kernel->data),
bias ? static_cast<float *>(bias->data) : zero_bias->data(),
static_cast<float *>(output->data) + n * output->shape[1] *
output->shape[2] *
output->shape[3],
NULL, NULL,
nnp_activation_identity, NULL, entry->threadpool, NULL);
CHECK_EQ(status, nnp_status_success);
}
});
TVM_REGISTER_GLOBAL(
"tvm.contrib.nnpack.convolution_inference_weight_transform")
.set_body([](TVMArgs args, TVMRetValue *ret) {
NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal();
static std::once_flag flag;
std::call_once(flag,
[]() { CHECK_EQ(nnp_initialize(), nnp_status_success); });
DLTensor *kernel = args[0];
DLTensor *transformed_kernel = args[1];
// Dummy sizes
nnp_padding input_padding{1, 1, 1, 1};
nnp_size stride_size{1, 1};
nnp_size input_size{100, 100};
NNPackConfig(args[2]);
uint64_t algo_ = args[3];
nnp_convolution_algorithm algo =
static_cast<nnp_convolution_algorithm>(algo_);
CHECK_EQ(kernel->ndim, 4);
size_t input_channels = kernel->shape[1];
size_t output_channels = kernel->shape[0];
CHECK_EQ(kernel->shape[2], 3);
CHECK_EQ(kernel->shape[3], 3);
nnp_size kernel_size{static_cast<size_t>(kernel->shape[2]),
static_cast<size_t>(kernel->shape[3])};
CHECK(kernel->strides == nullptr);
CHECK(TypeMatch(kernel->dtype, kDLFloat, 32));
size_t transformed_kernel_size = 0;
nnp_status status;
status = nnp_convolution_inference(
algo, nnp_convolution_transform_strategy_precompute, input_channels,
output_channels, input_size, input_padding, kernel_size, stride_size,
nullptr, nullptr, nullptr, nullptr, nullptr, &transformed_kernel_size,
nnp_activation_identity, nullptr, entry->threadpool, nullptr);
CHECK_EQ(status, nnp_status_success);
CHECK_LE(transformed_kernel_size, GetDataSize(*transformed_kernel));
status = nnp_convolution_inference(
algo, nnp_convolution_transform_strategy_precompute, input_channels,
output_channels, input_size, input_padding, kernel_size, stride_size,
nullptr, static_cast<float *>(kernel->data), nullptr, nullptr,
static_cast<float *>(transformed_kernel->data),
&transformed_kernel_size, nnp_activation_identity, nullptr,
entry->threadpool, nullptr);
CHECK_EQ(status, nnp_status_success);
});
......@@ -109,7 +255,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_output")
CHECK(TypeMatch(bias->dtype, kDLFloat, 32));
CHECK(TypeMatch(output->dtype, kDLFloat, 32));
nnp_convolution_output(nnp_convolution_algorithm_auto,
nnp_status status = nnp_convolution_output(nnp_convolution_algorithm_auto,
batch_size,
input_channels,
output_channels,
......@@ -126,6 +272,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_output")
NULL,
entry->threadpool,
NULL);
CHECK_EQ(status, nnp_status_success);
});
} // namespace contrib
} // namespace tvm
......@@ -10,20 +10,30 @@ using namespace runtime;
typedef dmlc::ThreadLocalStore<NNPackThreadLocalEntry> NNPackThreadLocalStore;
NNPackThreadLocalEntry* NNPackThreadLocalEntry::ThreadLocal() {
return NNPackThreadLocalStore::Get();
}
bool NNPackConfig(uint64_t nthreads) {
NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal();
if (entry->threadpool != NULL &&
pthreadpool_get_threads_count(entry->threadpool) != nthreads) {
if (entry->threadpool && pthreadpool_get_threads_count(entry->threadpool) == nthreads) {
CHECK_NE(nthreads, 1);
return true;
}
if (entry->threadpool) {
pthreadpool_destroy(entry->threadpool);
entry->threadpool = NULL;
entry->threadpool = nullptr;
}
if (entry->threadpool == NULL) {
entry->threadpool = pthreadpool_create(nthreads);
if (nthreads == 1) {
// a null threadpool means the function is invoked on the calling thread,
// which is the desired logic for nthreads == 1
CHECK(!entry->threadpool);
return true;
}
entry->threadpool = pthreadpool_create(nthreads);
return true;
}
......
......@@ -15,7 +15,7 @@ namespace contrib {
using namespace runtime;
struct NNPackThreadLocalEntry {
pthreadpool_t threadpool{NULL};
pthreadpool_t threadpool{nullptr};
static NNPackThreadLocalEntry* ThreadLocal();
};
......
......@@ -290,10 +290,10 @@ variable-rgx=[a-z_][a-z0-9_]{2,30}$
variable-name-hint=[a-z_][a-z0-9_]{2,30}$
# Regular expression matching correct function names
function-rgx=[a-z_][a-z0-9_]{2,30}$
function-rgx=[a-z_][a-z0-9_]{2,48}$
# Naming hint for function names
function-name-hint=[a-z_][a-z0-9_]{2,30}$
function-name-hint=[a-z_][a-z0-9_]{2,48}$
# Regular expression matching correct class names
class-rgx=[A-Z_][a-zA-Z0-9]+$
......
......@@ -100,7 +100,7 @@ def np_conv(na, nw, padding, stride=1):
return nb
def test_convolution_inference():
BATCH = 32
BATCH = 8
IH = 48
IW = 48
IC = 16
......@@ -111,40 +111,111 @@ def test_convolution_inference():
OH = (IH + 2*PAD - K) + 1
OW = (IW + 2*PAD - K) + 1
dshape = (IC, IH, IW)
dshape = (BATCH, IC, IH, IW)
kshape = (OC, IC, K, K)
bshape = (OC, )
oshape = (OC, OH, OW)
oshape = (BATCH, OC, OH, OW)
data = tvm.placeholder(dshape, name='data')
kernel = tvm.placeholder(kshape, name='kernel')
bias = tvm.placeholder(bshape, name='bias')
output = nnpack.convolution_inference(data, kernel, bias,
[PAD, PAD, PAD, PAD], [STRIDE, STRIDE])
def verify(target="llvm",
algorithm=nnpack.ConvolutionAlgorithm.AUTO,
with_bias=True):
if not tvm.module.enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func("tvm.contrib.nnpack.fully_connected_inference", True):
print("skip because extern function is not available")
return
ctx = tvm.cpu(0)
output = nnpack.convolution_inference(
data, kernel, bias if with_bias else None,
[PAD, PAD, PAD, PAD], [STRIDE, STRIDE],
algorithm=algorithm)
s = tvm.create_schedule(output.op)
def verify(target="llvm"):
f = tvm.build(s, [data, kernel, bias, output], target)
na = np.random.uniform(size=dshape).astype(data.dtype)
nb = np.random.uniform(size=kshape).astype(kernel.dtype)
nc = np.zeros(bshape, dtype=bias.dtype)
ta = tvm.nd.array(na, ctx)
tb = tvm.nd.array(nb, ctx)
tc = tvm.nd.array(nc, ctx)
td = tvm.nd.array(np.zeros(oshape, dtype=output.dtype), ctx)
f(ta, tb, tc, td)
nd = np_conv(np.reshape(na, (BATCH, IC, IH, IW)), nb, PAD, STRIDE) + nc.reshape(1, bshape[0], 1, 1)
tvm.testing.assert_allclose(
td.asnumpy(), nd.reshape(BATCH, IC, IH, IW), rtol=1e-5)
for algorithm in [
nnpack.ConvolutionAlgorithm.AUTO,
nnpack.ConvolutionAlgorithm.FFT_8x8,
nnpack.ConvolutionAlgorithm.FFT_16x16,
nnpack.ConvolutionAlgorithm.WT_8x8,
nnpack.ConvolutionAlgorithm.IMPLICIT_GEMM,
nnpack.ConvolutionAlgorithm.WT_8x8_FP16,
]:
for with_bias in [True, False]:
verify(algorithm=algorithm, with_bias=with_bias)
def test_convolution_inference_without_weight_transform():
BATCH = 6
IH = 48
IW = 48
IC = 16
OC = 16
K = 3
PAD = 1
STRIDE = 1
OH = (IH + 2*PAD - K) + 1
OW = (IW + 2*PAD - K) + 1
dshape = (BATCH, IC, IH, IW)
kshape = (OC, IC, K, K)
bshape = (OC, )
oshape = (BATCH, OC, OH, OW)
data = tvm.placeholder(dshape, name='data')
kernel = tvm.placeholder(kshape, name='kernel')
bias = tvm.placeholder(bshape, name='bias')
def verify(target="llvm",
algorithm=nnpack.ConvolutionAlgorithm.AUTO,
with_bias=True):
if not tvm.module.enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func("tvm.contrib.nnpack.fully_connected_inference", True):
print("skip because extern function is not available")
return
ctx = tvm.cpu(0)
transformed_kernel = nnpack.convolution_inference_weight_transform(
kernel, algorithm=algorithm)
output = nnpack.convolution_inference_without_weight_transform(
data, transformed_kernel, bias if with_bias else None,
[PAD, PAD, PAD, PAD], [STRIDE, STRIDE],
algorithm=algorithm)
s = tvm.create_schedule(output.op)
f = tvm.build(s, [data, kernel, bias, output], target)
na = np.random.uniform(size=dshape).astype(data.dtype)
nb = np.random.uniform(size=kshape).astype(kernel.dtype)
nc = np.zeros(bshape, dtype=bias.dtype)
nc = np.random.uniform(size=bshape).astype(bias.dtype) if with_bias else np.zeros(bshape, dtype=bias.dtype)
ta = tvm.nd.array(na, ctx)
tb = tvm.nd.array(nb, ctx)
tc = tvm.nd.array(nc, ctx)
td = tvm.nd.array(np.zeros(oshape, dtype=output.dtype), ctx)
f(ta, tb, tc, td)
nd = np_conv(np.reshape(na, (1, IC, IH, IW)), nb, PAD, STRIDE)
nd = np_conv(np.reshape(na, (BATCH, IC, IH, IW)), nb, PAD, STRIDE) + nc.reshape(1, bshape[0], 1, 1)
tvm.testing.assert_allclose(
td.asnumpy(), nd.reshape(IC, IH, IW), rtol=1e-5)
verify()
td.asnumpy(), nd.reshape(BATCH, IC, IH, IW), rtol=1e-5)
for algorithm in [nnpack.ConvolutionAlgorithm.WT_8x8]:
for with_bias in [True, False]:
verify(algorithm=algorithm, with_bias=with_bias)
def test_convolution_output():
BATCH = 32
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment