Commit 17351875 by Meghan Cowan Committed by Tianqi Chen

[TOPI] bitserial_conv2d move to autotvm template and updates (#2819)

parent cefe07e2
......@@ -205,7 +205,7 @@ def args_to_workload(x, topi_compute_func=None):
workload = tuple([args_to_workload(a) for a in x])
elif isinstance(x, (str, int, float, np.int, np.float)):
workload = x
elif isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)):
elif isinstance(x, (expr.StringImm, expr.UIntImm, expr.IntImm, expr.FloatImm)):
workload = x.value
elif x is None:
workload = 0
......
......@@ -68,6 +68,8 @@ class TaskExtractEnv:
topi.nn.group_conv2d_nchw: "topi_nn_group_conv2d_nchw",
topi.nn.conv2d_transpose_nchw: "topi_nn_conv2d_transpose_nchw",
topi.nn.dense: "topi_nn_dense",
topi.nn.bitserial_conv2d_nchw: "topi_nn_bitserial_conv2d_nchw",
topi.nn.bitserial_conv2d_nhwc: "topi_nn_bitserial_conv2d_nhwc",
topi.nn.deformable_conv2d_nchw: "topi_nn_deformable_conv2d_nchw",
}
......@@ -79,6 +81,8 @@ class TaskExtractEnv:
topi.nn.group_conv2d_nchw: [topi.generic.schedule_group_conv2d_nchw],
topi.nn.conv2d_transpose_nchw: [topi.generic.schedule_conv2d_transpose_nchw],
topi.nn.dense: [topi.generic.schedule_dense],
topi.nn.bitserial_conv2d_nchw: [topi.generic.schedule_bitserial_conv2d_nchw],
topi.nn.bitserial_conv2d_nhwc: [topi.generic.schedule_bitserial_conv2d_nhwc],
topi.nn.deformable_conv2d_nchw: [topi.generic.schedule_deformable_conv2d_nchw],
}
......@@ -174,6 +178,24 @@ class TaskExtractEnv:
return s, [data, weight, bias, C]
return s, [data, weight, C]
@register("topi_nn_bitserial_conv2d_nhwc")
def _topi_bitserial_conv2d_nhwc(*args, **kwargs):
args = deserialize_args(args)
C = topi.nn.bitserial_conv2d_nhwc(*args, **kwargs)
s = topi.generic.nn.schedule_bitserial_conv2d_nhwc([C])
data = args[0]
kernel = args[1]
return s, [data, kernel, C]
@register("topi_nn_bitserial_conv2d_nchw")
def _topi_bitserial_conv2d_nchw(*args, **kwargs):
args = deserialize_args(args)
C = topi.nn.bitserial_conv2d_nchw(*args, **kwargs)
s = topi.generic.nn.schedule_bitserial_conv2d_nchw([C])
data = args[0]
kernel = args[1]
return s, [data, kernel, C]
@register("topi_nn_deformable_conv2d_nchw")
def _topi_nn_deformable_conv2d_nchw(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
......
......@@ -11,16 +11,16 @@ def generate_quantized_np(shape, bits, out_dtype):
return np.random.randint(min_val, max_val, size=shape).astype(out_dtype)
def verify_bitserial_conv2d_nchw(batch, in_size, in_channel, num_filter, kernel, stride, padding,
activation_bits, weight_bits, dorefa):
activation_bits, weight_bits, unipolar):
in_height = in_width = in_size
input_type = 'uint32'
input_dtype = 'uint32'
out_dtype = 'int32'
with tvm.target.create('llvm'):
A = tvm.placeholder((batch, in_channel, in_height, in_width), dtype=input_type, name='A')
W = tvm.placeholder((num_filter, in_channel, kernel, kernel), dtype=input_type, name='W')
B = topi.nn.bitserial_conv2d(A, W, stride, padding, activation_bits, weight_bits,
out_dtype=out_dtype, layout="NCHW", dorefa=dorefa)
A = tvm.placeholder((batch, in_channel, in_height, in_width), dtype=input_dtype, name='A')
W = tvm.placeholder((num_filter, in_channel, kernel, kernel), dtype=input_dtype, name='W')
B = topi.nn.bitserial_conv2d_nchw(A, W, stride, padding, activation_bits, weight_bits,
out_dtype=out_dtype, unipolar=unipolar)
s = topi.generic.schedule_bitserial_conv2d_nchw([B])
a_shape = get_const_tuple(A.shape)
......@@ -28,9 +28,9 @@ def verify_bitserial_conv2d_nchw(batch, in_size, in_channel, num_filter, kernel,
@memoize("topi.tests.test_topi_bitseral_conv2d_nchw")
def get_ref_data():
a_np = generate_quantized_np(get_const_tuple(a_shape), activation_bits, input_type)
w_np = generate_quantized_np(get_const_tuple(w_shape), weight_bits, input_type)
if dorefa:
a_np = generate_quantized_np(get_const_tuple(a_shape), activation_bits, input_dtype)
w_np = generate_quantized_np(get_const_tuple(w_shape), weight_bits, input_dtype)
if unipolar:
w_ = np.copy(w_np).astype(out_dtype)
for x in np.nditer(w_, op_flags=['readwrite']):
x[...] = 1 if x == 1 else -1
......@@ -49,16 +49,16 @@ def verify_bitserial_conv2d_nchw(batch, in_size, in_channel, num_filter, kernel,
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel, stride, padding,
activation_bits, weight_bits, dorefa):
activation_bits, weight_bits, unipolar):
in_height = in_width = in_size
input_type='uint32'
input_dtype='uint32'
out_dtype='int32'
with tvm.target.create('llvm'):
A = tvm.placeholder((batch, in_height, in_width, in_channel), dtype=input_type, name='A')
W = tvm.placeholder((kernel, kernel, in_channel, num_filter), dtype=input_type, name='W')
B = topi.nn.bitserial_conv2d(A, W, stride, padding, activation_bits, weight_bits, out_dtype=out_dtype,
layout="NHWC", dorefa=dorefa)
A = tvm.placeholder((batch, in_height, in_width, in_channel), dtype=input_dtype, name='A')
W = tvm.placeholder((kernel, kernel, in_channel, num_filter), dtype=input_dtype, name='W')
B = topi.nn.bitserial_conv2d_nhwc(A, W, stride, padding, activation_bits, weight_bits,
out_dtype=out_dtype, unipolar=unipolar)
s = topi.generic.schedule_bitserial_conv2d_nhwc([B])
a_shape = get_const_tuple(A.shape)
......@@ -66,9 +66,9 @@ def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel,
@memoize("topi.tests.test_topi_bitseral_conv2d_nhwc")
def get_ref_data():
a_np = generate_quantized_np(get_const_tuple(a_shape), activation_bits, input_type)
w_np = generate_quantized_np(get_const_tuple(w_shape), weight_bits, input_type)
if dorefa:
a_np = generate_quantized_np(get_const_tuple(a_shape), activation_bits, input_dtype)
w_np = generate_quantized_np(get_const_tuple(w_shape), weight_bits, input_dtype)
if unipolar:
w_ = np.copy(w_np).astype(out_dtype)
for x in np.nditer(w_, op_flags=['readwrite']):
x[...] = 1 if x == 1 else -1
......
......@@ -4,6 +4,7 @@ import numpy as np
import tvm
import topi
import topi.testing
from topi.util import get_const_tuple
def generate_quantized_np(shape, bits, out_dtype):
np.random.seed(0)
......@@ -13,19 +14,20 @@ def generate_quantized_np(shape, bits, out_dtype):
# Verify that certain special instructions from the tensorize pass exist
def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel, stride, padding,
activation_bits, weight_bits, dorefa):
activation_bits, weight_bits, unipolar):
in_height = in_width = in_size
input_type = 'uint32'
out_dtype = 'int32'
out_dtype = 'int16'
with tvm.target.arm_cpu('rasp3b'):
device = 'llvm -device=arm_cpu -model=bcm2837 -target=armv7l-linux-gnueabihf -mattr=+neon'
with tvm.target.create(device):
A = tvm.placeholder((batch, in_height, in_width, in_channel), dtype=input_type, name='A')
W = tvm.placeholder((kernel, kernel, in_channel, num_filter), dtype=input_type, name='W')
B = topi.nn.bitserial_conv2d(A, W, stride, padding, activation_bits, weight_bits, out_dtype=out_dtype,
layout="NHWC", dorefa=dorefa)
B = topi.nn.bitserial_conv2d_nhwc(A, W, stride, padding, activation_bits, weight_bits,
pack_dtype='uint8', out_dtype='int16', unipolar=unipolar)
s = topi.generic.schedule_bitserial_conv2d_nhwc([B])
func = tvm.build(s, [A, W, B], tvm.target.arm_cpu('rasp3b'))
func = tvm.build(s, [A, W, B], device)
assembly = func.get_source('asm')
matches = re.findall("vpadal", assembly)
......@@ -35,6 +37,33 @@ def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel,
matches = re.findall("vpadd", assembly)
assert (len(matches) > 0)
ctx = tvm.context(device, 0)
if 'arm' not in os.uname()[4]:
print ("Skipped running code, not an arm device")
return
print("Running on target: %s" % device)
def get_ref_data():
a_np = generate_quantized_np(get_const_tuple(A.shape), activation_bits, input_type)
w_np = generate_quantized_np(get_const_tuple(W.shape), weight_bits, input_type)
if unipolar:
w_ = np.copy(w_np).astype(out_dtype)
for x in np.nditer(w_, op_flags=['readwrite']):
x[...] = 1 if x == 1 else -1
b_np = topi.testing.conv2d_nhwc_python(a_np, w_, stride, padding).astype(out_dtype)
else:
b_np = topi.testing.conv2d_nhwc_python(a_np, w_np, stride, padding).astype(out_dtype)
return a_np, w_np, b_np
a_np, w_np, b_np = get_ref_data()
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
func = tvm.build(s, [A, W, B], device)
func(a, w, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
def test_bitserial_conv2d():
in_size = 56
ic, oc = 64, 64
......@@ -45,6 +74,9 @@ def test_bitserial_conv2d():
verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 1, 1, False)
verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 2, 1, False)
verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 1, 1, True)
verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 2, 1, True)
if __name__ == "__main__":
test_bitserial_conv2d()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment