Commit 47e57be4 by Zhi Committed by Tianqi Chen

support of multiple devices for tvm.build (#1773)

parent bea0b00f
......@@ -379,24 +379,94 @@ def lower(sch,
return stmt
return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func)
def build(sch,
def _build_for_device(flist, target, target_host):
"""Build the lowered functions for a device with the given compilation
target.
Parameters
----------
flist : list of LoweredFunc
The schedule to be built.
target : str or :any:`tvm.target.Target`
The target and option of the compilation.
target_host : str or :any:`tvm.target.Target`
The host compilation target.
Returns
-------
fhost : list of LoweredFunc
A list of lowered functions for the host.
mdev : tvm.module
A module that contains device code.
"""
target = _target.create(target)
device_type = ndarray.context(target.target_name, 0).device_type
fhost = []
fdevice = []
for func in flist:
if not ir_pass.VerifyMemory(func, device_type):
raise ValueError(
"Direct host side access to device memory is detected in %s. "
"Did you forget to bind?" % func.name)
if func.func_type == container.LoweredFunc.MixedFunc:
if current_build_config().detect_global_barrier:
func = ir_pass.ThreadSync(func, "global")
func = ir_pass.ThreadSync(func, "shared")
func = ir_pass.ThreadSync(func, "warp")
warp_size = target.thread_warp_size
func = ir_pass.LowerThreadAllreduce(func, warp_size)
fsplits = [s for s in ir_pass.SplitHostDevice(func)]
fhost.append(fsplits[0])
for x in fsplits[1:]:
fdevice.append(x)
elif func.func_type == container.LoweredFunc.HostFunc:
fhost.append(func)
elif func.func_type == container.LoweredFunc.DeviceFunc:
fdevice.append(func)
else:
raise ValueError("unknown function type %d" % func.func_type)
for i, func in enumerate(fdevice):
warp_size = target.thread_warp_size
fdevice[i] = ir_pass.LowerWarpMemory(func, warp_size)
if "gpu" in target.keys and not fdevice:
warnings.warn(
"Specified target %s, but cannot find device code, did you do "
"bind?" % target)
fhost = [ir_pass.BindDeviceType(x, device_type) for x in fhost]
fhost = [ir_pass.LowerTVMBuiltin(x) for x in fhost]
if device_type == ndarray.cpu(0).device_type and target_host == target:
assert not fdevice
target_host = _target.create(target_host)
fdevice = [ir_pass.LowerIntrin(x, target.target_name) for x in fdevice]
fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost]
fhost = [ir_pass.CombineContextCall(x) for x in fhost]
mdev = codegen.build_module(fdevice, str(target)) if fdevice else None
return fhost, mdev
def build(inputs,
args=None,
target=None,
target_host=None,
name="default_function",
binds=None,
postpone_host_codegen=False):
binds=None):
"""Build a function with arguments as signature. Code will be generated
for a device specified by the target. For homogeneous execution, a module
that contains both host and device code is returned. For heterogeneous
execution, a list of lowered functions for the host and a module containing
device code are returned, but actual code generation for the host module is
postponed after code generation is finished for all devices.
for devices coupled with target information.
Parameters
----------
sch : tvm.Schedule, or LoweredFunc
The schedule to be builded
inputs : tvm.Schedule, LoweredFunc, or dict of target to LoweredFunc list
The schedule to be built
args : list of Buffer or Tensor or Var, optional
The argument lists to the function.
......@@ -420,107 +490,108 @@ def build(sch,
Dictionary that maps the binding of symbolic buffer to Tensor.
By default, a new buffer is created for each tensor in the argument.
postpone_host_codegen : bool, optional
A bool value that indicates if code generation for the host module
should be postponed. This variable is set to be true for heterogeneous
execution. Otherwise, it is defaulted to false.
Returns
-------
ret : tvm.module, or (list of LoweredFunc, tvm.module) tuple
A module that combines both host and device code is returned when
postpone_host_codegen is not set. Otherwise, a list of lowered
functions for the host and a module contains only device code are
returned.
ret : tvm.module
A module that combines both host and device code.
Examples
________
There are two typical example uses of this function depending on the type
of the argument `inputs`:
1. it is a list of lowered functions:
.. code-block:: python
n = 2
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
s = tvm.create_schedule(C.op)
f = tvm.lower(s, [A, B, C], name="test_add")
m = tvm.build(f, target="llvm")
2. it is a dict of compilation target to list of lowered functions:
.. code-block:: python
n = 2
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
s1 = tvm.create_schedule(C.op)
s2 = topi.cpp.cuda.schedule_injective("cuda", [C])
f1 = tvm.lower(s1, [A, B, C], name="test_add1")
f2 = tvm.lower(s2, [A, B, C], name="test_add2")
m = tvm.build({"llvm": [f1], "cuda": [f2]}, target_host="llvm")
Note
----
See the note on :any:`tvm.target` on target string format.
"""
if isinstance(sch, schedule.Schedule):
if isinstance(inputs, schedule.Schedule):
if args is None:
raise ValueError("args must be given for build from schedule")
flist = lower(sch, args,
flist = lower(inputs, args,
name=name,
binds=binds)
if isinstance(flist, container.LoweredFunc):
flist = [flist]
elif isinstance(sch, container.LoweredFunc):
elif isinstance(inputs, container.LoweredFunc):
if args:
raise ValueError("args must be done when build from LoweredFunc")
flist = [sch]
elif isinstance(sch, (list, tuple, container.Array)):
flist = sch
raise ValueError("args must be done when build from LoweredFunc.")
flist = [inputs]
elif isinstance(inputs, (list, tuple, container.Array)):
flist = inputs
elif not isinstance(inputs, (dict, container.Map)):
raise ValueError("inputs must be Schedule, LoweredFunc, list of "
"LoweredFunc, or dict of target to list of "
"LoweredFunc.")
if not isinstance(inputs, (dict, container.Map)):
target = _target.current_target() if target is None else target
target = target if target else "llvm"
target_flist = {target: flist}
else:
raise ValueError("sch have to be Schedule, LoweredFunc or list of LoweredFunc")
target_flist = inputs
for tar, flist in target_flist.items():
if not isinstance(tar, (str, _target.Target)):
raise ValueError("The key of inputs must be str or "
"_target.Target when inputs is dict.")
fname_set = set()
for x in flist:
if not isinstance(x, container.LoweredFunc):
raise ValueError("sch have to be Schedule, LoweredFunc or list of LoweredFunc")
raise ValueError("inputs must be Schedule, LoweredFunc, list "
"of LoweredFunc, or dict of str to list of "
"LoweredFunc.")
if x.name in fname_set:
raise ValueError("Duplicate function name %s" % x.name)
fname_set.add(x.name)
target = _target.current_target() if target is None else target
target = _target.create(target) if target else _target.create("llvm")
device_type = ndarray.context(target.target_name, 0).device_type
fhost = []
fdevice = []
for func in flist:
if not ir_pass.VerifyMemory(func, device_type):
raise ValueError(
"Direct host side access to device memory is detected in %s. "
"Did you forget to bind?" % func.name)
if func.func_type == container.LoweredFunc.MixedFunc:
if current_build_config().detect_global_barrier:
func = ir_pass.ThreadSync(func, "global")
func = ir_pass.ThreadSync(func, "shared")
func = ir_pass.ThreadSync(func, "warp")
warp_size = target.thread_warp_size
func = ir_pass.LowerThreadAllreduce(func, warp_size)
fsplits = [s for s in ir_pass.SplitHostDevice(func)]
fhost.append(fsplits[0])
for x in fsplits[1:]:
fdevice.append(x)
elif func.func_type == container.LoweredFunc.HostFunc:
fhost.append(func)
elif func.func_type == container.LoweredFunc.DeviceFunc:
fdevice.append(func)
else:
raise ValueError("unknown function type %d" % func.func_type)
for i, func in enumerate(fdevice):
warp_size = target.thread_warp_size
fdevice[i] = ir_pass.LowerWarpMemory(func, warp_size)
if "gpu" in target.keys and not fdevice:
warnings.warn(
"Specified target %s, but cannot find device code, did you do bind?" % target)
fhost = [ir_pass.BindDeviceType(x, device_type) for x in fhost]
fhost = [ir_pass.LowerTVMBuiltin(x) for x in fhost]
if not target_host:
for tar, _ in target_flist.items():
tar = _target.create(tar)
device_type = ndarray.context(tar.target_name, 0).device_type
if device_type == ndarray.cpu(0).device_type:
target_host = target
assert not fdevice
else:
target_host = tar
break
if not target_host:
target_host = "llvm" if module.enabled("llvm") else "stackvm"
target_host = _target.create(target_host)
target_device = target
fdevice = [ir_pass.LowerIntrin(x, target_device.target_name) for x in fdevice]
fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost]
fhost = [ir_pass.CombineContextCall(x) for x in fhost]
# Append fhost to the device module and return the updated module. All
# device modules will be imported to the host module after all of them are
# collected.
mdev = codegen.build_module(fdevice, str(target_device)) if fdevice else None
if postpone_host_codegen:
return fhost, mdev
fhost_all = []
device_modules = []
for tar, flist in target_flist.items():
fhost, mdev = _build_for_device(flist, tar, target_host)
# Save the current lowered functions of the host and the device module.
fhost_all += fhost
device_modules.append(mdev)
# Generate a unified host module.
mhost = codegen.build_module(fhost_all, str(target_host))
mhost = codegen.build_module(fhost, str(target_host))
if fdevice:
# Import all modules.
for mdev in device_modules:
if mdev:
mhost.import_module(mdev)
return mhost
......@@ -124,9 +124,6 @@ def test_simplex_data_transferring():
schedule_add = topi.cpp.cuda.schedule_injective(target, [elemwise_add])
lower_add = tvm.lower(schedule_add, [tensor_a, tensor_b, elemwise_add],
name="elemwise_add")
host_funcs_add, lib_add = tvm.build(lower_add, target=target_device,
name="elemwise_add",
postpone_host_codegen=True)
# Insert copy. Neither compute nor schedule is required for the copy
# node. The compute will be performed at runtime which is just data
......@@ -142,16 +139,8 @@ def test_simplex_data_transferring():
elemwise_sub],
name="elemwise_sub")
host_funcs_sub, lib_sub = tvm.build(lower_sub, target=target_host,
name="elemwise_sub",
postpone_host_codegen=True)
host_funcs = host_funcs_add + host_funcs_sub
mhost = tvm.codegen.build_module(host_funcs, target_host)
if lib_add:
mhost.import_module(lib_add)
if lib_sub:
mhost.import_module(lib_sub)
target_flist = {target_device: [lower_add], target_host: [lower_sub]}
mhost = tvm.build(target_flist, target_host=target_host)
ctx = [host_ctx, device_ctx]
mod = graph_runtime.create(graph, mhost, ctx)
params = {}
......@@ -338,10 +327,6 @@ def test_duplex_data_transferring():
lower_add1 = tvm.lower(
add_schedule1, [tensor_d, copy_sub_add, elemwise_add1],
name="elemwise_add1")
host_funcs_add, lib_add = tvm.build([lower_add0, lower_add1],
target=target_device,
postpone_host_codegen=True)
# Create module for sub whose target is the host.
tensor_c = tvm.placeholder(shape, name="C")
elemwise_sub = tvm.compute(shape, lambda *i: copy_add_sub(*i)
......@@ -350,15 +335,10 @@ def test_duplex_data_transferring():
lower_sub = tvm.lower(sub_schedule, [copy_add_sub, tensor_c,
elemwise_sub],
name="elemwise_sub")
host_funcs_sub, lib_sub = tvm.build(lower_sub, target=target_host,
postpone_host_codegen=True)
host_funcs = host_funcs_add + host_funcs_sub
mhost = tvm.codegen.build_module(host_funcs, target_host)
if lib_add:
mhost.import_module(lib_add)
if lib_sub:
mhost.import_module(lib_sub)
target_flist = {target_device: [lower_add0, lower_add1], target_host:
[lower_sub]}
mhost = tvm.build(target_flist, target_host=target_host)
ctx = [host_ctx, device_ctx]
params = {}
params["A"] = tensor_a = np.random.uniform(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment