Unverified Commit fb7fa8e4 by Tianqi Chen Committed by GitHub

[AUTOTVM] Refactor measure build func (#2927)

parent c6339730
...@@ -19,7 +19,7 @@ import numpy as np ...@@ -19,7 +19,7 @@ import numpy as np
from ... import ir_pass, build, build_config, nd, TVMError, register_func, \ from ... import ir_pass, build, build_config, nd, TVMError, register_func, \
rpc as _rpc, target as _target rpc as _rpc, target as _target
from ...contrib import nvcc, ndk from ...contrib import nvcc, ndk, tar
from ..util import get_const_tuple from ..util import get_const_tuple
from ..env import AutotvmGlobalScope from ..env import AutotvmGlobalScope
...@@ -58,20 +58,20 @@ class LocalBuilder(Builder): ...@@ -58,20 +58,20 @@ class LocalBuilder(Builder):
build_func: callable or str build_func: callable or str
If is 'default', use default build function If is 'default', use default build function
If is 'ndk', use function for android ndk If is 'ndk', use function for android ndk
If is callable, use it as custom build function If is callable, use it as custom build function, expect lib_format field.
""" """
def __init__(self, timeout=10, n_parallel=None, build_func='default'): def __init__(self, timeout=10, n_parallel=None, build_func='default'):
super(LocalBuilder, self).__init__(timeout, n_parallel) super(LocalBuilder, self).__init__(timeout, n_parallel)
if isinstance(build_func, str): if isinstance(build_func, str):
if build_func == 'default': if build_func == 'default':
build_func = default_build_func build_func = tar.tar
elif build_func == 'ndk': elif build_func == 'ndk':
build_func = android_ndk_build_func build_func = ndk.create_shared
else: else:
raise ValueError("Invalid build_func" + build_func) raise ValueError("Invalid build_func" + build_func)
self.build_func = build_func self.build_func = _wrap_build_func(build_func)
self.executor = LocalExecutor(timeout=timeout) self.executor = LocalExecutor(timeout=timeout)
self.tmp_dir = tempfile.mkdtemp() self.tmp_dir = tempfile.mkdtemp()
...@@ -349,46 +349,47 @@ def _build_func_common(measure_input, check_gpu=None, cuda_arch=None, build_opti ...@@ -349,46 +349,47 @@ def _build_func_common(measure_input, check_gpu=None, cuda_arch=None, build_opti
return func, tuple((get_const_tuple(x.shape), x.dtype) for x in args) return func, tuple((get_const_tuple(x.shape), x.dtype) for x in args)
def default_build_func(measure_input, tmp_dir, **kwargs): def _wrap_build_func(build_func):
""" """
Default build func. This can work for cuda, opencl, llvm backend Wrap build_func to a function that can be used in measure.
Parameters Parameters
---------- ----------
measure_input: MeasureInput build_func : The compilation function
The input of measurement We expect fcompile to contain an attr "output_format"
tmp_dir: str
The path of temporary directory to export generated library
"""
tic = time.time()
try:
filename = os.path.join(tmp_dir, "tmp_func_%0x.tar" % getrandbits(64))
func, arg_info = _build_func_common(measure_input, **kwargs)
func.export_library(filename)
except Exception as e: # pylint: disable=broad-except
return BuildResult(None, None, e, time.time() - tic)
return BuildResult(filename, arg_info, None, time.time() - tic)
def android_ndk_build_func(measure_input, tmp_dir, **kwargs):
"""
Build function for android device using ndk.
Parameters Returns
---------- -------
measure_input: MeasureInput wrapped_build_func : function
The input of measurement The wrapped build function
tmp_dir: str
The path of temporary directory to export generated library
""" """
tic = time.time() if not hasattr(build_func, "output_format"):
try: raise AttributeError("Expect build_func to have the attribute output_format.")
filename = os.path.join(tmp_dir, "tmp_func_%0x.so" % getrandbits(64)) output_format = build_func.output_format
func, arg_info = _build_func_common(measure_input, **kwargs)
func.export_library(filename, ndk.create_shared) def _wrapped(measure_input, tmp_dir, **kwargs):
except Exception as e: # pylint: disable=broad-except """
return BuildResult(None, None, e, time.time() - tic) Wrapped build func.
return BuildResult(filename, arg_info, None, time.time() - tic)
Parameters
----------
measure_input: MeasureInput
The input of measurement
tmp_dir: str
The path of temporary directory to export generated library
"""
tic = time.time()
try:
filename = os.path.join(tmp_dir, "tmp_func_%0x.%s" % (
getrandbits(64), output_format))
# TODO(tvm-team) consider linline _build_func_common
func, arg_info = _build_func_common(measure_input, **kwargs)
func.export_library(filename, build_func)
except Exception as e: # pylint: disable=broad-except
return BuildResult(None, None, e, time.time() - tic)
return BuildResult(filename, arg_info, None, time.time() - tic)
return _wrapped
def run_through_rpc(measure_input, build_result, def run_through_rpc(measure_input, build_result,
......
...@@ -29,7 +29,7 @@ def create_shared(output, ...@@ -29,7 +29,7 @@ def create_shared(output,
cc : str, optional cc : str, optional
The compile string. The compile string.
""" """
if sys.platform == "darwin" or sys.platform.startswith('linux'): if sys.platform == "darwin" or sys.platform.startswith("linux"):
_linux_shared(output, objects, options, cc) _linux_shared(output, objects, options, cc)
elif sys.platform == "win32": elif sys.platform == "win32":
_windows_shared(output, objects, options) _windows_shared(output, objects, options)
...@@ -37,6 +37,38 @@ def create_shared(output, ...@@ -37,6 +37,38 @@ def create_shared(output,
raise ValueError("Unsupported platform") raise ValueError("Unsupported platform")
# assign so as default output format
create_shared.output_format = "so" if sys.platform != "win32" else "dll"
def cross_compiler(cc, options=None, output_format="so"):
"""Create a cross compiler function.
Parameters
----------
cc : str
The cross compiler name.
options : list, optional
List of additional optional string.
output_format : str, optional
Library output format.
Returns
-------
fcompile : function
A compilation function that can be passed to export_library.
"""
def _fcompile(outputs, objects, opts=None):
opts = opts if opts else []
if options:
opts += options
_linux_shared(outputs, objects, opts, cc=cc)
_fcompile.output_format = output_format
return _fcompile
def _linux_shared(output, objects, options, cc="g++"): def _linux_shared(output, objects, options, cc="g++"):
cmd = [cc] cmd = [cc]
cmd += ["-shared", "-fPIC"] cmd += ["-shared", "-fPIC"]
......
...@@ -42,6 +42,9 @@ def tar(output, files): ...@@ -42,6 +42,9 @@ def tar(output, files):
msg += py_str(out) msg += py_str(out)
raise RuntimeError(msg) raise RuntimeError(msg)
# assign output format
tar.output_format = "tar"
def untar(tar_file, directory): def untar(tar_file, directory):
"""Unpack all tar files into the directory """Unpack all tar files into the directory
......
...@@ -98,6 +98,9 @@ def create_dylib(output, objects, arch, sdk="macosx"): ...@@ -98,6 +98,9 @@ def create_dylib(output, objects, arch, sdk="macosx"):
raise RuntimeError(msg) raise RuntimeError(msg)
# assign so as default output format
create_dylib.output_format = "dylib"
def compile_metal(code, path_target=None, sdk="macosx"): def compile_metal(code, path_target=None, sdk="macosx"):
"""Compile metal with CLI tool from env. """Compile metal with CLI tool from env.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment