Unverified Commit fb7fa8e4 by Tianqi Chen Committed by GitHub

[AUTOTVM] Refactor measure build func (#2927)

parent c6339730
......@@ -19,7 +19,7 @@ import numpy as np
from ... import ir_pass, build, build_config, nd, TVMError, register_func, \
rpc as _rpc, target as _target
from ...contrib import nvcc, ndk
from ...contrib import nvcc, ndk, tar
from ..util import get_const_tuple
from ..env import AutotvmGlobalScope
......@@ -58,20 +58,20 @@ class LocalBuilder(Builder):
build_func: callable or str
If is 'default', use default build function
If is 'ndk', use function for android ndk
If is callable, use it as custom build function
If is callable, use it as custom build function, expect lib_format field.
"""
def __init__(self, timeout=10, n_parallel=None, build_func='default'):
super(LocalBuilder, self).__init__(timeout, n_parallel)
if isinstance(build_func, str):
if build_func == 'default':
build_func = default_build_func
build_func = tar.tar
elif build_func == 'ndk':
build_func = android_ndk_build_func
build_func = ndk.create_shared
else:
raise ValueError("Invalid build_func" + build_func)
self.build_func = build_func
self.build_func = _wrap_build_func(build_func)
self.executor = LocalExecutor(timeout=timeout)
self.tmp_dir = tempfile.mkdtemp()
......@@ -349,46 +349,47 @@ def _build_func_common(measure_input, check_gpu=None, cuda_arch=None, build_opti
return func, tuple((get_const_tuple(x.shape), x.dtype) for x in args)
def default_build_func(measure_input, tmp_dir, **kwargs):
def _wrap_build_func(build_func):
"""
Default build func. This can work for cuda, opencl, llvm backend
Wrap build_func to a function that can be used in measure.
Parameters
----------
measure_input: MeasureInput
The input of measurement
tmp_dir: str
The path of temporary directory to export generated library
"""
tic = time.time()
try:
filename = os.path.join(tmp_dir, "tmp_func_%0x.tar" % getrandbits(64))
func, arg_info = _build_func_common(measure_input, **kwargs)
func.export_library(filename)
except Exception as e: # pylint: disable=broad-except
return BuildResult(None, None, e, time.time() - tic)
return BuildResult(filename, arg_info, None, time.time() - tic)
build_func : The compilation function
We expect fcompile to contain an attr "output_format"
Returns
-------
wrapped_build_func : function
The wrapped build function
"""
if not hasattr(build_func, "output_format"):
raise AttributeError("Expect build_func to have the attribute output_format.")
output_format = build_func.output_format
def android_ndk_build_func(measure_input, tmp_dir, **kwargs):
def _wrapped(measure_input, tmp_dir, **kwargs):
"""
Build function for android device using ndk.
Wrapped build func.
Parameters
----------
measure_input: MeasureInput
The input of measurement
tmp_dir: str
The path of temporary directory to export generated library
"""
tic = time.time()
try:
filename = os.path.join(tmp_dir, "tmp_func_%0x.so" % getrandbits(64))
filename = os.path.join(tmp_dir, "tmp_func_%0x.%s" % (
getrandbits(64), output_format))
# TODO(tvm-team) consider linline _build_func_common
func, arg_info = _build_func_common(measure_input, **kwargs)
func.export_library(filename, ndk.create_shared)
func.export_library(filename, build_func)
except Exception as e: # pylint: disable=broad-except
return BuildResult(None, None, e, time.time() - tic)
return BuildResult(filename, arg_info, None, time.time() - tic)
return _wrapped
def run_through_rpc(measure_input, build_result,
......
......@@ -29,7 +29,7 @@ def create_shared(output,
cc : str, optional
The compile string.
"""
if sys.platform == "darwin" or sys.platform.startswith('linux'):
if sys.platform == "darwin" or sys.platform.startswith("linux"):
_linux_shared(output, objects, options, cc)
elif sys.platform == "win32":
_windows_shared(output, objects, options)
......@@ -37,6 +37,38 @@ def create_shared(output,
raise ValueError("Unsupported platform")
# assign so as default output format
create_shared.output_format = "so" if sys.platform != "win32" else "dll"
def cross_compiler(cc, options=None, output_format="so"):
"""Create a cross compiler function.
Parameters
----------
cc : str
The cross compiler name.
options : list, optional
List of additional optional string.
output_format : str, optional
Library output format.
Returns
-------
fcompile : function
A compilation function that can be passed to export_library.
"""
def _fcompile(outputs, objects, opts=None):
opts = opts if opts else []
if options:
opts += options
_linux_shared(outputs, objects, opts, cc=cc)
_fcompile.output_format = output_format
return _fcompile
def _linux_shared(output, objects, options, cc="g++"):
cmd = [cc]
cmd += ["-shared", "-fPIC"]
......
......@@ -42,6 +42,9 @@ def tar(output, files):
msg += py_str(out)
raise RuntimeError(msg)
# assign output format
tar.output_format = "tar"
def untar(tar_file, directory):
"""Unpack all tar files into the directory
......
......@@ -98,6 +98,9 @@ def create_dylib(output, objects, arch, sdk="macosx"):
raise RuntimeError(msg)
# assign so as default output format
create_dylib.output_format = "dylib"
def compile_metal(code, path_target=None, sdk="macosx"):
"""Compile metal with CLI tool from env.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment