miopen.py 3.72 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
"""External function interface to MIOpen library."""
# pylint: disable-msg=C0103
import ctypes
import numpy as np
from .. import api as _api
from .. import intrin as _intrin
from .. import get_global_func as _get_global_func


def _get_np_int32_array_handle(arr):
    """Return a void_p handle for a numpy array

    Parameters
    ----------
    arr: numpy.NDArray
        source numpy array

    Returns
    -------
    ptr:  ctypes.c_void_p
        pointer to the data
    """
    assert arr.dtype == np.int32
    ptr = arr.ctypes.data_as(ctypes.POINTER(ctypes.c_int32))
    return ctypes.cast(ptr, ctypes.c_void_p)


def conv2d_forward(x,
                   w,
                   stride_h=1,
                   stride_w=1,
                   pad_h=0,
                   pad_w=0,
                   dilation_h=1,
                   dilation_w=1,
52
                   conv_mode=0,
53 54
                   data_type=1,
                   group_count=1):
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
    """Create an extern op that compute 2D convolution with MIOpen

    Parameters
    ----------
    x: Tensor
        input feature map
    w: Tensor
        convolution weight
    stride_h: int
        height stride
    stride_w: int
        width stride
    pad_h: int
        height pad
    pad_w: int
        weight pad
    dilation_h: int
        height dilation
    dilation_w: int
        width dilation
    conv_mode: int
        0: miopenConvolution
        1: miopenTranspose
78 79 80
    data_type: int
        0: miopenHalf (fp16)
        1: miopenFloat (fp32)
81 82
    group_count: int
        number of groups
83 84 85 86 87
    Returns
    -------
    y: Tensor
        The result tensor
    """
88 89 90
    assert (0 <= conv_mode <= 2), "0: miopenConvolution / 1: miopenTranspose / 2: miopenGroupConv"
    if group_count > 1:
        conv_mode = 2
91 92 93 94 95
    oshape = np.zeros((len(x.shape)), dtype=np.int32)
    xshape = x.shape
    wshape = w.shape
    setup_func = _get_global_func("tvm.contrib.miopen.conv2d.setup")
    algo = setup_func(conv_mode,
96
                      data_type,
97 98 99 100 101 102 103 104 105 106 107 108 109 110
                      pad_h,
                      pad_w,
                      stride_h,
                      stride_w,
                      dilation_h,
                      dilation_w,
                      xshape[0].value,
                      xshape[1].value,
                      xshape[2].value,
                      xshape[3].value,
                      wshape[0].value,
                      wshape[1].value,
                      wshape[2].value,
                      wshape[3].value,
111
                      group_count,
112 113 114 115 116 117 118
                      _get_np_int32_array_handle(oshape))

    return _api.extern(
        list(oshape), [x, w],
        lambda ins, outs: _intrin.call_packed(
            "tvm.contrib.miopen.conv2d.forward",
            conv_mode,
119
            data_type,
120 121 122 123 124 125 126 127 128 129
            pad_h,
            pad_w,
            stride_h,
            stride_w,
            dilation_h,
            dilation_w,
            algo,
            ins[0],
            ins[1],
            outs[0]), name="y")