Commit 7b821851 by ziheng Committed by Tianqi Chen

[CONTRIB/NNPACK] Add NNPack Fully Connected Functions (#199)

* Add NNPack Fully Connected Inference

* Add NNPack fully_connected_output

* Fix lint

* Fix
parent 29d5ffbb
......@@ -107,6 +107,7 @@ ifdef LLVM_CONFIG
endif
include make/contrib/cblas.mk
include make/contrib/nnpack.mk
ifdef ADD_CFLAGS
CFLAGS += $(ADD_CFLAGS)
......
......@@ -52,6 +52,9 @@ USE_RPC = 0
# Whether use BLAS, choices: openblas, atlas, blas, apple
USE_BLAS = none
USE_NNPACK = 0
# NNPACK_PATH = none
# add the path to CUDA library to link and compile flag
# if you have already add them to environment variable.
# CUDA_PATH = /usr/local/cuda
NNPACK_CONTRIB_SRC = $(wildcard src/contrib/nnpack/*.cc)
NNPACK_CONTRIB_OBJ = $(patsubst src/%.cc, build/%.o, $(NNPACK_CONTRIB_SRC))
ifeq ($(USE_NNPACK), 1)
ifndef NNPACK_PATH
NNPACK_PATH = $(ROOTDIR)/NNPACK
endif
PTHREAD_POOL_PATH = $(NNPACK_PATH)/deps/pthreadpool
CFLAGS += -DTVM_USE_NNPACK=1 -I$(NNPACK_PATH)/include -I$(PTHREAD_POOL_PATH)/include
LDFLAGS += -L$(NNPACK_PATH)/lib -lnnpack -lpthreadpool -lpthread
RUNTIME_DEP += $(NNPACK_CONTRIB_OBJ)
endif
"""External function interface to NNPACK libraroes."""
from __future__ import absolute_import as _abs
from .. import api as _api
from .. import intrin as _intrin
def fully_connected_inference(lhs, rhs):
"""Create an extern op that compute fully connected of 1D tensor lhs and
2D tensor rhs with nnpack.
Parameters
----------
lhs : Tensor
lhs 1D array input[input_channels] of FP32 elements
rhs : Tensor
lhs 2D matrix kernel[output_channels][input_channels] of FP32 elements
Returns
-------
C : Tensor
lhs 1D array out[output_channels] of FP32 elements.
"""
m = rhs.shape[0]
return _api.extern(
(m, ), [lhs, rhs],
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.nnpack.fully_connected_inference",
ins[0], ins[1], outs[0]), name="C")
def fully_connected_output(lhs, rhs):
"""Create an extern op that compute fully connected of 2D tensor lhs and
2D tensor rhs with nnpack.
Parameters
----------
lhs : Tensor
lhs 2D matrix input[batch_size][input_channels] of FP32 elements
rhs : Tensor
lhs 2D matrix kernel[output_channels][input_channels] of FP32 elements
Returns
-------
C : Tensor
lhs 2D array out[batch_size][output_channels] of FP32 elements.
"""
n = lhs.shape[0]
m = rhs.shape[0]
return _api.extern(
(n, m), [lhs, rhs],
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.nnpack.fully_connected_output",
ins[0], ins[1], outs[0]), name="C")
/*!
* Copyright (c) 2017 by Contributors
* \file Use external nnpack library call.
*/
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <dmlc/logging.h>
#include <nnpack.h>
namespace tvm {
namespace contrib {
using namespace runtime;
// matrix multiplication for row major
TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.fully_connected_inference")
.set_body([](TVMArgs args, TVMRetValue *ret) {
nnp_initialize();
DLTensor* A = args[0];
DLTensor* B = args[1];
DLTensor* C = args[2];
CHECK_EQ(A->ndim, 1);
CHECK_EQ(B->ndim, 2);
CHECK_EQ(C->ndim, 1);
CHECK_EQ(B->shape[0], C->shape[0]);
CHECK_EQ(B->shape[1], A->shape[0]);
CHECK(C->strides == nullptr);
CHECK(B->strides == nullptr);
CHECK(A->strides == nullptr);
CHECK(TypeMatch(A->dtype, kFloat, 32));
CHECK(TypeMatch(B->dtype, kFloat, 32));
CHECK(TypeMatch(C->dtype, kFloat, 32));
nnp_fully_connected_inference(B->shape[1],
B->shape[0],
static_cast<float*>(A->data),
static_cast<float*>(B->data),
static_cast<float*>(C->data),
NULL);
});
TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.fully_connected_output")
.set_body([](TVMArgs args, TVMRetValue *ret) {
nnp_initialize();
DLTensor* A = args[0];
DLTensor* B = args[1];
DLTensor* C = args[2];
CHECK_EQ(A->ndim, 2);
CHECK_EQ(B->ndim, 2);
CHECK_EQ(C->ndim, 2);
CHECK_EQ(B->shape[0], C->shape[1]);
CHECK_EQ(B->shape[1], A->shape[1]);
CHECK_EQ(A->shape[0], C->shape[0]);
CHECK(C->strides == nullptr);
CHECK(B->strides == nullptr);
CHECK(A->strides == nullptr);
CHECK(TypeMatch(A->dtype, kFloat, 32));
CHECK(TypeMatch(B->dtype, kFloat, 32));
CHECK(TypeMatch(C->dtype, kFloat, 32));
nnp_fully_connected_output(A->shape[0],
B->shape[1],
B->shape[0],
static_cast<float*>(A->data),
static_cast<float*>(B->data),
static_cast<float*>(C->data),
NULL,
NULL);
});
} // namespace contrib
} // namespace tvm
import tvm
import numpy as np
from tvm.contrib import nnpack
def test_fully_connected_output():
n = 1024
l = 128
m = 235
bias = tvm.var('bias', dtype=tvm.float32)
A = tvm.placeholder((n, l), name='A')
B = tvm.placeholder((m, l), name='B')
C = nnpack.fully_connected_inference(A, B)
C = nnpack.fully_connected_output(A, B)
D = tvm.compute(C.shape, lambda i, j: C[i,j] + bias, name="D")
s = tvm.create_schedule(D.op)
def verify(target="llvm"):
if not tvm.module.enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func("tvm.contrib.nnpack.fully_connected_output", True):
print("skip because extern function is not avalable")
return
ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, D, bias], target)
a = tvm.nd.array(np.random.uniform(size=(n, l)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(m, l)).astype(B.dtype), ctx)
d = tvm.nd.array(np.zeros((n, m), dtype=D.dtype), ctx)
bb = 10.0
f(a, b, d, bb)
np.testing.assert_allclose(
d.asnumpy(), np.dot(a.asnumpy(), b.asnumpy().T) + bb, rtol=1e-5)
verify()
def test_fully_connected_inference():
n = 1024
l = 128
m = 235
bias = tvm.var('bias', dtype=tvm.float32)
A = tvm.placeholder((l, ), name='A')
B = tvm.placeholder((m, l), name='B')
C = nnpack.fully_connected_inference(A, B)
D = tvm.compute(C.shape, lambda i: C[i] + bias, name="D")
s = tvm.create_schedule(D.op)
def verify(target="llvm"):
if not tvm.module.enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func("tvm.contrib.nnpack.fully_connected_inference", True):
print("skip because extern function is not avalable")
return
ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, D, bias], target)
a = tvm.nd.array(np.random.uniform(size=(l)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(m, l)).astype(B.dtype), ctx)
d = tvm.nd.array(np.zeros((m, ), dtype=D.dtype), ctx)
bb = 10.0
f(a, b, d, bb)
np.testing.assert_allclose(
d.asnumpy(), np.dot(a.asnumpy(), b.asnumpy().T) + bb, rtol=1e-5)
verify()
if __name__ == "__main__":
test_fully_connected_inference()
test_fully_connected_output()
......@@ -20,3 +20,5 @@ python -m nose -v examples/graph_executor/tests || exit -1
TVM_FFI=cython python -m nose -v tests/python/integration || exit -1
TVM_FFI=ctypes python3 -m nose -v tests/python/integration || exit -1
TVM_FFI=cython python -m nose -v tests/python/contrib || exit -1
TVM_FFI=ctypes python3 -m nose -v tests/python/contrib || exit -1
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment