Commit 36d3a41e by Meghan Cowan Committed by Tianqi Chen

[TOPI] Bitserial low-precision convolution (#1332)

parent e806cd15
......@@ -154,6 +154,31 @@ def call_extern(dtype, func_name, *args):
dtype, func_name, convert(args), _Call.Extern, None, 0)
def call_llvm_intrin(dtype, name, *args):
"""Build expression by calling an llvm intrinsic function
Parameters
----------
dtype : str
The data type of the result.
name : str
The name of the llvm intrinsic function.
args : list
Poistional arguments.
Returns
-------
call : Expr
The call expression.
"""
import tvm
llvm_id = tvm.codegen.llvm_lookup_intrinsic_id(name)
assert llvm_id != 0, "%s is not an LLVM intrinsic" % name
return call_pure_intrin(dtype, 'llvm_intrin', tvm.const(llvm_id, 'uint32'), *args)
def exp(x):
"""Take exponetial of input x.
......
......@@ -282,6 +282,15 @@ class LLVMModuleNode final : public runtime::ModuleNode {
std::shared_ptr<llvm::LLVMContext> ctx_;
};
unsigned LookupLLVMIntrinsic(const std::string& name) {
return llvm::Function::lookupIntrinsicID(name);
}
TVM_REGISTER_API("codegen.llvm_lookup_intrinsic_id")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = static_cast<int64_t>(LookupLLVMIntrinsic(args[0]));
});
TVM_REGISTER_API("codegen.build_llvm")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::shared_ptr<LLVMModuleNode> n = std::make_shared<LLVMModuleNode>();
......
......@@ -17,6 +17,16 @@ def test_llvm_intrin():
func = tvm.ir_pass.MakeAPI(body, "prefetch", [A], 0, True)
fcode = tvm.build(func, None, "llvm")
def test_llvm_lookup_intrin():
ib = tvm.ir_builder.create()
m = tvm.var("m")
A = ib.pointer("uint8x8", name="A")
x = tvm.call_llvm_intrin("uint8x8", "llvm.ctpop.i8", tvm.const(1, 'uint32'), A)
ib.emit(x)
body = ib.get()
func = tvm.ir_pass.MakeAPI(body, "ctpop", [A], 1, True)
fcode = tvm.build(func, None, "llvm")
def test_llvm_add_pipeline():
nn = 1024
n = tvm.convert(nn)
......@@ -324,3 +334,4 @@ if __name__ == "__main__":
test_llvm_flip_pipeline()
test_llvm_madd_pipeline()
test_llvm_temp_space()
test_llvm_lookup_intrin()
......@@ -143,6 +143,41 @@ def schedule_depthwise_conv2d_nhwc(outs):
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_bitserial_conv2d_nchw(outs):
"""Schedule for bitserial_conv2d_nchw
Parameters
----------
outs: Array of Tensor
The computation graph description of bitserial_conv2d_nchw
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_bitserial_conv2d_nhwc(outs):
"""Schedule for bitserial_conv2d_nhwc
Parameters
----------
outs: Array of Tensor
The computation graph description of bitserial_conv2d_nchw
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.override_native_generic_func("schedule_reduce")
def schedule_reduce(outs):
......
......@@ -16,4 +16,5 @@ from .conv2d_transpose import *
from .bnn import *
from .upsampling import *
from .local_response_norm import *
from .bitserial_conv2d import *
from .l2_normalize import *
......@@ -4,3 +4,4 @@ from __future__ import absolute_import as _abs
from .conv2d import schedule_conv2d_nchw
from .depthwise_conv2d import schedule_depthwise_conv2d_nchw
from .bitserial_conv2d import schedule_bitserial_conv2d_nhwc
......@@ -8,3 +8,4 @@ from .binary_dense import schedule_binary_dense
from .nn import *
from .injective import *
from .pooling import schedule_pool, schedule_global_pool
from .bitserial_conv2d import schedule_bitserial_conv2d
import os
import numpy as np
import tvm
import topi
import topi.testing
from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple
from tvm.contrib import util
from tvm.contrib.pickle_memoize import memoize
def generate_quantized_np(shape, bits, out_dtype):
min_val = 0
max_val = 1 << bits
return np.random.randint(min_val, max_val, size=shape).astype(out_dtype)
def verify_bitserial_conv2d_nchw(batch, in_size, in_channel, num_filter, kernel, stride, padding,
activation_bits, weight_bits, dorefa):
in_height = in_width = in_size
input_type='uint32'
out_dtype='int32'
with tvm.target.create('llvm'):
A = tvm.placeholder((batch, in_channel, in_height, in_width), dtype=input_type, name='A')
W = tvm.placeholder((num_filter, in_channel, kernel, kernel), dtype=input_type, name='W')
B = topi.nn.bitserial_conv2d(A, W, stride, padding, activation_bits, weight_bits,
out_dtype=out_dtype, layout="NCHW", dorefa=dorefa)
s = topi.generic.schedule_bitserial_conv2d_nchw([B])
a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
dtype = A.dtype
def get_ref_data():
a_np = generate_quantized_np(get_const_tuple(A.shape), activation_bits, input_type)
w_np = generate_quantized_np(get_const_tuple(W.shape), weight_bits, input_type)
if dorefa:
w_ = np.copy(w_np).astype(out_dtype)
for x in np.nditer(w_, op_flags=['readwrite']):
x[...] = 1 if x == 1 else -1
b_np = topi.testing.conv2d_nchw_python(a_np.astype(out_dtype), w_, stride, padding)
else:
b_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding)
return a_np, w_np, b_np
a_np, w_np, b_np = get_ref_data()
ctx = tvm.cpu(0)
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
func = tvm.build(s, [A, W, B], "llvm")
func(a, w, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel, stride, padding,
activation_bits, weight_bits, dorefa):
in_height = in_width = in_size
input_type='uint32'
out_dtype='int32'
with tvm.target.create('llvm'):
A = tvm.placeholder((batch, in_height, in_width, in_channel), dtype=input_type, name='A')
W = tvm.placeholder((kernel, kernel, in_channel, num_filter), dtype=input_type, name='W')
B = topi.nn.bitserial_conv2d(A, W, stride, padding, activation_bits, weight_bits, out_dtype=out_dtype,
layout="NHWC", dorefa=dorefa)
s = topi.generic.schedule_bitserial_conv2d_nhwc([B])
a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
dtype = A.dtype
def get_ref_data():
a_np = generate_quantized_np(get_const_tuple(A.shape), activation_bits, input_type)
w_np = generate_quantized_np(get_const_tuple(W.shape), weight_bits, input_type)
if dorefa:
w_ = np.copy(w_np).astype(out_dtype)
for x in np.nditer(w_, op_flags=['readwrite']):
x[...] = 1 if x == 1 else -1
b_np = topi.testing.conv2d_nhwc_python(a_np, w_, stride, padding).astype(out_dtype)
else:
b_np = topi.testing.conv2d_nhwc_python(a_np, w_np, stride, padding).astype(out_dtype)
return a_np, w_np, b_np
a_np, w_np, b_np = get_ref_data()
ctx = tvm.cpu(0)
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
func = tvm.build(s, [A, W, B], 'llvm')
func(a, w, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
def test_bitserial_conv2d():
in_size = 56
ic, oc = 64, 64
k = 3
stride = 1
pad = 1
verify_bitserial_conv2d_nchw(1, in_size, ic, oc, k, stride, pad, 1, 1, True)
verify_bitserial_conv2d_nchw(1, in_size, ic, oc, k, stride, pad, 2, 1, True)
verify_bitserial_conv2d_nchw(1, in_size, ic, oc, k, stride, pad, 1, 1, False)
verify_bitserial_conv2d_nchw(1, in_size, ic, oc, k, stride, pad, 2, 1, False)
verify_bitserial_conv2d_nchw(1, in_size, ic, oc, k, stride, pad, 2, 2, False)
verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 1, 1, True)
verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 2, 1, True)
verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 1, 1, False)
verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 2, 1, False)
verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 2, 2, False)
if __name__ == "__main__":
test_bitserial_conv2d()
\ No newline at end of file
import os
import re
import numpy as np
import tvm
import topi
import topi.testing
from topi.util import get_const_tuple
from tvm.contrib import util
target = 'llvm -target=armv7l-none-linux-gnueabihf -mcpu=cortex-a53 -mattr=+neon'
def generate_quantized_np(shape, bits, out_dtype):
np.random.seed(0)
min_val = 0
max_val = 1 << bits
return np.random.randint(min_val, max_val, size=shape).astype(out_dtype)
# Verify that certain special instructions from the tensorize pass exist
def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel, stride, padding,
activation_bits, weight_bits, dorefa):
in_height = in_width = in_size
input_type='uint32'
out_dtype='int32'
with tvm.target.rasp():
A = tvm.placeholder((batch, in_height, in_width, in_channel), dtype=input_type, name='A')
W = tvm.placeholder((kernel, kernel, in_channel, num_filter), dtype=input_type, name='W')
B = topi.nn.bitserial_conv2d(A, W, stride, padding, activation_bits, weight_bits, out_dtype=out_dtype,
layout="NHWC", dorefa=dorefa)
s = topi.generic.schedule_bitserial_conv2d_nhwc([B])
func = tvm.build(s, [A, W, B], target)
assembly = func.get_source('asm')
matches = re.findall("vpadal", assembly)
assert (len(matches) > 0)
matches = re.findall("vcnt", assembly)
assert (len(matches) > 0)
matches = re.findall("vpadd", assembly)
assert (len(matches) > 0)
def test_bitserial_conv2d():
in_size = 56
ic, oc = 64, 64
k = 3
stride = 1
pad = 1
verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 1, 1, False)
verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 2, 1, False)
if __name__ == "__main__":
test_bitserial_conv2d()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment