Commit f8f4ceb2 by Yizhi Liu Committed by Tianqi Chen

[nvcc] enable multiple arch in one fatbin (#4377)

parent 500ff051
......@@ -582,7 +582,13 @@ def check_remote(target, device_key, host=None, port=None, priority=100, timeout
@register_func
def tvm_callback_cuda_compile(code):
"""use nvcc to generate ptx code for better optimization"""
ptx = nvcc.compile_cuda(code, target="ptx", arch=AutotvmGlobalScope.current.cuda_target_arch)
curr_cuda_target_arch = AutotvmGlobalScope.current.cuda_target_arch
# e.g., target arch could be [
# "-gencode", "arch=compute_52,code=sm_52",
# "-gencode", "arch=compute_70,code=sm_70"
# ]
target = "fatbin" if isinstance(curr_cuda_target_arch, list) else "ptx"
ptx = nvcc.compile_cuda(code, target=target, arch=AutotvmGlobalScope.current.cuda_target_arch)
return ptx
......@@ -591,8 +597,10 @@ def set_cuda_target_arch(arch):
Parameters
----------
arch: str
arch: str or list
The argument of nvcc -arch. (e.g. "sm_51", "sm_62")
it can also be a count of gencode arguments pass to nvcc command line,
e.g., ["-gencode", "arch=compute_52,code=sm_52", "-gencode", "arch=compute_70,code=sm_70"]
"""
AutotvmGlobalScope.current.cuda_target_arch = arch
......
......@@ -74,7 +74,10 @@ def compile_cuda(code,
file_target = path_target if path_target else temp_target
cmd = ["nvcc"]
cmd += ["--%s" % target, "-O3"]
cmd += ["-arch", arch]
if isinstance(arch, list):
cmd += arch
else:
cmd += ["-arch", arch]
if options:
if isinstance(options, str):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment