Unverified Commit 19d0d157 by Tianqi Chen Committed by GitHub

[CONTRIB][CC] Enhance cc.cross_compiler (#4817)

* [CONTRIB][CC] Enhance cc.cross_compiler

- Enhance cc.cross_compiler to take str argument.
- Remove cc.build_create_shared_func as it is dupilicated with cross_compiler
- Add examples to cc.cross_compiler

* address review comments
parent 5ea4f0d5
...@@ -87,38 +87,23 @@ create_shared.output_format = "so" if sys.platform != "win32" else "dll" ...@@ -87,38 +87,23 @@ create_shared.output_format = "so" if sys.platform != "win32" else "dll"
create_shared.get_target_triple = get_target_by_dump_machine( create_shared.get_target_triple = get_target_by_dump_machine(
"g++" if sys.platform == "darwin" or sys.platform.startswith("linux") else None) "g++" if sys.platform == "darwin" or sys.platform.startswith("linux") else None)
def build_create_shared_func(options=None, compile_cmd="g++"):
"""Build create_shared function with particular default options and compile_cmd.
Parameters def cross_compiler(compile_func,
---------- options=None,
options : List[str] output_format=None,
The list of additional options string. get_target_triple=None):
"""Create a cross compiler function by specializing compile_func with options.
compile_cmd : Optional[str] This function can be used to construct compile functions that
The compiler command. can be passed to AutoTVM measure or export_library.
Returns
-------
create_shared_wrapper : Callable[[str, str, Optional[str]], None]
A compilation function that can be passed to export_library or to autotvm.LocalBuilder.
"""
def create_shared_wrapper(output, objects, options=options, compile_cmd=compile_cmd):
create_shared(output, objects, options, compile_cmd)
create_shared_wrapper.output_format = create_shared.output_format
create_shared_wrapper.get_target_triple = get_target_by_dump_machine(compile_cmd)
return create_shared_wrapper
def cross_compiler(compile_func, base_options=None, output_format="so", get_target_triple=None):
"""Create a cross compiler function.
Parameters Parameters
---------- ----------
compile_func : Callable[[str, str, Optional[str]], None] compile_func : Union[str, Callable[[str, str, Optional[str]], None]]
Function that performs the actual compilation Function that performs the actual compilation
base_options : Optional[List[str]] options : Optional[List[str]]
List of additional optional string. List of additional optional string.
output_format : Optional[str] output_format : Optional[str]
...@@ -131,14 +116,44 @@ def cross_compiler(compile_func, base_options=None, output_format="so", get_targ ...@@ -131,14 +116,44 @@ def cross_compiler(compile_func, base_options=None, output_format="so", get_targ
------- -------
fcompile : Callable[[str, str, Optional[str]], None] fcompile : Callable[[str, str, Optional[str]], None]
A compilation function that can be passed to export_library. A compilation function that can be passed to export_library.
Examples
--------
.. code-block:: python
from tvm.contrib import cc, ndk
# export using arm gcc
mod = build_runtime_module()
mod.export_library(path_dso,
cc.cross_compiler("arm-linux-gnueabihf-gcc"))
# specialize ndk compilation options.
specialized_ndk = cc.cross_compiler(
ndk.create_shared,
["--sysroot=/path/to/sysroot", "-shared", "-fPIC", "-lm"])
mod.export_library(path_dso, specialized_ndk)
""" """
if base_options is None: base_options = [] if options is None else options
base_options = [] kwargs = {}
# handle case where compile_func is the name of the cc
if isinstance(compile_func, str):
kwargs = {"cc" : compile_func}
compile_func = create_shared
def _fcompile(outputs, objects, options=None): def _fcompile(outputs, objects, options=None):
all_options = base_options all_options = base_options
if options is not None: if options is not None:
all_options += options all_options += options
compile_func(outputs, objects, options=all_options) compile_func(outputs, objects, options=all_options, **kwargs)
if not output_format and hasattr(compile_func, "output_format"):
output_format = compile_func.output_format
output_format = output_format if output_format else "so"
if not get_target_triple and hasattr(compile_func, "get_target_triple"):
get_target_triple = compile_func.get_target_triple
_fcompile.output_format = output_format _fcompile.output_format = output_format
_fcompile.get_target_triple = get_target_triple _fcompile.get_target_triple = get_target_triple
return _fcompile return _fcompile
......
...@@ -113,7 +113,8 @@ def test_device_module_dump(): ...@@ -113,7 +113,8 @@ def test_device_module_dump():
raise ValueError("Unsupported platform") raise ValueError("Unsupported platform")
path_dso = temp.relpath("dev_lib.so") path_dso = temp.relpath("dev_lib.so")
f.export_library(path_dso) # test cross compiler function
f.export_library(path_dso, cc.cross_compiler("g++"))
f1 = tvm.module.load(path_dso) f1 = tvm.module.load(path_dso)
a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx) a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
...@@ -134,8 +135,8 @@ def test_device_module_dump(): ...@@ -134,8 +135,8 @@ def test_device_module_dump():
name = "myadd_%s" % device name = "myadd_%s" % device
f = tvm.build(s, [A, B], device, "stackvm", name=name) f = tvm.build(s, [A, B], device, "stackvm", name=name)
path_dso = temp.relpath("dev_lib.stackvm") path_dso = temp.relpath("dev_lib.stackvm")
#f.export_library(path_dso) f.export_library(path_dso)
#f1 = tvm.module.load(path_dso) f1 = tvm.module.load(path_dso)
a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx) a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx) b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
f(a, b) f(a, b)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment