Commit ffff1e49 by wetliu Committed by Tianqi Chen

[TOPI] Depth wise convolution backward methods for NHWC (#434)

* rename the nchw and pass the unit test; going to do it for nhwc depthwise

* bug with fusion

* nchw works fine; nhwc float32 problem remains

* still cannot bind them together

* fusion works

* syntax fix

* all bugs fixed; test cases pass

* minor fix on nn.h

* back wrt input

* backward wrt input nhwc; only test case in recipe

* test case for depthwise back wrt input

* test case for depthwise backward wrt weight

* tags

* minor fixes

* pylint test; add arch=3.7

* modify scheduler

* better backward depthwise w.r.t weight scheduler

* updated scheduler

* test_topi_depthwise_conv2d_back_input.py and test_topi_depthwise_conv2d_back_weight.py success

* all test cases wrt input pass

* update

* new test cases and scheduler

* not working 1 and 2

* good wrt weight, bad wrt input

* test cases added

* remove tf lines

* minor fix

* compute arch changed

* remove compile hook

* minor change

* pylint

* fix the float for python case

* fix cases for python3 case

* except for memoize

* fix most; memoize still wrong

* memoize added

* unexpected layout cases added for scheduler

* error message layout other than NHWC added

* improve padding

* fix as pr requests

* remove dilate in backward wrt weight
parent f2ab736b
......@@ -15,6 +15,8 @@ constexpr auto kConv2dNCHW = "conv2d_nchw";
constexpr auto kConv2dHWCN = "conv2d_hwcn";
constexpr auto kDepthwiseConv2dNCHW = "depthwise_conv2d_nchw";
constexpr auto kDepthwiseConv2dNHWC = "depthwise_conv2d_nhwc";
constexpr auto kDepthwiseConv2dBackInputNHWC = "depthwise_conv2d_back_input_nhwc";
constexpr auto kDepthwiseConv2dBackWeightNHWC = "depthwise_conv2d_back_weight_nhwc";
constexpr auto kGroupConv2d = "group_conv2d";
} // namespace topi
......
......@@ -5,6 +5,8 @@ from __future__ import absolute_import as _abs
from .conv2d_nchw import schedule_conv2d_nchw
from .conv2d_hwcn import schedule_conv2d_hwcn
from .depthwise_conv2d import schedule_depthwise_conv2d_nchw, schedule_depthwise_conv2d_nhwc
from .depthwise_conv2d import schedule_depthwise_conv2d_backward_input_nhwc
from .depthwise_conv2d import schedule_depthwise_conv2d_backward_weight_nhwc
from .reduction import schedule_reduce
from .broadcast import schedule_broadcast_to
from .softmax import schedule_softmax
......
......@@ -186,3 +186,101 @@ def schedule_depthwise_conv2d_nhwc(outs):
traverse(outs[0].op)
return s
def schedule_depthwise_conv2d_backward_input_nhwc(outs):
"""Schedule for depthwise_conv2d nhwc backward wrt input.
Parameters
----------
outs: Array of Tensor
The computation graph description of depthwise_conv2d
backward wrt input in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for depthwise_conv2d backward
wrt input with layout nhwc.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
def _schedule(Padded_out_grad, In_grad):
s[Padded_out_grad].compute_inline()
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis("threadIdx.x")
_, h, w, c = In_grad.op.axis
fused_hwc = s[In_grad].fuse(h, w, c)
xoc, xic = s[In_grad].split(fused_hwc, factor=128)
s[In_grad].bind(xoc, block_x)
s[In_grad].bind(xic, thread_x)
def traverse(OP):
# inline all one-to-one-mapping operators except the last stage (output)
if OP.tag == 'depthwise_conv2d_backward_input_nhwc':
Padded_out_grad = OP.input_tensors[0]
Dilated_out_grad = Padded_out_grad.op.input_tensors[0]
s[Dilated_out_grad].compute_inline()
In_grad = OP.output(0)
_schedule(Padded_out_grad, In_grad)
else:
raise ValueError("Depthwise conv backward wrt input for non-NHWC is not supported.")
traverse(outs[0].op)
return s
def schedule_depthwise_conv2d_backward_weight_nhwc(outs):
"""Schedule for depthwise_conv2d nhwc backward wrt weight.
Parameters
----------
outs: Array of Tensor
The computation graph description of depthwise_conv2d
backward wrt weight in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for depthwise_conv2d backward
wrt weight with layout nhwc.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
def _schedule(Weight_grad):
block_x = tvm.thread_axis("blockIdx.x")
thread_y = tvm.thread_axis("threadIdx.y")
thread_x = tvm.thread_axis("threadIdx.x")
db, dh, dw = Weight_grad.op.reduce_axis
fused_dbdhdw = s[Weight_grad].fuse(db, dh, dw)
_, ki = s[Weight_grad].split(fused_dbdhdw, factor=8)
BF = s.rfactor(Weight_grad, ki)
fused_fwcm = s[Weight_grad].fuse(*s[Weight_grad].op.axis)
xo, xi = s[Weight_grad].split(fused_fwcm, factor=32)
s[Weight_grad].bind(xi, thread_x)
s[Weight_grad].bind(xo, block_x)
s[Weight_grad].bind(s[Weight_grad].op.reduce_axis[0], thread_y)
s[BF].compute_at(s[Weight_grad], s[Weight_grad].op.reduce_axis[0])
def traverse(OP):
# inline all one-to-one-mapping operators except the last stage (output)
if OP.tag == 'depthwise_conv2d_backward_weight_nhwc':
Padded_in = OP.input_tensors[1]
s[Padded_in].compute_inline()
Weight_grad = OP.output(0)
_schedule(Weight_grad)
else:
raise ValueError("Depthwise conv backward wrt weight for non-NHWC is not supported.")
traverse(outs[0].op)
return s
# pylint: disable=invalid-name, unused-variable, too-many-locals
"""Depthwise Convolution operators"""
"""Depthwise convolution operators"""
from __future__ import absolute_import as _abs
import tvm
from .dilate import dilate
from .pad import pad
from .util import get_pad_tuple
from ..util import simplify
......@@ -55,6 +57,7 @@ def depthwise_conv2d_nchw(Input, Filter, stride, padding):
name='DepthwiseConv2d', tag="depthwise_conv2d_nchw")
return Output
def depthwise_conv2d_nhwc(Input, Filter, stride, padding):
"""Depthwise convolution nhwc forward operator.
......@@ -66,8 +69,8 @@ def depthwise_conv2d_nhwc(Input, Filter, stride, padding):
Filter : tvm.Tensor
4-D with shape [filter_height, filter_width, in_channel, channel_multiplier]
Stride : tvm.Tensor
1-D of size 2
stride : tuple of two ints
The spatial stride along height and width
padding : int or str
Padding size, or ['VALID', 'SAME']
......@@ -102,3 +105,105 @@ def depthwise_conv2d_nhwc(Input, Filter, stride, padding):
axis=[di, dj]),
name='DepthwiseConv2d', tag="depthwise_conv2d_nhwc")
return Output
def depthwise_conv2d_backward_input_nhwc(Filter, Out_grad, oshape, ishape, stride, padding):
"""Depthwise convolution nhwc backward wrt input operator.
Parameters
----------
Filter : tvm.Tensor
4-D with shape [filter_height, filter_width, in_channel, channel_multiplier]
Out_grad : tvm.Tensor
4-D with shape [batch, out_height, out_width, out_channel]
stride : tuple of two ints
The spatial stride along height and width
padding : int or str
Padding size, or ['VALID', 'SAME']
Returns
-------
Output : tvm.Tensor
4-D with shape [batch, in_height, in_width, in_channel]
"""
batch, in_h, in_w, in_c = ishape
_, out_h, out_w, out_c = oshape
filter_h, filter_w, _, channel_multiplier = Filter.shape
stride_h, stride_w = stride
dilated_out_grad = dilate(Out_grad, [1, stride_h, stride_w, 1], name='dilated_out_grad')
# padding params in forward propagation
fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(padding, (filter_h, filter_w))
# padding params in backward propagation
bpad_top = filter_h - 1 - fpad_top
bpad_bottom = (filter_h - 1 - fpad_bottom) + (stride_h - 1)
bpad_left = filter_w - 1 - fpad_left
bpad_right = (filter_w - 1 - fpad_right) + (stride_w - 1)
padded_out_grad = pad(dilated_out_grad, \
[0, bpad_top, bpad_left, 0], \
[0, bpad_bottom, bpad_right, 0], \
name='padded_out_grad')
dh = tvm.reduce_axis((0, filter_h), name='dh')
dw = tvm.reduce_axis((0, filter_w), name='dw')
dc = tvm.reduce_axis((0, channel_multiplier), name='dc')
In_grad = tvm.compute(
(batch, in_h, in_w, in_c),
lambda b, h, w, c: tvm.sum(padded_out_grad[b, h+dh, w+dw, c*channel_multiplier + dc] * \
Filter[filter_h-1-dh, filter_w-1-dw, c, dc],
axis=[dh, dw, dc]), tag='depthwise_conv2d_backward_input_nhwc')
return In_grad
def depthwise_conv2d_backward_weight_nhwc(Input, Out_grad, oshape, fshape, stride, padding):
"""Depthwise convolution nhwc backward wrt weight operator.
Parameters
----------
Input : tvm.Tensor
4-D with shape [batch, in_height, in_width, in_channel]
Out_grad : tvm.Tensor
4-D with shape [batch, out_height, out_width, out_channel]
stride : tuple of two ints
The spatial stride along height and width
padding : int or str
Padding size, or ['VALID', 'SAME']
Returns
-------
Output : tvm.Tensor
4-D with shape [filter_height, filter_width, in_channel, channel_multiplier]
"""
batch, out_h, out_w, out_c = oshape
filter_h, filter_w, _, channel_multiplier = fshape
in_c = Input.shape[3].value
stride_h, stride_w = stride
pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (filter_h, filter_w))
padded_in = pad(Input, \
[0, pad_top, pad_left, 0], \
[0, pad_bottom, pad_right, 0], \
name='padded_in')
dh = tvm.reduce_axis((0, Out_grad.shape[1].value), name='dh')
dw = tvm.reduce_axis((0, Out_grad.shape[2].value), name='dw')
db = tvm.reduce_axis((0, batch), name='db')
Weight_grad = tvm.compute(
(filter_h, filter_w, in_c, channel_multiplier),
lambda fh, fw, c, m: tvm.sum(
Out_grad[db, dh, dw, c*channel_multiplier+m%channel_multiplier] *
padded_in[db, fh+dh*stride_h, fw+dw*stride_w, c], axis=[db, dh, dw]),
tag='depthwise_conv2d_backward_weight_nhwc')
return Weight_grad
......@@ -15,11 +15,10 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
# placeholder
Input = tvm.placeholder((batch, in_channel, in_height, in_width), name='Input')
Filter = tvm.placeholder((filter_channel, channel_multiplier, filter_height, filter_width), name='Filter')
Stride = [stride_h, stride_w]
Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale')
Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift')
# declare
DepthwiseConv2d = topi.nn.depthwise_conv2d_nchw(Input, Filter, Stride, padding)
DepthwiseConv2d = topi.nn.depthwise_conv2d_nchw(Input, Filter, stride=[stride_h, stride_w], padding=padding)
ScaleShift = topi.nn.scale_shift_nchw(DepthwiseConv2d, Scale, Shift)
Relu = topi.nn.relu(ScaleShift)
# schedule
......@@ -97,11 +96,10 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu
# placeholder
Input = tvm.placeholder((batch, in_height, in_width, in_channel), name='Input')
Filter = tvm.placeholder((filter_height, filter_width,filter_channel, channel_multiplier), name='Filter')
Stride = [stride_h, stride_w]
Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale')
Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift')
# declare
DepthwiseConv2d = topi.nn.depthwise_conv2d_nhwc(Input, Filter, Stride, padding)
DepthwiseConv2d = topi.nn.depthwise_conv2d_nhwc(Input, Filter, stride=[stride_h, stride_w], padding=padding)
ScaleShift = topi.nn.scale_shift_nhwc(DepthwiseConv2d, Scale, Shift)
Relu = topi.nn.relu(ScaleShift)
# schedule
......
import tvm
import topi
import numpy as np
from tvm.contrib.pickle_memoize import memoize
from scipy import signal
from topi.util import get_const_tuple
from topi.nn.util import get_pad_tuple
from topi.cuda.depthwise_conv2d import schedule_depthwise_conv2d_backward_input_nhwc
def verify_depthwise_conv2d_back_input(batch, in_channel, in_h, channel_multiplier, filter_h, stride_h, padding_h):
in_w = in_h
filter_channel = in_channel
filter_w = filter_h
stride_w = stride_h
padding_w = padding_h
out_h = np.int((in_h+2*padding_h-filter_h)/stride_h+1)
out_w = np.int((in_w+2*padding_w-filter_w)/stride_w+1)
out_channel = in_channel * channel_multiplier
ishape = [batch, in_h, in_w, in_channel]
oshape = [batch, out_h, out_w, out_channel]
# placeholder
Out_grad = tvm.placeholder(oshape, name='Out_grad')
Filter = tvm.placeholder((filter_h, filter_w, filter_channel, channel_multiplier))
# declare
In_grad = topi.nn.depthwise_conv2d_backward_input_nhwc(Filter, Out_grad, oshape, ishape,
stride=[stride_h, stride_w], padding=[padding_h, padding_w])
# schedule
schedule = schedule_depthwise_conv2d_backward_input_nhwc(In_grad)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
ctx = tvm.context(device, 0)
# build the kernel
f = tvm.build(schedule, [Filter, Out_grad, In_grad], device)
# prepare pod type for test data closure
dtype = Out_grad.dtype
out_grad_shape = get_const_tuple(Out_grad.shape)
filter_shape = get_const_tuple(Filter.shape)
# use memoize to pickle the test data for next time use
@memoize("topi.tests.test_topi_depthwise_conv2d_backward_input.nhwc")
def get_ref_data():
out_grad_np = np.random.uniform(size=out_grad_shape).astype(dtype)
filter_np = np.random.uniform(size=filter_shape).astype(dtype)
dilated_out_grad_np = topi.testing.dilate_python(out_grad_np, [1, stride_h, stride_w, 1])
# padding params in forward propagation
fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple([padding_h, padding_w], (filter_h, filter_w))
# padding params in backward propagation
bpad_top = filter_h - 1 - fpad_top
bpad_bottom = (filter_h - 1 - fpad_bottom) + (stride_h - 1)
bpad_left = filter_w - 1 - fpad_left
bpad_right = (filter_w - 1 - fpad_right) + (stride_w - 1)
padded_out_grad = np.zeros((batch, dilated_out_grad_np.shape[1]+bpad_top+bpad_bottom,
dilated_out_grad_np.shape[2]+bpad_left+bpad_right, out_channel))
padded_out_grad[:, bpad_top:dilated_out_grad_np.shape[1]+bpad_top,
bpad_left:dilated_out_grad_np.shape[2]+bpad_left, :] = dilated_out_grad_np
in_grad_np = np.zeros((batch, in_h, in_w, in_channel))
for b in range(batch):
for c in range(in_channel):
for m in range(channel_multiplier):
in_grad_np[b, :, :, c] += signal.convolve2d(padded_out_grad[b, :, :, c*channel_multiplier+m], \
filter_np[:, :, c, m], mode='valid')[0:in_h, 0:in_w]
return (out_grad_np, filter_np, in_grad_np)
(out_grad_np, filter_np, in_grad_np) = get_ref_data()
out_grad_tvm = tvm.nd.array(out_grad_np, ctx)
filter_tvm = tvm.nd.array(filter_np, ctx)
in_grad_tvm = tvm.nd.array(np.zeros(shape=ishape, dtype=dtype), ctx)
# launch the kernel
timer = f.time_evaluator(f.entry_name, ctx, number=1)
tcost = timer(filter_tvm, out_grad_tvm, in_grad_tvm).mean
np.testing.assert_allclose(in_grad_np, in_grad_tvm.asnumpy(), rtol=1e-5)
check_device("opencl")
check_device("cuda")
check_device("metal")
def test_topi_depthwise_conv2d_backward_input_nhwc():
verify_depthwise_conv2d_back_input(16, 256, 56, 1, 3, 1, 1)
verify_depthwise_conv2d_back_input(16, 256, 56, 2, 3, 1, 1)
verify_depthwise_conv2d_back_input(16, 256, 56, 1, 5, 1, 2)
verify_depthwise_conv2d_back_input(16, 256, 56, 2, 5, 1, 2)
verify_depthwise_conv2d_back_input(16, 256, 56, 1, 3, 2, 1)
verify_depthwise_conv2d_back_input(16, 256, 56, 2, 3, 2, 1)
verify_depthwise_conv2d_back_input(16, 256, 56, 1, 5, 2, 2)
verify_depthwise_conv2d_back_input(16, 256, 56, 2, 5, 2, 2)
verify_depthwise_conv2d_back_input(16, 256, 56, 1, 3, 1, 0)
verify_depthwise_conv2d_back_input(16, 256, 56, 2, 3, 1, 0)
verify_depthwise_conv2d_back_input(16, 256, 56, 1, 5, 1, 0)
verify_depthwise_conv2d_back_input(16, 256, 56, 2, 5, 1, 0)
verify_depthwise_conv2d_back_input(16, 256, 56, 1, 3, 2, 0)
verify_depthwise_conv2d_back_input(16, 256, 56, 2, 3, 2, 0)
verify_depthwise_conv2d_back_input(16, 256, 56, 1, 5, 2, 0)
verify_depthwise_conv2d_back_input(16, 256, 56, 2, 5, 2, 0)
if __name__ == "__main__":
test_topi_depthwise_conv2d_backward_input_nhwc()
import tvm
import topi
import numpy as np
from tvm.contrib.pickle_memoize import memoize
from scipy import signal
from topi.util import get_const_tuple
from topi.nn.util import get_pad_tuple
from topi.cuda.depthwise_conv2d import schedule_depthwise_conv2d_backward_weight_nhwc
def verify_depthwise_conv2d_back_weight(batch, in_channel, in_h, channel_multiplier, filter_h, stride_h, padding_h):
in_w = in_h
filter_channel = in_channel
filter_w = filter_h
stride_w = stride_h
padding_w = padding_h
out_h = np.int((in_h+2*padding_h-filter_h)/stride_h+1)
out_w = np.int((in_w+2*padding_w-filter_w)/stride_w+1)
out_channel = in_channel * channel_multiplier
oshape = [batch, out_h, out_w, out_channel]
fshape = [filter_h, filter_w, in_channel, channel_multiplier]
# placeholder
Out_grad = tvm.placeholder(oshape, name='Out_grad')
Input = tvm.placeholder((batch, in_h, in_w, in_channel), name='In_grad')
# declare
Weight_grad = topi.nn.depthwise_conv2d_backward_weight_nhwc(Input, Out_grad, oshape, fshape,
stride=[stride_h, stride_w], padding=[padding_h, padding_w])
# schedule
schedule = schedule_depthwise_conv2d_backward_weight_nhwc(Weight_grad)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
ctx = tvm.context(device, 0)
# build the kernel
f = tvm.build(schedule, [Input, Out_grad, Weight_grad], device)
# prepare pod type for test data closure
dtype = Out_grad.dtype
out_grad_shape = get_const_tuple(Out_grad.shape)
in_shape = get_const_tuple(Input.shape)
# use memoize to pickle the test data for next time use
@memoize("topi.tests.test_topi_depthwise_conv2d_backward_weight.nhwc")
def get_ref_data():
out_grad_np = np.random.uniform(size=out_grad_shape).astype(dtype)
input_np = np.random.uniform(size=in_shape).astype(dtype)
dilated_out_grad_np = topi.testing.dilate_python(out_grad_np, [1, stride_h, stride_w, 1])
pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple([padding_h, padding_w], (filter_h, filter_w))
padded_input_np = np.zeros((batch, in_h+pad_top+pad_bottom, in_w+pad_left+pad_right, in_channel))
padded_input_np[:, pad_top:in_h+pad_top, pad_left:in_w+pad_left, :] = input_np
weight_grad_np = np.zeros((filter_h, filter_w, in_channel, channel_multiplier))
for c in range(in_channel):
for m in range(channel_multiplier):
for b in range(batch):
weight_grad_np[:, :, c, m] += signal.convolve2d(padded_input_np[b, :, :, c], \
np.rot90(dilated_out_grad_np[b, :, :, c*channel_multiplier+m%channel_multiplier], 2), \
mode='valid')[0:filter_h, 0:filter_w]
return (out_grad_np, input_np, weight_grad_np)
(out_grad_np, input_np, weight_grad_np) = get_ref_data()
out_grad_tvm = tvm.nd.array(out_grad_np, ctx)
input_tvm = tvm.nd.array(input_np, ctx)
weight_grad_tvm = tvm.nd.array(np.zeros(shape=fshape, dtype=dtype), ctx)
# launch the kernel
timer = f.time_evaluator(f.entry_name, ctx, number=1)
tcost = timer(input_tvm, out_grad_tvm, weight_grad_tvm).mean
np.testing.assert_allclose(weight_grad_np, weight_grad_tvm.asnumpy(), rtol=1e-4)
check_device("opencl")
check_device("cuda")
check_device("metal")
def test_topi_depthwise_conv2d_backward_weight_nhwc():
verify_depthwise_conv2d_back_weight(16, 256, 56, 1, 3, 1, 1)
verify_depthwise_conv2d_back_weight(16, 256, 56, 2, 3, 1, 1)
verify_depthwise_conv2d_back_weight(16, 256, 56, 1, 5, 1, 2)
verify_depthwise_conv2d_back_weight(16, 256, 56, 2, 5, 1, 2)
verify_depthwise_conv2d_back_weight(16, 256, 56, 1, 3, 2, 1)
verify_depthwise_conv2d_back_weight(16, 256, 56, 2, 3, 2, 1)
verify_depthwise_conv2d_back_weight(16, 256, 56, 1, 5, 2, 2)
verify_depthwise_conv2d_back_weight(16, 256, 56, 2, 5, 2, 2)
verify_depthwise_conv2d_back_weight(16, 256, 56, 1, 3, 1, 0)
verify_depthwise_conv2d_back_weight(16, 256, 56, 2, 3, 1, 0)
verify_depthwise_conv2d_back_weight(16, 256, 56, 1, 5, 1, 0)
verify_depthwise_conv2d_back_weight(16, 256, 56, 2, 5, 1, 0)
verify_depthwise_conv2d_back_weight(16, 256, 56, 1, 3, 2, 0)
verify_depthwise_conv2d_back_weight(16, 256, 56, 2, 3, 2, 0)
verify_depthwise_conv2d_back_weight(16, 256, 56, 1, 5, 2, 0)
verify_depthwise_conv2d_back_weight(15, 256, 56, 2, 5, 2, 0)
if __name__ == "__main__":
test_topi_depthwise_conv2d_backward_weight_nhwc()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment