Commit 7c3ec7df by Zhi Committed by Tianqi Chen

Heterogeneous Runtime (#1695)

parent 7beafddd
...@@ -384,8 +384,14 @@ def build(sch, ...@@ -384,8 +384,14 @@ def build(sch,
target=None, target=None,
target_host=None, target_host=None,
name="default_function", name="default_function",
binds=None): binds=None,
"""Build a function with arguments as signiture. postpone_host_codegen=False):
"""Build a function with arguments as signature. Code will be generated
for a device specified by the target. For homogeneous execution, a module
that contains both host and device code is returned. For heterogeneous
execution, a list of lowered functions for the host and a module containing
device code are returned, but actual code generation for the host module is
postponed after code generation is finished for all devices.
Parameters Parameters
---------- ----------
...@@ -414,10 +420,18 @@ def build(sch, ...@@ -414,10 +420,18 @@ def build(sch,
Dictionary that maps the binding of symbolic buffer to Tensor. Dictionary that maps the binding of symbolic buffer to Tensor.
By default, a new buffer is created for each tensor in the argument. By default, a new buffer is created for each tensor in the argument.
postpone_host_codegen : bool, optional
A bool value that indicates if code generation for the host module
should be postponed. This variable is set to be true for heterogeneous
execution. Otherwise, it is defaulted to false.
Returns Returns
------- -------
f : Function, or pair of functions ret : tvm.module, or (list of LoweredFunc, tvm.module) tuple
The result function. A module that combines both host and device code is returned when
postpone_host_codegen is not set. Otherwise, a list of lowered
functions for the host and a module contains only device code are
returned.
Note Note
---- ----
...@@ -498,9 +512,15 @@ def build(sch, ...@@ -498,9 +512,15 @@ def build(sch,
fdevice = [ir_pass.LowerIntrin(x, target_device.target_name) for x in fdevice] fdevice = [ir_pass.LowerIntrin(x, target_device.target_name) for x in fdevice]
fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost] fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost]
fhost = [ir_pass.CombineContextCall(x) for x in fhost] fhost = [ir_pass.CombineContextCall(x) for x in fhost]
mhost = codegen.build_module(fhost, str(target_host))
# Append fhost to the device module and return the updated module. All
# device modules will be imported to the host module after all of them are
# collected.
mdev = codegen.build_module(fdevice, str(target_device)) if fdevice else None
if postpone_host_codegen:
return fhost, mdev
mhost = codegen.build_module(fhost, str(target_host))
if fdevice: if fdevice:
mdev = codegen.build_module(fdevice, str(target_device))
mhost.import_module(mdev) mhost.import_module(mdev)
return mhost return mhost
...@@ -3,26 +3,24 @@ import numpy as np ...@@ -3,26 +3,24 @@ import numpy as np
from .._ffi.base import string_types from .._ffi.base import string_types
from .._ffi.function import get_global_func from .._ffi.function import get_global_func
from .._ffi.runtime_ctypes import TVMContext
from ..rpc import base as rpc_base from ..rpc import base as rpc_base
from .. import ndarray as nd
def create(graph_json_str, libmod, ctx): def create(graph_json_str, libmod, ctx):
"""Create a runtime executor module given a graph and module. """Create a runtime executor module given a graph and module.
Parameters Parameters
---------- ----------
graph_json_str : str or graph class graph_json_str : str or graph class
The graph to be deployed in json format output by nnvm graph. The graph to be deployed in json format output by nnvm graph.
The graph can only contain one operator(tvm_op) that The graph can only contain one operator(tvm_op) that
points to the name of PackedFunc in the libmod. points to the name of PackedFunc in the libmod.
libmod : tvm.Module libmod : tvm.Module
The module of the corresponding function The module of the corresponding function
ctx : TVMContext or list of TVMContext
ctx : TVMContext The context to deploy the module. It can be local or remote when there
The context to deploy the module, can be local or remote. is only one TVMContext. Otherwise, the first context in the list will
be used as this purpose. All context should be given for heterogeneous
execution.
Returns Returns
------- -------
graph_module : GraphModule graph_module : GraphModule
...@@ -33,17 +31,42 @@ def create(graph_json_str, libmod, ctx): ...@@ -33,17 +31,42 @@ def create(graph_json_str, libmod, ctx):
graph_json_str = graph_json_str._tvm_graph_json() graph_json_str = graph_json_str._tvm_graph_json()
except AttributeError: except AttributeError:
raise ValueError("Type %s is not supported" % type(graph_json_str)) raise ValueError("Type %s is not supported" % type(graph_json_str))
device_type = ctx.device_type if isinstance(ctx, TVMContext):
device_id = ctx.device_id ctx = [ctx]
elif not isinstance(ctx, (list, tuple)):
raise ValueError("ctx has to be the type of TVMContext or a list of "
"TVMCTVMContext")
for cur_ctx in ctx:
if not isinstance(cur_ctx, TVMContext):
raise ValueError("ctx has to be the type of TVMContext or a list "
"of TVMContext")
# device_type_id[0], device_type_id[1] are used as the primary/fallback
# context type and id. All other ones are used as device context for
# heterogeneous execution.
num_rpc_ctx = 0
device_type_id = []
for cur_ctx in ctx:
device_type = cur_ctx.device_type
if device_type >= rpc_base.RPC_SESS_MASK: if device_type >= rpc_base.RPC_SESS_MASK:
assert libmod.type_key == "rpc" assert libmod.type_key == "rpc"
assert rpc_base._SessTableIndex(libmod) == ctx._rpc_sess._tbl_index assert rpc_base._SessTableIndex(
libmod) == cur_ctx._rpc_sess._tbl_index
num_rpc_ctx += 1
device_type = cur_ctx.device_type % rpc_base.RPC_SESS_MASK
device_type_id.append(device_type)
device_type_id.append(cur_ctx.device_id)
if 0 < num_rpc_ctx < len(ctx):
raise ValueError("Either all or none of the contexts should be rpc.")
if num_rpc_ctx == len(ctx):
hmod = rpc_base._ModuleHandle(libmod) hmod = rpc_base._ModuleHandle(libmod)
fcreate = ctx._rpc_sess.get_function("tvm.graph_runtime.remote_create") fcreate = ctx[0]._rpc_sess.get_function("tvm.graph_runtime.remote_create")
device_type = device_type % rpc_base.RPC_SESS_MASK return GraphModule(fcreate(graph_json_str, hmod, *device_type_id))
return GraphModule(fcreate(graph_json_str, hmod, device_type, device_id), ctx)
fcreate = get_global_func("tvm.graph_runtime.create") fcreate = get_global_func("tvm.graph_runtime.create")
return GraphModule(fcreate(graph_json_str, libmod, device_type, device_id), ctx) return GraphModule(fcreate(graph_json_str, libmod, *device_type_id))
class GraphModule(object): class GraphModule(object):
...@@ -58,18 +81,13 @@ class GraphModule(object): ...@@ -58,18 +81,13 @@ class GraphModule(object):
module : Module module : Module
The interal tvm module that holds the actual graph functions. The interal tvm module that holds the actual graph functions.
ctx : TVMContext
The context this module is under
Attributes Attributes
---------- ----------
module : Module module : Module
The interal tvm module that holds the actual graph functions. The interal tvm module that holds the actual graph functions.
ctx : TVMContext
The context this module is under
""" """
def __init__(self, module, ctx):
def __init__(self, module):
self.module = module self.module = module
self._set_input = module["set_input"] self._set_input = module["set_input"]
self._run = module["run"] self._run = module["run"]
...@@ -81,7 +99,6 @@ class GraphModule(object): ...@@ -81,7 +99,6 @@ class GraphModule(object):
except AttributeError: except AttributeError:
pass pass
self._load_params = module["load_params"] self._load_params = module["load_params"]
self.ctx = ctx
def set_input(self, key=None, value=None, **params): def set_input(self, key=None, value=None, **params):
"""Set inputs to the module via kwargs """Set inputs to the module via kwargs
...@@ -98,14 +115,14 @@ class GraphModule(object): ...@@ -98,14 +115,14 @@ class GraphModule(object):
Additonal arguments Additonal arguments
""" """
if key: if key:
self._set_input(key, nd.array(value, ctx=self.ctx)) self._get_input(key).copyfrom(value)
if params: if params:
# upload big arrays first to avoid memory issue in rpc mode # upload big arrays first to avoid memory issue in rpc mode
keys = list(params.keys()) keys = list(params.keys())
keys.sort(key=lambda x: -np.prod(params[x].shape)) keys.sort(key=lambda x: -np.prod(params[x].shape))
for k in keys: for k in keys:
self._set_input(k, nd.array(params[k], ctx=self.ctx)) self._get_input(k).copyfrom(params[k])
def run(self, **input_dict): def run(self, **input_dict):
"""Run forward execution of the graph """Run forward execution of the graph
...@@ -177,7 +194,8 @@ class GraphModule(object): ...@@ -177,7 +194,8 @@ class GraphModule(object):
if hasattr(self, '_debug_get_output'): if hasattr(self, '_debug_get_output'):
self._debug_get_output(node, out) self._debug_get_output(node, out)
else: else:
raise RuntimeError("Please compile runtime with USE_GRAPH_RUNTIME_DEBUG = 0") raise RuntimeError(
"Please compile runtime with USE_GRAPH_RUNTIME_DEBUG = 0")
return out return out
def load_params(self, params_bytes): def load_params(self, params_bytes):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment