Commit c468558e by Tianqi Chen Committed by GitHub

[CUDA] auto detect compatibility when arch is not passed (#490)

parent c6a20452
...@@ -18,7 +18,8 @@ namespace runtime { ...@@ -18,7 +18,8 @@ namespace runtime {
enum DeviceAttrKind : int { enum DeviceAttrKind : int {
kExist = 0, kExist = 0,
kMaxThreadsPerBlock = 1, kMaxThreadsPerBlock = 1,
kWarpSize = 2 kWarpSize = 2,
kComputeVersion = 3
}; };
/*! \brief Number of bytes each allocation must align to */ /*! \brief Number of bytes each allocation must align to */
......
...@@ -131,6 +131,20 @@ class TVMContext(ctypes.Structure): ...@@ -131,6 +131,20 @@ class TVMContext(ctypes.Structure):
return _api_internal._GetDeviceAttr( return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 2) self.device_type, self.device_id, 2)
@property
def compute_version(self):
"""Get compute verison number in string.
Currently used to get compute capability of CUDA device.
Returns
-------
version : str
The version string in `major.minor` format.
"""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 3)
def sync(self): def sync(self):
"""Synchronize until jobs finished at the context.""" """Synchronize until jobs finished at the context."""
check_call(_LIB.TVMSynchronize(self.device_type, self.device_id, None)) check_call(_LIB.TVMSynchronize(self.device_type, self.device_id, None))
......
...@@ -39,11 +39,8 @@ def create_shared(output, ...@@ -39,11 +39,8 @@ def create_shared(output,
if options: if options:
cmd += options cmd += options
args = ' '.join(cmd)
proc = subprocess.Popen( proc = subprocess.Popen(
args, shell=True, cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
(out, _) = proc.communicate() (out, _) = proc.communicate()
if proc.returncode != 0: if proc.returncode != 0:
......
# pylint: disable=invalid-name # pylint: disable=invalid-name
"""Utility to invoke nvcc compiler in the system""" """Utility to invoke nvcc compiler in the system"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import sys
import subprocess import subprocess
from . import util from . import util
from .. import ndarray as nd
def compile_cuda(code, target="ptx", arch=None, def compile_cuda(code,
options=None, path_target=None): target="ptx",
arch=None,
options=None,
path_target=None):
"""Compile cuda code with NVCC from env. """Compile cuda code with NVCC from env.
Parameters Parameters
...@@ -39,32 +43,32 @@ def compile_cuda(code, target="ptx", arch=None, ...@@ -39,32 +43,32 @@ def compile_cuda(code, target="ptx", arch=None,
with open(temp_code, "w") as out_file: with open(temp_code, "w") as out_file:
out_file.write(code) out_file.write(code)
if target == "cubin" and arch is None:
raise ValueError("arch(sm_xy) must be passed for generating cubin") if arch is None:
if nd.gpu(0).exist:
# auto detect the compute arch argument
arch = "sm_" + "".join(nd.gpu(0).compute_version.split('.'))
else:
raise ValueError("arch(sm_xy) is not passed, and we cannot detect it from env")
file_target = path_target if path_target else temp_target file_target = path_target if path_target else temp_target
cmd = ["nvcc"] cmd = ["nvcc"]
cmd += ["--%s" % target, "-O3"] cmd += ["--%s" % target, "-O3"]
if arch: cmd += ["-arch", arch]
cmd += ["-arch", arch]
cmd += ["-o", file_target] cmd += ["-o", file_target]
if options: if options:
cmd += options cmd += options
cmd += [temp_code] cmd += [temp_code]
args = ' '.join(cmd)
proc = subprocess.Popen( proc = subprocess.Popen(
args, shell=True, cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
(out, _) = proc.communicate() (out, _) = proc.communicate()
if proc.returncode != 0: if proc.returncode != 0:
sys.stderr.write("Compilation error:\n") msg = "Compilation error:\n"
sys.stderr.write(str(out)) msg += out
sys.stderr.flush() raise RuntimeError(msg)
cubin = None
else: return bytearray(open(file_target, "rb").read())
cubin = bytearray(open(file_target, "rb").read())
return cubin
...@@ -40,6 +40,17 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -40,6 +40,17 @@ class CUDADeviceAPI final : public DeviceAPI {
&value, cudaDevAttrWarpSize, ctx.device_id)); &value, cudaDevAttrWarpSize, ctx.device_id));
break; break;
} }
case kComputeVersion: {
std::ostringstream os;
CUDA_CALL(cudaDeviceGetAttribute(
&value, cudaDevAttrComputeCapabilityMajor, ctx.device_id));
os << value << ".";
CUDA_CALL(cudaDeviceGetAttribute(
&value, cudaDevAttrComputeCapabilityMinor, ctx.device_id));
os << value;
*rv = os.str();
return;
}
} }
*rv = value; *rv = value;
} }
......
...@@ -39,6 +39,7 @@ void MetalWorkspace::GetAttr( ...@@ -39,6 +39,7 @@ void MetalWorkspace::GetAttr(
*rv = 1; *rv = 1;
break; break;
} }
case kComputeVersion: return;
case kExist: break; case kExist: break;
} }
} }
......
...@@ -45,6 +45,7 @@ void OpenCLWorkspace::GetAttr( ...@@ -45,6 +45,7 @@ void OpenCLWorkspace::GetAttr(
*rv = 1; *rv = 1;
break; break;
} }
case kComputeVersion: return;
case kExist: break; case kExist: break;
} }
} }
......
...@@ -44,6 +44,7 @@ class ROCMDeviceAPI final : public DeviceAPI { ...@@ -44,6 +44,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
value = 64; value = 64;
break; break;
} }
case kComputeVersion: return;
} }
*rv = value; *rv = value;
} }
......
...@@ -143,7 +143,7 @@ def split(ary, indices_or_sections, axis=0): ...@@ -143,7 +143,7 @@ def split(ary, indices_or_sections, axis=0):
begin_ids = [seg_size * i for i in range(indices_or_sections)] begin_ids = [seg_size * i for i in range(indices_or_sections)]
elif isinstance(indices_or_sections, (tuple, list)): elif isinstance(indices_or_sections, (tuple, list)):
assert tuple(indices_or_sections) == tuple(sorted(indices_or_sections)),\ assert tuple(indices_or_sections) == tuple(sorted(indices_or_sections)),\
"Should be sorted, recieved %s" %str(indices_or_sections) "Should be sorted, recieved %s" % str(indices_or_sections)
begin_ids = [0] + list(indices_or_sections) begin_ids = [0] + list(indices_or_sections)
else: else:
raise NotImplementedError raise NotImplementedError
......
...@@ -12,7 +12,7 @@ USE_MANUAL_CODE = False ...@@ -12,7 +12,7 @@ USE_MANUAL_CODE = False
@tvm.register_func @tvm.register_func
def tvm_callback_cuda_compile(code): def tvm_callback_cuda_compile(code):
ptx = nvcc.compile_cuda(code, target="ptx", options=["-arch=sm_52"]) ptx = nvcc.compile_cuda(code, target="ptx")
return ptx return ptx
......
...@@ -13,7 +13,7 @@ USE_MANUAL_CODE = False ...@@ -13,7 +13,7 @@ USE_MANUAL_CODE = False
@tvm.register_func @tvm.register_func
def tvm_callback_cuda_compile(code): def tvm_callback_cuda_compile(code):
ptx = nvcc.compile_cuda(code, target="ptx", options=["-arch=sm_37"]) # 37 for k80(ec2 instance) ptx = nvcc.compile_cuda(code, target="ptx")
return ptx return ptx
def write_code(code, fname): def write_code(code, fname):
......
...@@ -12,7 +12,7 @@ USE_MANUAL_CODE = False ...@@ -12,7 +12,7 @@ USE_MANUAL_CODE = False
@tvm.register_func @tvm.register_func
def tvm_callback_cuda_compile(code): def tvm_callback_cuda_compile(code):
ptx = nvcc.compile_cuda(code, target="ptx", options=["-arch=sm_37"]) ptx = nvcc.compile_cuda(code, target="ptx")
return ptx return ptx
def write_code(code, fname): def write_code(code, fname):
......
...@@ -9,7 +9,7 @@ USE_MANUAL_CODE = False ...@@ -9,7 +9,7 @@ USE_MANUAL_CODE = False
@tvm.register_func @tvm.register_func
def tvm_callback_cuda_compile(code): def tvm_callback_cuda_compile(code):
ptx = nvcc.compile_cuda(code, target="ptx", options=["-arch=sm_52"]) ptx = nvcc.compile_cuda(code, target="ptx")
return ptx return ptx
def write_code(code, fname): def write_code(code, fname):
......
...@@ -12,7 +12,7 @@ USE_MANUAL_CODE = False ...@@ -12,7 +12,7 @@ USE_MANUAL_CODE = False
@tvm.register_func @tvm.register_func
def tvm_callback_cuda_compile(code): def tvm_callback_cuda_compile(code):
ptx = nvcc.compile_cuda(code, target="ptx", options=["-arch=sm_52"]) ptx = nvcc.compile_cuda(code, target="ptx")
return ptx return ptx
......
...@@ -17,7 +17,7 @@ UNROLL_WLOAD = True ...@@ -17,7 +17,7 @@ UNROLL_WLOAD = True
@tvm.register_func @tvm.register_func
def tvm_callback_cuda_compile(code): def tvm_callback_cuda_compile(code):
"""Use nvcc compiler for better perf.""" """Use nvcc compiler for better perf."""
ptx = nvcc.compile_cuda(code, target="ptx", options=["-arch=sm_52"]) ptx = nvcc.compile_cuda(code, target="ptx")
return ptx return ptx
def write_code(code, fname): def write_code(code, fname):
......
...@@ -24,7 +24,7 @@ SKIP_CHECK = False ...@@ -24,7 +24,7 @@ SKIP_CHECK = False
@tvm.register_func @tvm.register_func
def tvm_callback_cuda_compile(code): def tvm_callback_cuda_compile(code):
"""Use nvcc compiler for better perf.""" """Use nvcc compiler for better perf."""
ptx = nvcc.compile_cuda(code, target="ptx", options=["-arch=sm_52"]) ptx = nvcc.compile_cuda(code, target="ptx")
return ptx return ptx
def write_code(code, fname): def write_code(code, fname):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment