Commit c468558e by Tianqi Chen Committed by GitHub

[CUDA] auto detect compatibility when arch is not passed (#490)

parent c6a20452
......@@ -18,7 +18,8 @@ namespace runtime {
enum DeviceAttrKind : int {
kExist = 0,
kMaxThreadsPerBlock = 1,
kWarpSize = 2
kWarpSize = 2,
kComputeVersion = 3
};
/*! \brief Number of bytes each allocation must align to */
......
......@@ -131,6 +131,20 @@ class TVMContext(ctypes.Structure):
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 2)
@property
def compute_version(self):
"""Get compute verison number in string.
Currently used to get compute capability of CUDA device.
Returns
-------
version : str
The version string in `major.minor` format.
"""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 3)
def sync(self):
"""Synchronize until jobs finished at the context."""
check_call(_LIB.TVMSynchronize(self.device_type, self.device_id, None))
......
......@@ -39,11 +39,8 @@ def create_shared(output,
if options:
cmd += options
args = ' '.join(cmd)
proc = subprocess.Popen(
args, shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
if proc.returncode != 0:
......
# pylint: disable=invalid-name
"""Utility to invoke nvcc compiler in the system"""
from __future__ import absolute_import as _abs
import sys
import subprocess
from . import util
from .. import ndarray as nd
def compile_cuda(code, target="ptx", arch=None,
options=None, path_target=None):
def compile_cuda(code,
target="ptx",
arch=None,
options=None,
path_target=None):
"""Compile cuda code with NVCC from env.
Parameters
......@@ -39,32 +43,32 @@ def compile_cuda(code, target="ptx", arch=None,
with open(temp_code, "w") as out_file:
out_file.write(code)
if target == "cubin" and arch is None:
raise ValueError("arch(sm_xy) must be passed for generating cubin")
if arch is None:
if nd.gpu(0).exist:
# auto detect the compute arch argument
arch = "sm_" + "".join(nd.gpu(0).compute_version.split('.'))
else:
raise ValueError("arch(sm_xy) is not passed, and we cannot detect it from env")
file_target = path_target if path_target else temp_target
cmd = ["nvcc"]
cmd += ["--%s" % target, "-O3"]
if arch:
cmd += ["-arch", arch]
cmd += ["-o", file_target]
if options:
cmd += options
cmd += [temp_code]
args = ' '.join(cmd)
proc = subprocess.Popen(
args, shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
if proc.returncode != 0:
sys.stderr.write("Compilation error:\n")
sys.stderr.write(str(out))
sys.stderr.flush()
cubin = None
else:
cubin = bytearray(open(file_target, "rb").read())
return cubin
msg = "Compilation error:\n"
msg += out
raise RuntimeError(msg)
return bytearray(open(file_target, "rb").read())
......@@ -40,6 +40,17 @@ class CUDADeviceAPI final : public DeviceAPI {
&value, cudaDevAttrWarpSize, ctx.device_id));
break;
}
case kComputeVersion: {
std::ostringstream os;
CUDA_CALL(cudaDeviceGetAttribute(
&value, cudaDevAttrComputeCapabilityMajor, ctx.device_id));
os << value << ".";
CUDA_CALL(cudaDeviceGetAttribute(
&value, cudaDevAttrComputeCapabilityMinor, ctx.device_id));
os << value;
*rv = os.str();
return;
}
}
*rv = value;
}
......
......@@ -39,6 +39,7 @@ void MetalWorkspace::GetAttr(
*rv = 1;
break;
}
case kComputeVersion: return;
case kExist: break;
}
}
......
......@@ -45,6 +45,7 @@ void OpenCLWorkspace::GetAttr(
*rv = 1;
break;
}
case kComputeVersion: return;
case kExist: break;
}
}
......
......@@ -44,6 +44,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
value = 64;
break;
}
case kComputeVersion: return;
}
*rv = value;
}
......
......@@ -143,7 +143,7 @@ def split(ary, indices_or_sections, axis=0):
begin_ids = [seg_size * i for i in range(indices_or_sections)]
elif isinstance(indices_or_sections, (tuple, list)):
assert tuple(indices_or_sections) == tuple(sorted(indices_or_sections)),\
"Should be sorted, recieved %s" %str(indices_or_sections)
"Should be sorted, recieved %s" % str(indices_or_sections)
begin_ids = [0] + list(indices_or_sections)
else:
raise NotImplementedError
......
......@@ -12,7 +12,7 @@ USE_MANUAL_CODE = False
@tvm.register_func
def tvm_callback_cuda_compile(code):
ptx = nvcc.compile_cuda(code, target="ptx", options=["-arch=sm_52"])
ptx = nvcc.compile_cuda(code, target="ptx")
return ptx
......
......@@ -13,7 +13,7 @@ USE_MANUAL_CODE = False
@tvm.register_func
def tvm_callback_cuda_compile(code):
ptx = nvcc.compile_cuda(code, target="ptx", options=["-arch=sm_37"]) # 37 for k80(ec2 instance)
ptx = nvcc.compile_cuda(code, target="ptx")
return ptx
def write_code(code, fname):
......
......@@ -12,7 +12,7 @@ USE_MANUAL_CODE = False
@tvm.register_func
def tvm_callback_cuda_compile(code):
ptx = nvcc.compile_cuda(code, target="ptx", options=["-arch=sm_37"])
ptx = nvcc.compile_cuda(code, target="ptx")
return ptx
def write_code(code, fname):
......
......@@ -9,7 +9,7 @@ USE_MANUAL_CODE = False
@tvm.register_func
def tvm_callback_cuda_compile(code):
ptx = nvcc.compile_cuda(code, target="ptx", options=["-arch=sm_52"])
ptx = nvcc.compile_cuda(code, target="ptx")
return ptx
def write_code(code, fname):
......
......@@ -12,7 +12,7 @@ USE_MANUAL_CODE = False
@tvm.register_func
def tvm_callback_cuda_compile(code):
ptx = nvcc.compile_cuda(code, target="ptx", options=["-arch=sm_52"])
ptx = nvcc.compile_cuda(code, target="ptx")
return ptx
......
......@@ -17,7 +17,7 @@ UNROLL_WLOAD = True
@tvm.register_func
def tvm_callback_cuda_compile(code):
"""Use nvcc compiler for better perf."""
ptx = nvcc.compile_cuda(code, target="ptx", options=["-arch=sm_52"])
ptx = nvcc.compile_cuda(code, target="ptx")
return ptx
def write_code(code, fname):
......
......@@ -24,7 +24,7 @@ SKIP_CHECK = False
@tvm.register_func
def tvm_callback_cuda_compile(code):
"""Use nvcc compiler for better perf."""
ptx = nvcc.compile_cuda(code, target="ptx", options=["-arch=sm_52"])
ptx = nvcc.compile_cuda(code, target="ptx")
return ptx
def write_code(code, fname):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment