rocm.py 4.22 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18
"""Utility for ROCm backend"""
import subprocess
19
from os.path import join, exists
20
from . import util
21
from .._ffi.base import py_str
22
from .. import codegen
23
from ..api import register_func, convert
24

25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
def find_lld(required=True):
    """Find ld.lld in system.

    Parameters
    ----------
    required : bool
        Whether it is required,
        runtime error will be raised if the compiler is required.

    Returns
    -------
    valid_list : list of str
        List of possible paths.

    Note
    ----
    This function will first search ld.lld that
    matches the major llvm version that built with tvm
    """
    lld_list = []
    if hasattr(codegen, "llvm_version_major"):
        major = codegen.llvm_version_major()
        lld_list += ["ld.lld-%d.0" % major]
        lld_list += ["ld.lld-%d" % major]
    lld_list += ["lld"]
    valid_list = [util.which(x) for x in lld_list]
    valid_list = [x for x in valid_list if x]
    if not valid_list and required:
        raise RuntimeError(
            "cannot find ld.lld, candidates are: " + str(lld_list))
    return valid_list


def rocm_link(in_file, out_file, lld=None):
59 60 61 62 63 64 65 66 67
    """Link relocatable ELF object to shared ELF object using lld

    Parameters
    ----------
    in_file : str
        Input file name (relocatable ELF object file)

    out_file : str
        Output file name (shared ELF object file)
68 69 70 71

    lld : str, optional
        The lld linker, if not specified,
        we will try to guess the matched clang version.
72
    """
73
    args = [lld if lld is not None else find_lld()[0], "-shared", in_file, "-o", out_file]
74 75 76 77 78 79 80 81
    proc = subprocess.Popen(
        args,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT)
    (out, _) = proc.communicate()

    if proc.returncode != 0:
        msg = "Linking error using ld.lld:\n"
82
        msg += py_str(out)
83 84
        raise RuntimeError(msg)

85

86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
@register_func("tvm_callback_rocm_link")
def callback_rocm_link(obj_bin):
    """Links object file generated from LLVM to HSA Code Object

    Parameters
    ----------
    obj_bin : bytearray
        The object file

    Return
    ------
    cobj_bin : bytearray
        The HSA Code Object
    """
    tmp_dir = util.tempdir()
    tmp_obj = tmp_dir.relpath("rocm_kernel.o")
    tmp_cobj = tmp_dir.relpath("rocm_kernel.co")
    with open(tmp_obj, "wb") as out_file:
        out_file.write(bytes(obj_bin))
    rocm_link(tmp_obj, tmp_cobj)
    cobj_bin = bytearray(open(tmp_cobj, "rb").read())
    return cobj_bin
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135

@register_func("tvm_callback_rocm_bitcode_path")
def callback_rocm_bitcode_path(rocdl_dir="/opt/rocm/lib/"):
    """Utility function to find ROCm device library bitcodes

    Parameters
    ----------
    rocdl_dir : str
        The path to rocm library directory
        The default value is the standard location
    """
    # seems link order matters.
    bitcode_files = [
        "oclc_daz_opt_on.amdgcn.bc",
        "ocml.amdgcn.bc",
        "hc.amdgcn.bc",
        "irif.amdgcn.bc",
        "ockl.amdgcn.bc",
        "oclc_correctly_rounded_sqrt_off.amdgcn.bc",
        "oclc_correctly_rounded_sqrt_on.amdgcn.bc",
        "oclc_daz_opt_off.amdgcn.bc",
        "oclc_finite_only_off.amdgcn.bc",
        "oclc_finite_only_on.amdgcn.bc",
        "oclc_isa_version_803.amdgcn.bc",
        "oclc_isa_version_900.amdgcn.bc",
        "oclc_unsafe_math_off.amdgcn.bc",
        "oclc_unsafe_math_on.amdgcn.bc"
    ]
136 137
    paths = [join(rocdl_dir, bitcode) for bitcode in bitcode_files]
    return convert([path for path in paths if exists(path)])