miopen.py 3.66 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18 19 20
"""External function interface to MIOpen library."""
# pylint: disable-msg=C0103
import ctypes
import numpy as np
21
import tvm
22 23 24
import tvm._ffi

from tvm import te
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52


def _get_np_int32_array_handle(arr):
    """Return a void_p handle for a numpy array

    Parameters
    ----------
    arr: numpy.NDArray
        source numpy array

    Returns
    -------
    ptr:  ctypes.c_void_p
        pointer to the data
    """
    assert arr.dtype == np.int32
    ptr = arr.ctypes.data_as(ctypes.POINTER(ctypes.c_int32))
    return ctypes.cast(ptr, ctypes.c_void_p)


def conv2d_forward(x,
                   w,
                   stride_h=1,
                   stride_w=1,
                   pad_h=0,
                   pad_w=0,
                   dilation_h=1,
                   dilation_w=1,
53
                   conv_mode=0,
54 55
                   data_type=1,
                   group_count=1):
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
    """Create an extern op that compute 2D convolution with MIOpen

    Parameters
    ----------
    x: Tensor
        input feature map
    w: Tensor
        convolution weight
    stride_h: int
        height stride
    stride_w: int
        width stride
    pad_h: int
        height pad
    pad_w: int
        weight pad
    dilation_h: int
        height dilation
    dilation_w: int
        width dilation
    conv_mode: int
        0: miopenConvolution
        1: miopenTranspose
79 80 81
    data_type: int
        0: miopenHalf (fp16)
        1: miopenFloat (fp32)
82 83
    group_count: int
        number of groups
84 85 86 87 88
    Returns
    -------
    y: Tensor
        The result tensor
    """
89 90 91
    assert (0 <= conv_mode <= 2), "0: miopenConvolution / 1: miopenTranspose / 2: miopenGroupConv"
    if group_count > 1:
        conv_mode = 2
92 93 94
    oshape = np.zeros((len(x.shape)), dtype=np.int32)
    xshape = x.shape
    wshape = w.shape
95
    setup_func = tvm._ffi.get_global_func("tvm.contrib.miopen.conv2d.setup")
96
    algo = setup_func(conv_mode,
97
                      data_type,
98 99 100 101 102 103 104 105 106 107 108 109 110 111
                      pad_h,
                      pad_w,
                      stride_h,
                      stride_w,
                      dilation_h,
                      dilation_w,
                      xshape[0].value,
                      xshape[1].value,
                      xshape[2].value,
                      xshape[3].value,
                      wshape[0].value,
                      wshape[1].value,
                      wshape[2].value,
                      wshape[3].value,
112
                      group_count,
113 114
                      _get_np_int32_array_handle(oshape))

115
    return te.extern(
116
        list(oshape), [x, w],
117
        lambda ins, outs: tvm.tir.call_packed(
118 119
            "tvm.contrib.miopen.conv2d.forward",
            conv_mode,
120
            data_type,
121 122 123 124 125 126 127 128 129 130
            pad_h,
            pad_w,
            stride_h,
            stride_w,
            dilation_h,
            dilation_w,
            algo,
            ins[0],
            ins[1],
            outs[0]), name="y")