Commit f467377f by hlu1 Committed by Tianqi Chen

[contrib][nnpack] remove training-optimized ops (#2224)

parent 9a00b7b8
......@@ -34,30 +34,6 @@ def fully_connected_inference(lhs, rhs, nthreads=1):
"tvm.contrib.nnpack.fully_connected_inference",
ins[0], ins[1], outs[0], nthreads), name="C")
def fully_connected_output(lhs, rhs, nthreads=1):
"""Create an extern op that compute fully connected of 2D tensor lhs and
2D tensor rhs with nnpack.
Parameters
----------
lhs : Tensor
lhs 2D matrix input[batch_size][input_channels] of FP32 elements
rhs : Tensor
lhs 2D matrix kernel[output_channels][input_channels] of FP32 elements
Returns
-------
C : Tensor
lhs 2D array out[batch_size][output_channels] of FP32 elements.
"""
n = lhs.shape[0]
m = rhs.shape[0]
return _api.extern(
(n, m), [lhs, rhs],
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.nnpack.fully_connected_output",
ins[0], ins[1], outs[0], nthreads), name="C")
class ConvolutionAlgorithm:
AUTO = 0
......@@ -204,43 +180,4 @@ def convolution_inference_weight_transform(
"tvm.contrib.nnpack.convolution_inference_weight_transform",
ins[0], outs[0], nthreads, algorithm), name="transform_kernel")
def convolution_output(data, kernel, bias, padding, nthreads=1):
"""Create an extern op to compute convolution of 4D tensor data and
4D tensor kernel and 1D tensor bias with nnpack.
Parameters
----------
data : Tensor
data 4D tensor input[batch_size][input_channels][input_height]
[input_width] of FP32 elements.
kernel : Tensor
kernel 4D tensor kernel[output_channels][input_channels][kernel_height]
[kernel_width] of FP32 elements.
bias : Tensor
bias 1D array bias[output_channels][input_channels][kernel_height]
[kernel_width] of FP32 elements.
padding : list
padding A 4-dim list of [pad_top, pad_bottom, pad_left, pad_right],
which indicates the padding around the feature map.
Returns
-------
output : Tensor
output 4D tensor output[batch_size][output_channels][output_height]
[output_width] of FP32 elements.
"""
assert isinstance(padding, list) and len(padding) == 4
batch, _, input_height, input_width = data.shape
output_channels, _, kernel_height, kernel_width = kernel.shape
output_height = (input_height + padding[0] + padding[1] - kernel_height) + 1
output_width = (input_width + padding[0] + padding[1] - kernel_width) + 1
return _api.extern(
(batch, output_channels, output_height, output_width), [data, kernel, bias],
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.nnpack.convolution_output", ins[0], ins[1], ins[2],
outs[0], padding[0], padding[1], padding[2], padding[3], nthreads), name="C")
_init_api("tvm.contrib.nnpack")
......@@ -215,64 +215,5 @@ TVM_REGISTER_GLOBAL(
entry->threadpool, nullptr);
CHECK_EQ(status, nnp_status_success);
});
TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_output")
.set_body([](TVMArgs args, TVMRetValue *ret) {
NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal();
nnp_initialize();
DLTensor* input = args[0];
DLTensor* kernel = args[1];
DLTensor* bias = args[2];
DLTensor* output = args[3];
uint64_t pad_top = args[4], pad_right = args[5], pad_bottom = args[6], pad_left = args[7];
nnp_padding input_padding{pad_top, pad_right, pad_bottom, pad_left};
NNPackConfig(args[8]);
CHECK_EQ(input->ndim, 4);
CHECK_EQ(kernel->ndim, 4);
CHECK_EQ(bias->ndim, 1);
CHECK_EQ(output->ndim, 4);
CHECK_EQ(input->shape[0], output->shape[0]);
size_t batch_size = input->shape[0];
CHECK_EQ(input->shape[1], kernel->shape[1]);
size_t input_channels = input->shape[1];
CHECK_EQ(output->shape[1], bias->shape[0]);
CHECK_EQ(output->shape[1], kernel->shape[0]);
size_t output_channels = output->shape[1];
nnp_size input_size{static_cast<size_t>(input->shape[2]),
static_cast<size_t>(input->shape[3])};
nnp_size kernel_size{static_cast<size_t>(kernel->shape[2]),
static_cast<size_t>(kernel->shape[3])};
CHECK(input->strides == nullptr);
CHECK(kernel->strides == nullptr);
CHECK(bias->strides == nullptr);
CHECK(TypeMatch(input->dtype, kDLFloat, 32));
CHECK(TypeMatch(kernel->dtype, kDLFloat, 32));
CHECK(TypeMatch(bias->dtype, kDLFloat, 32));
CHECK(TypeMatch(output->dtype, kDLFloat, 32));
nnp_status status = nnp_convolution_output(nnp_convolution_algorithm_auto,
batch_size,
input_channels,
output_channels,
input_size,
input_padding,
kernel_size,
static_cast<float*>(input->data),
static_cast<float*>(kernel->data),
static_cast<float*>(bias->data),
static_cast<float*>(output->data),
NULL,
NULL,
nnp_activation_identity,
NULL,
entry->threadpool,
NULL);
CHECK_EQ(status, nnp_status_success);
});
} // namespace contrib
} // namespace tvm
......@@ -43,38 +43,5 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.fully_connected_inference")
entry->threadpool);
});
TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.fully_connected_output")
.set_body([](TVMArgs args, TVMRetValue *ret) {
NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal();
nnp_initialize();
DLTensor* A = args[0];
DLTensor* B = args[1];
DLTensor* C = args[2];
NNPackConfig(args[3]);
CHECK_EQ(A->ndim, 2);
CHECK_EQ(B->ndim, 2);
CHECK_EQ(C->ndim, 2);
CHECK_EQ(B->shape[0], C->shape[1]);
CHECK_EQ(B->shape[1], A->shape[1]);
CHECK_EQ(A->shape[0], C->shape[0]);
CHECK(C->strides == nullptr);
CHECK(B->strides == nullptr);
CHECK(A->strides == nullptr);
CHECK(TypeMatch(A->dtype, kDLFloat, 32));
CHECK(TypeMatch(B->dtype, kDLFloat, 32));
CHECK(TypeMatch(C->dtype, kDLFloat, 32));
nnp_fully_connected_output(A->shape[0],
B->shape[1],
B->shape[0],
static_cast<float*>(A->data),
static_cast<float*>(B->data),
static_cast<float*>(C->data),
entry->threadpool,
NULL);
});
} // namespace contrib
} // namespace tvm
......@@ -3,38 +3,6 @@ import numpy as np
import scipy.signal
from tvm.contrib import nnpack
def test_fully_connected_output():
n = 1024
l = 128
m = 235
bias = tvm.var('bias', dtype=tvm.float32)
A = tvm.placeholder((n, l), name='A')
B = tvm.placeholder((m, l), name='B')
C = nnpack.fully_connected_output(A, B)
D = tvm.compute(C.shape, lambda i, j: C[i,j] + bias, name="D")
s = tvm.create_schedule(D.op)
def verify(target="llvm"):
if not tvm.module.enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func("tvm.contrib.nnpack.fully_connected_output", True):
print("skip because extern function is not available")
return
if not nnpack.is_available():
return
ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, D, bias], target)
a = tvm.nd.array(np.random.uniform(size=(n, l)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(m, l)).astype(B.dtype), ctx)
d = tvm.nd.array(np.zeros((n, m), dtype=D.dtype), ctx)
bb = 10.0
f(a, b, d, bb)
tvm.testing.assert_allclose(
d.asnumpy(), np.dot(a.asnumpy(), b.asnumpy().T) + bb, rtol=1e-5)
verify()
def test_fully_connected_inference():
n = 1024
......@@ -131,7 +99,7 @@ def test_convolution_inference():
if not tvm.module.enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func("tvm.contrib.nnpack.fully_connected_inference", True):
if not tvm.get_global_func("tvm.contrib.nnpack.convolution_inference", True):
print("skip because extern function is not available")
return
if not nnpack.is_available():
......@@ -195,7 +163,7 @@ def test_convolution_inference_without_weight_transform():
if not tvm.module.enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func("tvm.contrib.nnpack.fully_connected_inference", True):
if not tvm.get_global_func("tvm.contrib.nnpack.convolution_inference_without_weight_transform", True):
print("skip because extern function is not available")
return
if not nnpack.is_available():
......@@ -228,53 +196,6 @@ def test_convolution_inference_without_weight_transform():
for with_bias in [True, False]:
verify(algorithm=algorithm, with_bias=with_bias)
def test_convolution_output():
BATCH = 32
IH = 48
IW = 48
IC = 16
OC = 16
K = 3
PAD = 1
OH = (IH + 2*PAD - K) + 1
OW = (IW + 2*PAD - K) + 1
dshape = (BATCH, IC, IH, IW)
kshape = (OC, IC, K, K)
bshape = (OC, )
oshape = (BATCH, OC, OH, OW)
data = tvm.placeholder(dshape, name='data')
kernel = tvm.placeholder(kshape, name='kernel')
bias = tvm.placeholder(bshape, name='bias')
output = nnpack.convolution_output(data, kernel, bias, [PAD, PAD, PAD, PAD])
s = tvm.create_schedule(output.op)
def verify(target="llvm"):
if not tvm.module.enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func("tvm.contrib.nnpack.fully_connected_inference", True):
print("skip because extern function is not available")
return
if not nnpack.is_available():
return
ctx = tvm.cpu(0)
f = tvm.build(s, [data, kernel, bias, output], target)
na = np.random.uniform(size=dshape).astype(data.dtype)
nb = np.random.uniform(size=kshape).astype(kernel.dtype)
nc = np.zeros(bshape, dtype=bias.dtype)
ta = tvm.nd.array(na, ctx)
tb = tvm.nd.array(nb, ctx)
tc = tvm.nd.array(nc, ctx)
td = tvm.nd.array(np.zeros(oshape, dtype=output.dtype), ctx)
f(ta, tb, tc, td)
nd = np_conv(na, nb, PAD)
tvm.testing.assert_allclose(
td.asnumpy(), nd, rtol=1e-5)
verify()
if __name__ == "__main__":
import nose
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment