nvcc.py 7.37 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
# pylint: disable=invalid-name
18
"""Utility to invoke nvcc compiler in the system"""
19
from __future__ import absolute_import as _abs
20

21
import subprocess
22 23
import os
import warnings
24
from . import util
25
from .. import ndarray as nd
26 27
from ..api import register_func
from .._ffi.base import py_str
28

29 30 31 32 33
def compile_cuda(code,
                 target="ptx",
                 arch=None,
                 options=None,
                 path_target=None):
34 35 36 37 38 39 40
    """Compile cuda code with NVCC from env.

    Parameters
    ----------
    code : str
        The cuda code.

41
    target : str
42 43
        The target format

44 45 46
    arch : str
        The architecture

47
    options : str or list of str
48 49
        The additional options

50 51 52
    path_target : str, optional
        Output file.

53 54 55 56 57
    Return
    ------
    cubin : bytearray
        The bytearray of the cubin
    """
58
    temp = util.tempdir()
59 60
    if target not in ["cubin", "ptx", "fatbin"]:
        raise ValueError("target must be in cubin, ptx, fatbin")
61 62
    temp_code = temp.relpath("my_kernel.cu")
    temp_target = temp.relpath("my_kernel.%s" % target)
63

64
    with open(temp_code, "w") as out_file:
65
        out_file.write(code)
66 67 68 69 70 71 72

    if arch is None:
        if nd.gpu(0).exist:
            # auto detect the compute arch argument
            arch = "sm_" + "".join(nd.gpu(0).compute_version.split('.'))
        else:
            raise ValueError("arch(sm_xy) is not passed, and we cannot detect it from env")
73

74
    file_target = path_target if path_target else temp_target
75 76
    cmd = ["nvcc"]
    cmd += ["--%s" % target, "-O3"]
77
    cmd += ["-arch", arch]
78

79
    if options:
80 81 82 83 84 85 86 87
        if isinstance(options, str):
            cmd += [options]
        elif isinstance(options, list):
            cmd += options
        else:
            raise ValueError("options must be str or list of str")

    cmd += ["-o", file_target]
88
    cmd += [temp_code]
89 90

    proc = subprocess.Popen(
91 92
        cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)

93 94 95
    (out, _) = proc.communicate()

    if proc.returncode != 0:
96
        msg = "Compilation error:\n"
97
        msg += py_str(out)
98 99
        raise RuntimeError(msg)

100 101 102 103 104
    data = bytearray(open(file_target, "rb").read())
    if not data:
        raise RuntimeError(
            "Compilation error: empty result is generated")
    return data
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121

def find_cuda_path():
    """Utility function to find cuda path

    Returns
    -------
    path : str
        Path to cuda root.
    """
    if "CUDA_PATH" in os.environ:
        return os.environ["CUDA_PATH"]
    cmd = ["which", "nvcc"]
    proc = subprocess.Popen(
        cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
    (out, _) = proc.communicate()
    out = py_str(out)
    if proc.returncode == 0:
122
        return os.path.realpath(os.path.join(str(out).strip(), "../.."))
123 124 125 126 127 128
    cuda_path = "/usr/local/cuda"
    if os.path.exists(os.path.join(cuda_path, "bin/nvcc")):
        return cuda_path
    raise RuntimeError("Cannot find cuda path")


129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
def get_cuda_version(cuda_path):
    """Utility function to get cuda version

    Parameters
    ----------
    cuda_path : str
        Path to cuda root.

    Returns
    -------
    version : float
        The cuda version
    """
    version_file_path = os.path.join(cuda_path, "version.txt")
    try:
        with open(version_file_path) as f:
            version_str = f.readline().replace('\n', '').replace('\r', '')
            return float(version_str.split(" ")[2][:2])
    except:
        raise RuntimeError("Cannot read cuda version file")


151 152 153 154 155 156 157 158
@register_func("tvm_callback_libdevice_path")
def find_libdevice_path(arch):
    """Utility function to find libdevice

    Parameters
    ----------
    arch : int
        The compute architecture in int
159 160 161 162 163

    Returns
    -------
    path : str
        Path to libdevice.
164 165 166 167 168
    """
    cuda_path = find_cuda_path()
    lib_path = os.path.join(cuda_path, "nvvm/libdevice")
    selected_ver = 0
    selected_path = None
169
    cuda_ver = get_cuda_version(cuda_path)
170
    if cuda_ver in (9.0, 9.1, 10.0):
171 172 173 174 175 176
        path = os.path.join(lib_path, "libdevice.10.bc")
    else:
        for fn in os.listdir(lib_path):
            if not fn.startswith("libdevice"):
                continue
            ver = int(fn.split(".")[-3].split("_")[-1])
177
            if selected_ver < ver <= arch:
178 179 180 181 182 183
                selected_ver = ver
                selected_path = fn
        if selected_path is None:
            raise RuntimeError("Cannot find libdevice for arch {}".format(arch))
        path = os.path.join(lib_path, selected_path)
    return path
184 185 186 187 188 189 190 191


def callback_libdevice_path(arch):
    try:
        return find_libdevice_path(arch)
    except RuntimeError:
        warnings.warn("Cannot find libdevice path")
        return ""
192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266


def parse_compute_version(compute_version):
    """Parse compute capability string to divide major and minor version

    Parameters
    ----------
    compute_version : str
        compute capability of a GPU (e.g. "6.0")

    Returns
    -------
    major : int
        major version number
    minor : int
        minor version number
    """
    split_ver = compute_version.split('.')
    try:
        major = int(split_ver[0])
        minor = int(split_ver[1])
        return major, minor
    except (IndexError, ValueError) as err:
        raise RuntimeError("Compute version parsing error: " + str(err))


def have_fp16(compute_version):
    """Either fp16 support is provided in the compute capability or not

    Parameters
    ----------
    compute_version: str
        compute capability of a GPU (e.g. "6.0")
    """
    major, minor = parse_compute_version(compute_version)
    # fp 16 support in reference to:
    # https://docs.nvidia.com/cuda/cuda-c-programming-guide/#arithmetic-instructions
    if major == 5 and minor == 3:
        return True
    # NOTE: exclude compute capability 6.1 devices although it is actually available
    #       to compute fp16, because these devices only have low-rate fp16 performance.
    if major == 6 and minor != 1:
        return True
    if major == 7:
        return True

    return False

def have_int8(compute_version):
    """Either int8 support is provided in the compute capability or not

    Parameters
    ----------
    compute_version : str
        compute capability of a GPU (e.g. "6.1")
    """
    major, minor = parse_compute_version(compute_version)
    if major == 6 and minor == 1:
        return True

    return False

def have_tensorcore(compute_version):
    """Either TensorCore support is provided in the compute capability or not

    Parameters
    ----------
    compute_version : str
        compute capability of a GPU (e.g. "7.0")
    """
    major, _ = parse_compute_version(compute_version)
    if major == 7:
        return True

    return False