test_topi_depthwise_conv2d_back_input.py 5.88 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18 19 20 21 22 23
import tvm
import topi
import numpy as np
from tvm.contrib.pickle_memoize import memoize
from scipy import signal
from topi.util import get_const_tuple
from topi.nn.util import get_pad_tuple
24
import topi.testing
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
from topi.cuda.depthwise_conv2d import schedule_depthwise_conv2d_backward_input_nhwc


def verify_depthwise_conv2d_back_input(batch, in_channel, in_h, channel_multiplier, filter_h, stride_h, padding_h):
    in_w = in_h
    filter_channel = in_channel
    filter_w = filter_h
    stride_w = stride_h
    padding_w = padding_h

    out_h = np.int((in_h+2*padding_h-filter_h)/stride_h+1)
    out_w = np.int((in_w+2*padding_w-filter_w)/stride_w+1)
    out_channel = in_channel * channel_multiplier

    ishape = [batch, in_h, in_w, in_channel]
    oshape = [batch, out_h, out_w, out_channel]

    # placeholder
    Out_grad = tvm.placeholder(oshape, name='Out_grad')
    Filter = tvm.placeholder((filter_h, filter_w, filter_channel, channel_multiplier))
    # declare
    In_grad = topi.nn.depthwise_conv2d_backward_input_nhwc(Filter, Out_grad, oshape, ishape,
        stride=[stride_h, stride_w], padding=[padding_h, padding_w])
    # schedule
    schedule = schedule_depthwise_conv2d_backward_input_nhwc(In_grad)

    def check_device(device):
52 53
        ctx = tvm.context(device, 0)
        if not ctx.exist:
54 55
            print("Skip because %s is not enabled" % device)
            return
56
        print("Running on target: %s" % device)
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
        # build the kernel
        f = tvm.build(schedule, [Filter, Out_grad, In_grad], device)
        # prepare pod type for test data closure
        dtype = Out_grad.dtype
        out_grad_shape = get_const_tuple(Out_grad.shape)
        filter_shape = get_const_tuple(Filter.shape)

        # use memoize to pickle the test data for next time use
        @memoize("topi.tests.test_topi_depthwise_conv2d_backward_input.nhwc")
        def get_ref_data():
            out_grad_np = np.random.uniform(size=out_grad_shape).astype(dtype)
            filter_np = np.random.uniform(size=filter_shape).astype(dtype)
            dilated_out_grad_np = topi.testing.dilate_python(out_grad_np, [1, stride_h, stride_w, 1])
            # padding params in forward propagation
            fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple([padding_h, padding_w], (filter_h, filter_w))
            # padding params in backward propagation
            bpad_top = filter_h - 1 - fpad_top
            bpad_bottom = (filter_h - 1 - fpad_bottom) + (stride_h - 1)
            bpad_left = filter_w - 1 - fpad_left
            bpad_right = (filter_w - 1 - fpad_right) + (stride_w - 1)

            padded_out_grad = np.zeros((batch, dilated_out_grad_np.shape[1]+bpad_top+bpad_bottom,
                dilated_out_grad_np.shape[2]+bpad_left+bpad_right, out_channel))
            padded_out_grad[:, bpad_top:dilated_out_grad_np.shape[1]+bpad_top,
                bpad_left:dilated_out_grad_np.shape[2]+bpad_left, :] = dilated_out_grad_np

            in_grad_np = np.zeros((batch, in_h, in_w, in_channel))
            for b in range(batch):
                for c in range(in_channel):
                    for m in range(channel_multiplier):
                        in_grad_np[b, :, :, c] += signal.convolve2d(padded_out_grad[b, :, :, c*channel_multiplier+m], \
                                filter_np[:, :, c, m], mode='valid')[0:in_h, 0:in_w]
            return (out_grad_np, filter_np, in_grad_np)

        (out_grad_np, filter_np, in_grad_np) = get_ref_data()

        out_grad_tvm = tvm.nd.array(out_grad_np, ctx)
        filter_tvm = tvm.nd.array(filter_np, ctx)
        in_grad_tvm = tvm.nd.array(np.zeros(shape=ishape, dtype=dtype), ctx)
        # launch the kernel
        timer = f.time_evaluator(f.entry_name, ctx, number=1)
        tcost = timer(filter_tvm, out_grad_tvm, in_grad_tvm).mean
99
        tvm.testing.assert_allclose(in_grad_np, in_grad_tvm.asnumpy(), rtol=1e-5)
100 101 102 103

    check_device("opencl")
    check_device("cuda")
    check_device("metal")
104
    check_device("rocm")
105
    check_device("vulkan")
106
    check_device("nvptx")
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129

def test_topi_depthwise_conv2d_backward_input_nhwc():
    verify_depthwise_conv2d_back_input(16, 256, 56, 1, 3, 1, 1)
    verify_depthwise_conv2d_back_input(16, 256, 56, 2, 3, 1, 1)
    verify_depthwise_conv2d_back_input(16, 256, 56, 1, 5, 1, 2)
    verify_depthwise_conv2d_back_input(16, 256, 56, 2, 5, 1, 2)
    verify_depthwise_conv2d_back_input(16, 256, 56, 1, 3, 2, 1)
    verify_depthwise_conv2d_back_input(16, 256, 56, 2, 3, 2, 1)
    verify_depthwise_conv2d_back_input(16, 256, 56, 1, 5, 2, 2)
    verify_depthwise_conv2d_back_input(16, 256, 56, 2, 5, 2, 2)

    verify_depthwise_conv2d_back_input(16, 256, 56, 1, 3, 1, 0)
    verify_depthwise_conv2d_back_input(16, 256, 56, 2, 3, 1, 0)
    verify_depthwise_conv2d_back_input(16, 256, 56, 1, 5, 1, 0)
    verify_depthwise_conv2d_back_input(16, 256, 56, 2, 5, 1, 0)
    verify_depthwise_conv2d_back_input(16, 256, 56, 1, 3, 2, 0)
    verify_depthwise_conv2d_back_input(16, 256, 56, 2, 3, 2, 0)
    verify_depthwise_conv2d_back_input(16, 256, 56, 1, 5, 2, 0)
    verify_depthwise_conv2d_back_input(16, 256, 56, 2, 5, 2, 0)


if __name__ == "__main__":
    test_topi_depthwise_conv2d_backward_input_nhwc()