Commit aede4820 by Tatsuya Nishiyama Committed by Tianqi Chen

Fix error during running on nvptx with cuda9 (#1162)

parent 1b14fd19
...@@ -104,6 +104,28 @@ def find_cuda_path(): ...@@ -104,6 +104,28 @@ def find_cuda_path():
raise RuntimeError("Cannot find cuda path") raise RuntimeError("Cannot find cuda path")
def get_cuda_version(cuda_path):
"""Utility function to get cuda version
Parameters
----------
cuda_path : str
Path to cuda root.
Returns
-------
version : float
The cuda version
"""
version_file_path = os.path.join(cuda_path, "version.txt")
try:
with open(version_file_path) as f:
version_str = f.readline().replace('\n', '').replace('\r', '')
return float(version_str.split(" ")[2][:2])
except:
raise RuntimeError("Cannot read cuda version file")
@register_func("tvm_callback_libdevice_path") @register_func("tvm_callback_libdevice_path")
def find_libdevice_path(arch): def find_libdevice_path(arch):
"""Utility function to find libdevice """Utility function to find libdevice
...@@ -112,22 +134,31 @@ def find_libdevice_path(arch): ...@@ -112,22 +134,31 @@ def find_libdevice_path(arch):
---------- ----------
arch : int arch : int
The compute architecture in int The compute architecture in int
Returns
-------
path : str
Path to libdevice.
""" """
cuda_path = find_cuda_path() cuda_path = find_cuda_path()
lib_path = os.path.join(cuda_path, "nvvm/libdevice") lib_path = os.path.join(cuda_path, "nvvm/libdevice")
selected_ver = 0 selected_ver = 0
selected_path = None selected_path = None
cuda_ver = get_cuda_version(cuda_path)
for fn in os.listdir(lib_path): if cuda_ver == 9.0 or cuda_ver == 9.1:
if not fn.startswith("libdevice"): path = os.path.join(lib_path, "libdevice.10.bc")
continue else:
ver = int(fn.split(".")[-3].split("_")[-1]) for fn in os.listdir(lib_path):
if ver > selected_ver and ver <= arch: if not fn.startswith("libdevice"):
selected_ver = ver continue
selected_path = fn ver = int(fn.split(".")[-3].split("_")[-1])
if selected_path is None: if ver > selected_ver and ver <= arch:
raise RuntimeError("Cannot find libdevice for arch {}".format(arch)) selected_ver = ver
return os.path.join(lib_path, selected_path) selected_path = fn
if selected_path is None:
raise RuntimeError("Cannot find libdevice for arch {}".format(arch))
path = os.path.join(lib_path, selected_path)
return path
def callback_libdevice_path(arch): def callback_libdevice_path(arch):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment