Unverified Commit 8d945872 by Alex Gladkov Committed by GitHub

Optimize x86 conv3d_ndhwc using data packing approach. (#4866)

Add tuneable conv3d_ndhwc schedule
parent 70c63829
......@@ -133,6 +133,7 @@ def extract_from_multiple_program(mods, params, ops, target, target_host=None,
tvm.relay.op.nn.batch_matmul: [topi.nn.batch_matmul],
tvm.relay.op.nn.deformable_conv2d: [topi.nn.deformable_conv2d_nchw],
tvm.relay.op.nn.conv1d_transpose: [topi.nn.conv1d_transpose_ncw],
tvm.relay.op.nn.conv3d: [topi.nn.conv3d],
}
topi_funcs = []
......
......@@ -94,6 +94,7 @@ class TaskExtractEnv:
topi.nn.bitserial_dense: "topi_nn_bitserial_dense",
topi.nn.deformable_conv2d_nchw: "topi_nn_deformable_conv2d_nchw",
topi.nn.conv1d_transpose_ncw: "topi_nn_conv1d_transpose_ncw",
topi.nn.conv3d: "topi_nn_conv3d",
}
self.topi_to_schedule = {
......@@ -112,6 +113,7 @@ class TaskExtractEnv:
topi.nn.bitserial_dense: [topi.generic.schedule_bitserial_dense],
topi.nn.deformable_conv2d_nchw: [topi.generic.schedule_deformable_conv2d_nchw],
topi.nn.conv1d_transpose_ncw: [topi.generic.schedule_conv1d_transpose_ncw],
topi.nn.conv3d: [topi.generic.schedule_conv3d_ndhwc],
}
# function reflection for tracing
......@@ -129,6 +131,7 @@ class TaskExtractEnv:
topi.nn.bitserial_dense: lambda x: setattr(topi.nn, 'bitserial_dense', x),
topi.nn.deformable_conv2d_nchw: lambda x: setattr(topi.nn, 'deformable_conv2d_nchw', x),
topi.nn.conv1d_transpose_ncw: lambda x: setattr(topi.nn, 'conv1d_transpose_ncw', x),
topi.nn.conv3d: lambda x: setattr(topi.nn, 'conv3d', x),
}
self.allow_duplicate = allow_duplicate
......@@ -231,6 +234,15 @@ class TaskExtractEnv:
s = topi.generic.schedule_conv1d_transpose_ncw([C])
return s, [A, W, C]
@register("topi_nn_conv3d")
def _topi_nn_conv3d(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
args = deserialize_args(args)
A, W = args[:2]
C = topi.nn.conv3d(*args, **kwargs)
s = topi.generic.schedule_conv3d_ndhwc([C])
return s, [A, W, C]
@register("topi_nn_dense")
def _topi_nn_dense(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
......
......@@ -47,6 +47,42 @@ def infer_pad(data, data_pad):
wpad = (TW - IW) // 2
return get_const_int(hpad), get_const_int(wpad)
def infer_pad3d(data, data_pad, layout):
"""Infer the padding from stages in reverse.
Parameters
----------
data : Tensor
data stage.
data_pad : Tensor
pad stage.
Returns
-------
dpad : int
padding depth
hpad : int
padding height
wpad : int
padding width
"""
if data_pad is None:
return 0, 0, 0
if layout == "NDHWC":
_, ID, IH, IW, _ = data.shape
_, TD, TH, TW, _ = data_pad.shape
elif layout == "NCDHW":
_, _, ID, IH, IW = data.shape
_, _, TD, TH, TW = data_pad.shape
else:
raise ValueError("Layout {} is not supported".format(layout))
dpad = (TD - ID)
hpad = (TH - IH)
wpad = (TW - IW)
return get_const_int(dpad), get_const_int(hpad), get_const_int(wpad)
def infer_stride(data, kernel, out):
"""Infer the stride from stages in reverse.
......
......@@ -36,7 +36,6 @@ def verify_conv3d_ndhwc(batch, in_channel, in_size, num_filter, kernel, stride,
A = tvm.placeholder((batch, in_depth, in_height, in_width, in_channel), name='A')
W = tvm.placeholder((kernel_depth, kernel_height, kernel_width, in_channel, num_filter), name='W')
B = topi.nn.conv3d_ndhwc(A, W, stride, padding, dilation)
a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
......@@ -57,6 +56,7 @@ def verify_conv3d_ndhwc(batch, in_channel, in_size, num_filter, kernel, stride,
return
print("Running on target: %s" % device)
with tvm.target.create(device):
B = topi.nn.conv3d(A, W, stride, padding, dilation, layout="NDHWC")
s = topi.generic.schedule_conv3d_ndhwc([B])
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment