Commit 1e48b02f by ziheng Committed by Tianqi Chen

[NNPACK] Add nnpack.convolution (#301)

* [NNPACK] Add nnpack.convolution

* Add instrinsic

* Fix lint
parent 1389d208
......@@ -50,3 +50,84 @@ def fully_connected_output(lhs, rhs):
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.nnpack.fully_connected_output",
ins[0], ins[1], outs[0]), name="C")
def convolution_inference(data, kernel, bias, padding, stride):
"""Create an extern op to do inference convolution of 3D tensor data and
4D tensor kernel and 1D tensor bias with nnpack.
Parameters
----------
data : Tensor
data 3D tensor input[input_channels][input_height][input_width] of
FP32 elements.
kernel : Tensor
kernel 4D tensor kernel[output_channels][input_channels][kernel_height]
[kernel_width] of FP32 elements.
bias : Tensor
bias 1D array bias[output_channels][input_channels][kernel_height]
[kernel_width] of FP32 elements.
padding : list
padding A 4-dim list of [pad_top, pad_bottom, pad_left, pad_right],
which indicates the padding around the feature map.
stride : list
stride A 2-dim list of [stride_height, stride_width], which indicates
the stride.
Returns
-------
output : Tensor
output 3D tensor output[output_channels][output_height][output_width]
of FP32 elements.
"""
assert isinstance(padding, list) and len(padding) == 4
assert isinstance(stride, list) and len(stride) == 2
_, input_height, input_width = data.shape
output_channels, _, kernel_height, kernel_width = kernel.shape
output_height = (input_height + padding[0] + padding[1] - kernel_height) + 1
output_width = (input_width + padding[0] + padding[1] - kernel_width) + 1
return _api.extern(
(output_channels, output_height, output_width), [data, kernel, bias],
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.nnpack.convolution_inference", ins[0], ins[1], ins[2],
outs[0], padding[0], padding[1], padding[2], padding[3],
stride[0], stride[1]), name="C")
def convolution_output(data, kernel, bias, padding):
"""Create an extern op to compute convolution of 4D tensor data and
4D tensor kernel and 1D tensor bias with nnpack.
Parameters
----------
data : Tensor
data 4D tensor input[batch_size][input_channels][input_height]
[input_width] of FP32 elements.
kernel : Tensor
kernel 4D tensor kernel[output_channels][input_channels][kernel_height]
[kernel_width] of FP32 elements.
bias : Tensor
bias 1D array bias[output_channels][input_channels][kernel_height]
[kernel_width] of FP32 elements.
padding : list
padding A 4-dim list of [pad_top, pad_bottom, pad_left, pad_right],
which indicates the padding around the feature map.
Returns
-------
output : Tensor
output 4D tensor output[batch_size][output_channels][output_height]
[output_width] of FP32 elements.
"""
assert isinstance(padding, list) and len(padding) == 4
batch, _, input_height, input_width = data.shape
output_channels, _, kernel_height, kernel_width = kernel.shape
output_height = (input_height + padding[0] + padding[1] - kernel_height) + 1
output_width = (input_width + padding[0] + padding[1] - kernel_width) + 1
return _api.extern(
(batch, output_channels, output_height, output_width), [data, kernel, bias],
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.nnpack.convolution_output", ins[0], ins[1], ins[2],
outs[0], padding[0], padding[1], padding[2], padding[3]), name="C")
/*!
* Copyright (c) 2017 by Contributors
* \file Use external nnpack library call.
*/
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <dmlc/logging.h>
#include <nnpack.h>
namespace tvm {
namespace contrib {
using namespace runtime;
TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference")
.set_body([](TVMArgs args, TVMRetValue *ret) {
nnp_initialize();
DLTensor* input = args[0];
DLTensor* kernel = args[1];
DLTensor* bias = args[2];
DLTensor* output = args[3];
nnp_padding input_padding{args[4], args[5], args[6], args[7]};
nnp_size stride_size{args[8], args[9]};
CHECK_EQ(input->ndim, 3);
CHECK_EQ(kernel->ndim, 4);
CHECK_EQ(bias->ndim, 1);
CHECK_EQ(output->ndim, 3);
CHECK_EQ(input->shape[0], kernel->shape[1]);
size_t input_channels = input->shape[0];
CHECK_EQ(output->shape[0], kernel->shape[0]);
CHECK_EQ(output->shape[0], bias->shape[0]);
size_t output_channels = output->shape[0];
nnp_size input_size{static_cast<size_t>(input->shape[1]),
static_cast<size_t>(input->shape[2])};
nnp_size kernel_size{static_cast<size_t>(kernel->shape[2]),
static_cast<size_t>(kernel->shape[3])};
CHECK(input->strides == nullptr);
CHECK(kernel->strides == nullptr);
CHECK(bias->strides == nullptr);
CHECK(TypeMatch(input->dtype, kFloat, 32));
CHECK(TypeMatch(kernel->dtype, kFloat, 32));
CHECK(TypeMatch(bias->dtype, kFloat, 32));
CHECK(TypeMatch(output->dtype, kFloat, 32));
nnp_convolution_inference(nnp_convolution_algorithm_auto,
nnp_convolution_transform_strategy_block_based,
input_channels,
output_channels,
input_size,
input_padding,
kernel_size,
stride_size,
static_cast<float*>(input->data),
static_cast<float*>(kernel->data),
static_cast<float*>(bias->data),
static_cast<float*>(output->data),
NULL,
NULL,
nnp_activation_identity,
NULL,
NULL,
NULL);
});
TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_output")
.set_body([](TVMArgs args, TVMRetValue *ret) {
nnp_initialize();
DLTensor* input = args[0];
DLTensor* kernel = args[1];
DLTensor* bias = args[2];
DLTensor* output = args[3];
nnp_padding input_padding{args[4], args[5], args[6], args[7]};
CHECK_EQ(input->ndim, 4);
CHECK_EQ(kernel->ndim, 4);
CHECK_EQ(bias->ndim, 1);
CHECK_EQ(output->ndim, 4);
CHECK_EQ(input->shape[0], output->shape[0]);
size_t batch_size = input->shape[0];
CHECK_EQ(input->shape[1], kernel->shape[1]);
size_t input_channels = input->shape[1];
CHECK_EQ(output->shape[1], bias->shape[0]);
CHECK_EQ(output->shape[1], kernel->shape[0]);
size_t output_channels = output->shape[1];
nnp_size input_size{static_cast<size_t>(input->shape[2]),
static_cast<size_t>(input->shape[3])};
nnp_size kernel_size{static_cast<size_t>(kernel->shape[2]),
static_cast<size_t>(kernel->shape[3])};
CHECK(input->strides == nullptr);
CHECK(kernel->strides == nullptr);
CHECK(bias->strides == nullptr);
CHECK(TypeMatch(input->dtype, kFloat, 32));
CHECK(TypeMatch(kernel->dtype, kFloat, 32));
CHECK(TypeMatch(bias->dtype, kFloat, 32));
CHECK(TypeMatch(output->dtype, kFloat, 32));
nnp_convolution_output(nnp_convolution_algorithm_auto,
batch_size,
input_channels,
output_channels,
input_size,
input_padding,
kernel_size,
static_cast<float*>(input->data),
static_cast<float*>(kernel->data),
static_cast<float*>(bias->data),
static_cast<float*>(output->data),
nnp_activation_identity,
NULL,
NULL,
NULL);
});
} // namespace contrib
} // namespace tvm
import tvm
import numpy as np
import scipy.signal
from tvm.contrib import nnpack
def test_fully_connected_output():
......@@ -9,7 +10,6 @@ def test_fully_connected_output():
bias = tvm.var('bias', dtype=tvm.float32)
A = tvm.placeholder((n, l), name='A')
B = tvm.placeholder((m, l), name='B')
C = nnpack.fully_connected_inference(A, B)
C = nnpack.fully_connected_output(A, B)
D = tvm.compute(C.shape, lambda i, j: C[i,j] + bias, name="D")
s = tvm.create_schedule(D.op)
......@@ -62,7 +62,135 @@ def test_fully_connected_inference():
d.asnumpy(), np.dot(a.asnumpy(), b.asnumpy().T) + bb, rtol=1e-5)
verify()
def np_conv(na, nw, padding, stride=1):
batch, in_channel, in_height, in_width = na.shape
_, num_filter, kernel_h, kernel_w = nw.shape
if isinstance(stride, int):
stride_h = stride_w = stride
else:
stride_h, stride_w = stride
if isinstance(padding, int):
pad_h = pad_w = padding * 2
else:
pad_h, pad_w = padding
pad_h *= 2
pad_w *= 2
pad_top = int(np.ceil(float(pad_h) / 2))
pad_bottom = pad_h - pad_top
pad_left = int(np.ceil(float(pad_w) / 2))
pad_right = pad_w - pad_left
out_channel = num_filter
out_height = (in_height - kernel_h + pad_h) // stride_h + 1
out_width = (in_width - kernel_w + pad_w) // stride_w + 1
nb = np.zeros((batch, out_channel, out_height, out_width))
for n in range(batch):
for f in range(out_channel):
for c in range(in_channel):
if pad_h > 0:
apad = np.zeros((in_height + pad_h, in_width + pad_w))
apad[pad_top:-pad_bottom, pad_left:-pad_right] = na[n, c]
else:
apad = na[n, c]
out = scipy.signal.convolve2d(
apad, np.rot90(np.rot90(nw[f, c])), mode='valid')
nb[n, f] += out[::stride, ::stride]
return nb
def test_convolution_inference():
BATCH = 32
IH = 48
IW = 48
IC = 16
OC = 16
K = 3
PAD = 1
STRIDE = 1
OH = (IH + 2*PAD - K) + 1
OW = (IW + 2*PAD - K) + 1
dshape = (IC, IH, IW)
kshape = (OC, IC, K, K)
bshape = (OC, )
oshape = (OC, OH, OW)
data = tvm.placeholder(dshape, name='data')
kernel = tvm.placeholder(kshape, name='kernel')
bias = tvm.placeholder(bshape, name='bias')
output = nnpack.convolution_inference(data, kernel, bias,
[PAD, PAD, PAD, PAD], [STRIDE, STRIDE])
s = tvm.create_schedule(output.op)
def verify(target="llvm"):
if not tvm.module.enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func("tvm.contrib.nnpack.fully_connected_inference", True):
print("skip because extern function is not avalable")
return
ctx = tvm.cpu(0)
f = tvm.build(s, [data, kernel, bias, output], target)
na = np.random.uniform(size=dshape).astype(data.dtype)
nb = np.random.uniform(size=kshape).astype(kernel.dtype)
nc = np.zeros(bshape, dtype=bias.dtype)
ta = tvm.nd.array(na, ctx)
tb = tvm.nd.array(nb, ctx)
tc = tvm.nd.array(nc, ctx)
td = tvm.nd.array(np.zeros(oshape, dtype=output.dtype), ctx)
f(ta, tb, tc, td)
nd = np_conv(np.reshape(na, (1, IC, IH, IW)), nb, PAD, STRIDE)
np.testing.assert_allclose(
td.asnumpy(), nd.reshape(IC, IH, IW), rtol=1e-5)
verify()
def test_convolution_output():
BATCH = 32
IH = 48
IW = 48
IC = 16
OC = 16
K = 3
PAD = 1
OH = (IH + 2*PAD - K) + 1
OW = (IW + 2*PAD - K) + 1
dshape = (BATCH, IC, IH, IW)
kshape = (OC, IC, K, K)
bshape = (OC, )
oshape = (BATCH, OC, OH, OW)
data = tvm.placeholder(dshape, name='data')
kernel = tvm.placeholder(kshape, name='kernel')
bias = tvm.placeholder(bshape, name='bias')
output = nnpack.convolution_output(data, kernel, bias, [PAD, PAD, PAD, PAD])
s = tvm.create_schedule(output.op)
def verify(target="llvm"):
if not tvm.module.enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func("tvm.contrib.nnpack.fully_connected_inference", True):
print("skip because extern function is not avalable")
return
ctx = tvm.cpu(0)
f = tvm.build(s, [data, kernel, bias, output], target)
na = np.random.uniform(size=dshape).astype(data.dtype)
nb = np.random.uniform(size=kshape).astype(kernel.dtype)
nc = np.zeros(bshape, dtype=bias.dtype)
ta = tvm.nd.array(na, ctx)
tb = tvm.nd.array(nb, ctx)
tc = tvm.nd.array(nc, ctx)
td = tvm.nd.array(np.zeros(oshape, dtype=output.dtype), ctx)
f(ta, tb, tc, td)
nd = np_conv(na, nb, PAD)
np.testing.assert_allclose(
td.asnumpy(), nd, rtol=1e-5)
verify()
if __name__ == "__main__":
test_fully_connected_inference()
test_fully_connected_output()
import nose
nose.runmodule()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment