Unverified Commit 19d0d157 by Tianqi Chen Committed by GitHub

[CONTRIB][CC] Enhance cc.cross_compiler (#4817)

* [CONTRIB][CC] Enhance cc.cross_compiler

- Enhance cc.cross_compiler to take str argument.
- Remove cc.build_create_shared_func as it is dupilicated with cross_compiler
- Add examples to cc.cross_compiler

* address review comments
parent 5ea4f0d5
......@@ -87,38 +87,23 @@ create_shared.output_format = "so" if sys.platform != "win32" else "dll"
create_shared.get_target_triple = get_target_by_dump_machine(
"g++" if sys.platform == "darwin" or sys.platform.startswith("linux") else None)
def build_create_shared_func(options=None, compile_cmd="g++"):
"""Build create_shared function with particular default options and compile_cmd.
Parameters
----------
options : List[str]
The list of additional options string.
def cross_compiler(compile_func,
options=None,
output_format=None,
get_target_triple=None):
"""Create a cross compiler function by specializing compile_func with options.
compile_cmd : Optional[str]
The compiler command.
Returns
-------
create_shared_wrapper : Callable[[str, str, Optional[str]], None]
A compilation function that can be passed to export_library or to autotvm.LocalBuilder.
"""
def create_shared_wrapper(output, objects, options=options, compile_cmd=compile_cmd):
create_shared(output, objects, options, compile_cmd)
create_shared_wrapper.output_format = create_shared.output_format
create_shared_wrapper.get_target_triple = get_target_by_dump_machine(compile_cmd)
return create_shared_wrapper
This function can be used to construct compile functions that
can be passed to AutoTVM measure or export_library.
def cross_compiler(compile_func, base_options=None, output_format="so", get_target_triple=None):
"""Create a cross compiler function.
Parameters
----------
compile_func : Callable[[str, str, Optional[str]], None]
compile_func : Union[str, Callable[[str, str, Optional[str]], None]]
Function that performs the actual compilation
base_options : Optional[List[str]]
options : Optional[List[str]]
List of additional optional string.
output_format : Optional[str]
......@@ -131,14 +116,44 @@ def cross_compiler(compile_func, base_options=None, output_format="so", get_targ
-------
fcompile : Callable[[str, str, Optional[str]], None]
A compilation function that can be passed to export_library.
Examples
--------
.. code-block:: python
from tvm.contrib import cc, ndk
# export using arm gcc
mod = build_runtime_module()
mod.export_library(path_dso,
cc.cross_compiler("arm-linux-gnueabihf-gcc"))
# specialize ndk compilation options.
specialized_ndk = cc.cross_compiler(
ndk.create_shared,
["--sysroot=/path/to/sysroot", "-shared", "-fPIC", "-lm"])
mod.export_library(path_dso, specialized_ndk)
"""
if base_options is None:
base_options = []
base_options = [] if options is None else options
kwargs = {}
# handle case where compile_func is the name of the cc
if isinstance(compile_func, str):
kwargs = {"cc" : compile_func}
compile_func = create_shared
def _fcompile(outputs, objects, options=None):
all_options = base_options
if options is not None:
all_options += options
compile_func(outputs, objects, options=all_options)
compile_func(outputs, objects, options=all_options, **kwargs)
if not output_format and hasattr(compile_func, "output_format"):
output_format = compile_func.output_format
output_format = output_format if output_format else "so"
if not get_target_triple and hasattr(compile_func, "get_target_triple"):
get_target_triple = compile_func.get_target_triple
_fcompile.output_format = output_format
_fcompile.get_target_triple = get_target_triple
return _fcompile
......
......@@ -113,7 +113,8 @@ def test_device_module_dump():
raise ValueError("Unsupported platform")
path_dso = temp.relpath("dev_lib.so")
f.export_library(path_dso)
# test cross compiler function
f.export_library(path_dso, cc.cross_compiler("g++"))
f1 = tvm.module.load(path_dso)
a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
......@@ -134,8 +135,8 @@ def test_device_module_dump():
name = "myadd_%s" % device
f = tvm.build(s, [A, B], device, "stackvm", name=name)
path_dso = temp.relpath("dev_lib.stackvm")
#f.export_library(path_dso)
#f1 = tvm.module.load(path_dso)
f.export_library(path_dso)
f1 = tvm.module.load(path_dso)
a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
f(a, b)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment