Unverified Commit 8d945872 by Alex Gladkov Committed by GitHub

Optimize x86 conv3d_ndhwc using data packing approach. (#4866)

Add tuneable conv3d_ndhwc schedule
parent 70c63829
...@@ -133,6 +133,7 @@ def extract_from_multiple_program(mods, params, ops, target, target_host=None, ...@@ -133,6 +133,7 @@ def extract_from_multiple_program(mods, params, ops, target, target_host=None,
tvm.relay.op.nn.batch_matmul: [topi.nn.batch_matmul], tvm.relay.op.nn.batch_matmul: [topi.nn.batch_matmul],
tvm.relay.op.nn.deformable_conv2d: [topi.nn.deformable_conv2d_nchw], tvm.relay.op.nn.deformable_conv2d: [topi.nn.deformable_conv2d_nchw],
tvm.relay.op.nn.conv1d_transpose: [topi.nn.conv1d_transpose_ncw], tvm.relay.op.nn.conv1d_transpose: [topi.nn.conv1d_transpose_ncw],
tvm.relay.op.nn.conv3d: [topi.nn.conv3d],
} }
topi_funcs = [] topi_funcs = []
......
...@@ -94,6 +94,7 @@ class TaskExtractEnv: ...@@ -94,6 +94,7 @@ class TaskExtractEnv:
topi.nn.bitserial_dense: "topi_nn_bitserial_dense", topi.nn.bitserial_dense: "topi_nn_bitserial_dense",
topi.nn.deformable_conv2d_nchw: "topi_nn_deformable_conv2d_nchw", topi.nn.deformable_conv2d_nchw: "topi_nn_deformable_conv2d_nchw",
topi.nn.conv1d_transpose_ncw: "topi_nn_conv1d_transpose_ncw", topi.nn.conv1d_transpose_ncw: "topi_nn_conv1d_transpose_ncw",
topi.nn.conv3d: "topi_nn_conv3d",
} }
self.topi_to_schedule = { self.topi_to_schedule = {
...@@ -112,6 +113,7 @@ class TaskExtractEnv: ...@@ -112,6 +113,7 @@ class TaskExtractEnv:
topi.nn.bitserial_dense: [topi.generic.schedule_bitserial_dense], topi.nn.bitserial_dense: [topi.generic.schedule_bitserial_dense],
topi.nn.deformable_conv2d_nchw: [topi.generic.schedule_deformable_conv2d_nchw], topi.nn.deformable_conv2d_nchw: [topi.generic.schedule_deformable_conv2d_nchw],
topi.nn.conv1d_transpose_ncw: [topi.generic.schedule_conv1d_transpose_ncw], topi.nn.conv1d_transpose_ncw: [topi.generic.schedule_conv1d_transpose_ncw],
topi.nn.conv3d: [topi.generic.schedule_conv3d_ndhwc],
} }
# function reflection for tracing # function reflection for tracing
...@@ -129,6 +131,7 @@ class TaskExtractEnv: ...@@ -129,6 +131,7 @@ class TaskExtractEnv:
topi.nn.bitserial_dense: lambda x: setattr(topi.nn, 'bitserial_dense', x), topi.nn.bitserial_dense: lambda x: setattr(topi.nn, 'bitserial_dense', x),
topi.nn.deformable_conv2d_nchw: lambda x: setattr(topi.nn, 'deformable_conv2d_nchw', x), topi.nn.deformable_conv2d_nchw: lambda x: setattr(topi.nn, 'deformable_conv2d_nchw', x),
topi.nn.conv1d_transpose_ncw: lambda x: setattr(topi.nn, 'conv1d_transpose_ncw', x), topi.nn.conv1d_transpose_ncw: lambda x: setattr(topi.nn, 'conv1d_transpose_ncw', x),
topi.nn.conv3d: lambda x: setattr(topi.nn, 'conv3d', x),
} }
self.allow_duplicate = allow_duplicate self.allow_duplicate = allow_duplicate
...@@ -231,6 +234,15 @@ class TaskExtractEnv: ...@@ -231,6 +234,15 @@ class TaskExtractEnv:
s = topi.generic.schedule_conv1d_transpose_ncw([C]) s = topi.generic.schedule_conv1d_transpose_ncw([C])
return s, [A, W, C] return s, [A, W, C]
@register("topi_nn_conv3d")
def _topi_nn_conv3d(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
args = deserialize_args(args)
A, W = args[:2]
C = topi.nn.conv3d(*args, **kwargs)
s = topi.generic.schedule_conv3d_ndhwc([C])
return s, [A, W, C]
@register("topi_nn_dense") @register("topi_nn_dense")
def _topi_nn_dense(*args, **kwargs): def _topi_nn_dense(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call" assert not kwargs, "Do not support kwargs in template function call"
......
...@@ -47,6 +47,42 @@ def infer_pad(data, data_pad): ...@@ -47,6 +47,42 @@ def infer_pad(data, data_pad):
wpad = (TW - IW) // 2 wpad = (TW - IW) // 2
return get_const_int(hpad), get_const_int(wpad) return get_const_int(hpad), get_const_int(wpad)
def infer_pad3d(data, data_pad, layout):
"""Infer the padding from stages in reverse.
Parameters
----------
data : Tensor
data stage.
data_pad : Tensor
pad stage.
Returns
-------
dpad : int
padding depth
hpad : int
padding height
wpad : int
padding width
"""
if data_pad is None:
return 0, 0, 0
if layout == "NDHWC":
_, ID, IH, IW, _ = data.shape
_, TD, TH, TW, _ = data_pad.shape
elif layout == "NCDHW":
_, _, ID, IH, IW = data.shape
_, _, TD, TH, TW = data_pad.shape
else:
raise ValueError("Layout {} is not supported".format(layout))
dpad = (TD - ID)
hpad = (TH - IH)
wpad = (TW - IW)
return get_const_int(dpad), get_const_int(hpad), get_const_int(wpad)
def infer_stride(data, kernel, out): def infer_stride(data, kernel, out):
"""Infer the stride from stages in reverse. """Infer the stride from stages in reverse.
......
...@@ -36,7 +36,6 @@ def verify_conv3d_ndhwc(batch, in_channel, in_size, num_filter, kernel, stride, ...@@ -36,7 +36,6 @@ def verify_conv3d_ndhwc(batch, in_channel, in_size, num_filter, kernel, stride,
A = tvm.placeholder((batch, in_depth, in_height, in_width, in_channel), name='A') A = tvm.placeholder((batch, in_depth, in_height, in_width, in_channel), name='A')
W = tvm.placeholder((kernel_depth, kernel_height, kernel_width, in_channel, num_filter), name='W') W = tvm.placeholder((kernel_depth, kernel_height, kernel_width, in_channel, num_filter), name='W')
B = topi.nn.conv3d_ndhwc(A, W, stride, padding, dilation)
a_shape = get_const_tuple(A.shape) a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape) w_shape = get_const_tuple(W.shape)
...@@ -57,6 +56,7 @@ def verify_conv3d_ndhwc(batch, in_channel, in_size, num_filter, kernel, stride, ...@@ -57,6 +56,7 @@ def verify_conv3d_ndhwc(batch, in_channel, in_size, num_filter, kernel, stride,
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
B = topi.nn.conv3d(A, W, stride, padding, dilation, layout="NDHWC")
s = topi.generic.schedule_conv3d_ndhwc([B]) s = topi.generic.schedule_conv3d_ndhwc([B])
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment