Commit 09236bf6 by Zhi Committed by Haichen Shen

[relay][pass] Annotation for heterogeneous compilation (#2361)

parent 30a5a600
...@@ -148,6 +148,8 @@ This level support backpropagation of broadcast operators. It is temporary. ...@@ -148,6 +148,8 @@ This level support backpropagation of broadcast operators. It is temporary.
tvm.relay.broadcast_to_like tvm.relay.broadcast_to_like
tvm.relay.collapse_sum_like tvm.relay.collapse_sum_like
tvm.relay.slice_like tvm.relay.slice_like
tvm.relay.device_copy
tvm.relay.annotation.on_device
Level 1 Definitions Level 1 Definitions
......
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/attrs/annotation.h
* \brief Attribute for annotation operators.
*/
#ifndef TVM_RELAY_ATTRS_ANNOTATION_H_
#define TVM_RELAY_ATTRS_ANNOTATION_H_
#include <tvm/attrs.h>
#include <string>
namespace tvm {
namespace relay {
/*!
* \brief Options for the device annotation operators.
*/
struct OnDeviceAttrs : public tvm::AttrsNode<OnDeviceAttrs> {
int device_type;
TVM_DECLARE_ATTRS(OnDeviceAttrs, "relay.attrs.OnDeviceAttrs") {
TVM_ATTR_FIELD(device_type)
.describe(
"The virutal device/context type that an expression is annotated with.")
.set_default(0);
}
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_ANNOTATION_H_
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/attrs/device_copy.h
* \brief Attribute for the device copy operator.
*/
#ifndef TVM_RELAY_ATTRS_DEVICE_COPY_H_
#define TVM_RELAY_ATTRS_DEVICE_COPY_H_
#include <tvm/attrs.h>
#include <string>
namespace tvm {
namespace relay {
/*!
* \brief Options for the device copy operators.
*/
struct DeviceCopyAttrs : public tvm::AttrsNode<DeviceCopyAttrs> {
int dst_dev_type;
int src_dev_type;
TVM_DECLARE_ATTRS(DeviceCopyAttrs, "relay.attrs.DeviceCopyAttrs") {
TVM_ATTR_FIELD(src_dev_type)
.describe(
"The virutal device/context type where the op copies data from.")
.set_default(0);
TVM_ATTR_FIELD(dst_dev_type)
.describe(
"The virutal device/context type where the op copies data to.")
.set_default(0);
}
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_DEVICE_COPY_H_
...@@ -164,7 +164,6 @@ struct ClipAttrs : public tvm::AttrsNode<ClipAttrs> { ...@@ -164,7 +164,6 @@ struct ClipAttrs : public tvm::AttrsNode<ClipAttrs> {
} }
}; };
struct LayoutTransformAttrs : public tvm::AttrsNode<LayoutTransformAttrs> { struct LayoutTransformAttrs : public tvm::AttrsNode<LayoutTransformAttrs> {
std::string src_layout; std::string src_layout;
std::string dst_layout; std::string dst_layout;
......
...@@ -188,6 +188,21 @@ Expr ForwardRewrite(const Expr& expr, ...@@ -188,6 +188,21 @@ Expr ForwardRewrite(const Expr& expr,
std::function<NodeRef(const Call&)> fcontext = nullptr, std::function<NodeRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr); std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
/*!
* \brief Rewrite the annotated program.
* \param expr The expression.
* \param fallback_device The fallback device which is the default device for
* operators without annotation.
* \return The updated program.
*/
Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device);
/*!
* \brief Collect the device mapping information of each expression.
* \param expr The expression.
* \return The device mapping.
*/
Map<Expr, Integer> CollectDeviceInfo(const Expr& expr);
/*! \brief A hashing structure in the style of std::hash. */ /*! \brief A hashing structure in the style of std::hash. */
struct StructuralHash { struct StructuralHash {
......
...@@ -25,6 +25,7 @@ from .ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl ...@@ -25,6 +25,7 @@ from .ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl
from .ndarray import vpi, rocm, opengl, ext_dev from .ndarray import vpi, rocm, opengl, ext_dev
from ._ffi.runtime_ctypes import TypeCode from ._ffi.runtime_ctypes import TypeCode
from ._ffi.ndarray import TVMContext
from ._ffi.function import Function from ._ffi.function import Function
from ._ffi.base import TVMError, __version__ from ._ffi.base import TVMError, __version__
from .api import * from .api import *
......
...@@ -18,6 +18,7 @@ from .op.reduce import * ...@@ -18,6 +18,7 @@ from .op.reduce import *
from .op.tensor import * from .op.tensor import *
from .op.transform import * from .op.transform import *
from . import nn from . import nn
from . import annotation
from . import vision from . import vision
from . import image from . import image
from . import frontend from . import frontend
......
# pylint: disable=wildcard-import, unused-import, unused-wildcard-import
"""Annotation related operators."""
# Re-export in a specific file name so that autodoc can pick it up
from .op.annotation import *
...@@ -52,8 +52,9 @@ def build(funcs, target, target_host=None): ...@@ -52,8 +52,9 @@ def build(funcs, target, target_host=None):
Parameters Parameters
---------- ----------
funcs : List[tvm.LoweredFunc] funcs : List[tvm.LoweredFunc] or Dict[str, List[tvm.LoweredFunc]]
The list of lowered functions. A list of lowered functions or dictionary mapping from targets to
lowered functions.
target : tvm.Target target : tvm.Target
......
...@@ -20,6 +20,7 @@ contrib.graph_runtime or any other TVM runtime comptatible system. ...@@ -20,6 +20,7 @@ contrib.graph_runtime or any other TVM runtime comptatible system.
from __future__ import absolute_import from __future__ import absolute_import
import json import json
from collections import defaultdict
import attr import attr
from . import _backend from . import _backend
from . import compile_engine from . import compile_engine
...@@ -27,6 +28,7 @@ from ..op import Op ...@@ -27,6 +28,7 @@ from ..op import Op
from ..expr import Function, GlobalVar from ..expr import Function, GlobalVar
from ..expr_functor import ExprFunctor from ..expr_functor import ExprFunctor
from ..ty import TupleType, TensorType from ..ty import TupleType, TensorType
from ... import target as _target
@attr.s @attr.s
...@@ -105,9 +107,9 @@ class GraphRuntimeCodegen(ExprFunctor): ...@@ -105,9 +107,9 @@ class GraphRuntimeCodegen(ExprFunctor):
self.nodes = [] self.nodes = []
self.var_map = {} self.var_map = {}
self.params = {} self.params = {}
self.storage_map = None self.storage_device_map = None
self.compile_engine = compile_engine.get() self.compile_engine = compile_engine.get()
self.lowered_funcs = set() self.lowered_funcs = defaultdict(set)
self._name_map = {} self._name_map = {}
def add_node(self, node, expr): def add_node(self, node, expr):
...@@ -129,10 +131,20 @@ class GraphRuntimeCodegen(ExprFunctor): ...@@ -129,10 +131,20 @@ class GraphRuntimeCodegen(ExprFunctor):
""" """
checked_type = expr.checked_type checked_type = expr.checked_type
# setup storage ids # setup storage ids
assert expr in self.storage_map assert expr in self.storage_device_map
node.attrs["storage_id"] = [ storage_device_info = self.storage_device_map[expr]
x.value for x in self.storage_map[expr] assert len(storage_device_info) == 2
] node.attrs["storage_id"] = [x.value for x in storage_device_info[0]]
device_types = [x.value for x in storage_device_info[1]]
num_unknown_devices = device_types.count(0)
if num_unknown_devices != 0 and num_unknown_devices != len(device_types):
raise RuntimeError("The graph contains not annotated nodes for "
"heterogeneous execution. All nodes must be "
"annotated.")
# Add the `device_index` attribute when the graph is annotated.
if num_unknown_devices == 0:
node.attrs["device_index"] = device_types
node_id = len(self.nodes) node_id = len(self.nodes)
self.nodes.append(node) self.nodes.append(node)
...@@ -232,9 +244,25 @@ class GraphRuntimeCodegen(ExprFunctor): ...@@ -232,9 +244,25 @@ class GraphRuntimeCodegen(ExprFunctor):
"TVM only support calls to primitive functions " + "TVM only support calls to primitive functions " +
"(i.e functions composed of fusable operator invocations)") "(i.e functions composed of fusable operator invocations)")
assert call in self.storage_device_map
device_types = self.storage_device_map[call][1]
call_dev_type = device_types[0].value
if isinstance(self.target, (str, _target.Target)):
# homogeneous execution.
cached_func = self.compile_engine.lower(func, self.target) cached_func = self.compile_engine.lower(func, self.target)
self.target = {0: str(self.target)}
elif isinstance(self.target, dict):
# heterogeneous execution.
if call_dev_type not in self.target:
raise Exception("No target is provided for device " +
"{0}".format(call_dev_type))
cached_func = self.compile_engine.lower(func,
self.target[call_dev_type])
else:
raise ValueError("self.target must be the type of str," +
"tvm.target.Target, or dict of int to str")
for loweredf in cached_func.funcs: for loweredf in cached_func.funcs:
self.lowered_funcs.add(loweredf) self.lowered_funcs[self.target[call_dev_type]].add(loweredf)
inputs = [] inputs = []
# flatten tuple in the call. # flatten tuple in the call.
...@@ -284,6 +312,7 @@ class GraphRuntimeCodegen(ExprFunctor): ...@@ -284,6 +312,7 @@ class GraphRuntimeCodegen(ExprFunctor):
num_entry = 0 num_entry = 0
shapes = [] shapes = []
storage_ids = [] storage_ids = []
device_types = []
dltypes = [] dltypes = []
node_row_ptr = [0] node_row_ptr = [0]
for node in self.nodes: for node in self.nodes:
...@@ -291,6 +320,8 @@ class GraphRuntimeCodegen(ExprFunctor): ...@@ -291,6 +320,8 @@ class GraphRuntimeCodegen(ExprFunctor):
shapes += node.attrs["shape"] shapes += node.attrs["shape"]
dltypes += node.attrs["dtype"] dltypes += node.attrs["dtype"]
storage_ids += node.attrs["storage_id"] storage_ids += node.attrs["storage_id"]
if "device_index" in node.attrs:
device_types += node.attrs["device_index"]
num_entry += node.num_outputs num_entry += node.num_outputs
node_row_ptr.append(num_entry) node_row_ptr.append(num_entry)
...@@ -298,6 +329,8 @@ class GraphRuntimeCodegen(ExprFunctor): ...@@ -298,6 +329,8 @@ class GraphRuntimeCodegen(ExprFunctor):
attrs = {} attrs = {}
attrs["shape"] = ["list_shape", shapes] attrs["shape"] = ["list_shape", shapes]
attrs["storage_id"] = ["list_int", storage_ids] attrs["storage_id"] = ["list_int", storage_ids]
if device_types:
attrs["device_index"] = ["list_int", device_types]
attrs["dltype"] = ["list_str", dltypes] attrs["dltype"] = ["list_str", dltypes]
json_dict = { json_dict = {
...@@ -313,11 +346,24 @@ class GraphRuntimeCodegen(ExprFunctor): ...@@ -313,11 +346,24 @@ class GraphRuntimeCodegen(ExprFunctor):
def debug_dump_memory_plan(self, func): def debug_dump_memory_plan(self, func):
"""Debug function to dump memory plan.""" """Debug function to dump memory plan."""
def _annotate(expr): def _annotate(expr):
if expr in self.storage_map: if expr in self.storage_device_map:
return str(self.storage_map[expr]) storage_device_info = self.storage_device_map[expr]
assert len(storage_device_info) == 2
return str(storage_device_info[0])
return ""
return func.astext(show_meta_data=False, annotate=_annotate)
def debug_dump_device_annotation(self, func):
"""Debug function to dump device annotation result."""
def _annotate(expr):
if expr in self.storage_device_map:
storage_device_info = self.storage_device_map[expr]
assert len(storage_device_info) == 2
return str(storage_device_info[1])
return "" return ""
return func.astext(show_meta_data=False, annotate=_annotate) return func.astext(show_meta_data=False, annotate=_annotate)
def codegen(self, func): def codegen(self, func):
"""Compile a single function into a graph. """Compile a single function into a graph.
...@@ -331,24 +377,31 @@ class GraphRuntimeCodegen(ExprFunctor): ...@@ -331,24 +377,31 @@ class GraphRuntimeCodegen(ExprFunctor):
graph_json : str graph_json : str
The graph json that can be consumed by runtime. The graph json that can be consumed by runtime.
lowered_funcs : List[tvm.LoweredFunc] lowered_funcs : List[tvm.LoweredFunc] or Dict[str, List[tvm.LoweredFunc]]
The lowered functions. The lowered functions.
params : Dict[str, tvm.nd.NDArray] params : Dict[str, tvm.nd.NDArray]
Additional constant parameters. Additional constant parameters.
""" """
self.storage_map = _backend.GraphPlanMemory(func) self.storage_device_map = _backend.GraphPlanMemory(func)
# First we convert all the parameters into input nodes. # First we convert all the parameters into input nodes.
for param in func.params: for param in func.params:
node = InputNode(param.name_hint, {}) node = InputNode(param.name_hint, {})
self.var_map[param] = self.add_node( self.var_map[param] = self.add_node(node, param)
node, param)
# Then we compile the body into a graph which can depend # Then we compile the body into a graph which can depend
# on input variables. # on input variables.
self.heads = self.visit(func.body) self.heads = self.visit(func.body)
graph_json = self._get_json() graph_json = self._get_json()
lowered_funcs = list(self.lowered_funcs)
# Return the lowered functions as a list for homogeneous compilation.
# Otherwise, for heterogeneous compilation, a dictionary containing
# the device id to a list of lowered functions is returned. Both forms
# are acceptable to tvm.build.
if not isinstance(self.target, dict):
lowered_funcs = list(list(self.lowered_funcs.values())[0])
else:
lowered_funcs = {k: list(v) for k, v in self.lowered_funcs.items()}
return graph_json, lowered_funcs, self.params return graph_json, lowered_funcs, self.params
def _get_unique_name(self, name): def _get_unique_name(self, name):
......
...@@ -2,6 +2,9 @@ ...@@ -2,6 +2,9 @@
Construct the necessary state for the TVM graph runtime Construct the necessary state for the TVM graph runtime
from a Relay expression. from a Relay expression.
""" """
import warnings
from tvm._ffi.runtime_ctypes import TVMContext
from ..build_module import build as _tvm_build_module from ..build_module import build as _tvm_build_module
from .. import nd as _nd, target as _target, autotvm from .. import nd as _nd, target as _target, autotvm
from ..contrib import graph_runtime as _graph_rt from ..contrib import graph_runtime as _graph_rt
...@@ -20,6 +23,7 @@ OPT_PASS_LEVEL = { ...@@ -20,6 +23,7 @@ OPT_PASS_LEVEL = {
"AlterOpLayout": 3, "AlterOpLayout": 3,
} }
class BuildConfig(object): class BuildConfig(object):
"""Configuration scope to set a build config option. """Configuration scope to set a build config option.
...@@ -33,12 +37,13 @@ class BuildConfig(object): ...@@ -33,12 +37,13 @@ class BuildConfig(object):
"opt_level": 2, "opt_level": 2,
"add_pass": None, "add_pass": None,
} }
def __init__(self, **kwargs): def __init__(self, **kwargs):
self._old_scope = None self._old_scope = None
for k, _ in kwargs.items(): for k, _ in kwargs.items():
if k not in BuildConfig.defaults: if k not in BuildConfig.defaults:
raise ValueError( raise ValueError("invalid argument %s, candidates are %s" %
"invalid argument %s, candidates are %s" % (k, BuildConfig.defaults.keys())) (k, BuildConfig.defaults.keys()))
self._attr = kwargs self._attr = kwargs
def __getattr__(self, name): def __getattr__(self, name):
...@@ -127,8 +132,10 @@ def optimize(func, target, params=None): ...@@ -127,8 +132,10 @@ def optimize(func, target, params=None):
func : tvm.relay.Function func : tvm.relay.Function
The input to optimization. The input to optimization.
target: :any:`tvm.target.Target` target : Optional[:any:`tvm.target.Target`, Dict[int, tvm.target.Target]]
The optimization target. Some optimization passes are target specific. The optimization target. For heterogeneous compilation, it is a
dictionary mapping device type to compilation target. For homogeneous
compilation, it is a build target.
params : Optional[Dict[str, tvm.nd.NDArray]] params : Optional[Dict[str, tvm.nd.NDArray]]
Input parameters to the graph that do not change Input parameters to the graph that do not change
...@@ -165,12 +172,19 @@ def optimize(func, target, params=None): ...@@ -165,12 +172,19 @@ def optimize(func, target, params=None):
func = ir_pass.forward_fold_scale_axis(func) func = ir_pass.forward_fold_scale_axis(func)
func = ir_pass.fold_constant(func) func = ir_pass.fold_constant(func)
# FIXME(zhiics) Skip AlterOpLayout pass for heterogeneous compilation for
# now. We probably need to pass target to this pass as well. Fix it in
# a followup PR.
if cfg.pass_enabled("AlterOpLayout"): if cfg.pass_enabled("AlterOpLayout"):
if isinstance(target, _target.Target):
func = ir_pass.infer_type(func) func = ir_pass.infer_type(func)
func = ir_pass.canonicalize_ops(func) func = ir_pass.canonicalize_ops(func)
func = ir_pass.infer_type(func) func = ir_pass.infer_type(func)
with target: with target:
func = ir_pass.alter_op_layout(func) func = ir_pass.alter_op_layout(func)
elif isinstance(target, dict):
warnings.warn("AlterOpLayout pass is not enabled for heterogeneous"
" execution yet.")
if cfg.pass_enabled("FoldConstant"): if cfg.pass_enabled("FoldConstant"):
func = ir_pass.fold_constant(func) func = ir_pass.fold_constant(func)
...@@ -178,10 +192,8 @@ def optimize(func, target, params=None): ...@@ -178,10 +192,8 @@ def optimize(func, target, params=None):
return func return func
def build(func, def build(func, target=None, target_host=None, params=None,
target=None, fallback_device=None):
target_host=None,
params=None):
"""Build a function to run on TVM graph runtime. """Build a function to run on TVM graph runtime.
Parameters Parameters
...@@ -189,10 +201,12 @@ def build(func, ...@@ -189,10 +201,12 @@ def build(func,
func: relay.Function func: relay.Function
The function to build. The function to build.
target : str or :any:`tvm.target.Target`, optional target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context
The build target name) to str/tvm.target.Target, optional
For heterogeneous compilation, it is a dictionary indicating context to
target mapping. For homogeneous compilation, it is a build target.
target_host : str or :any:`tvm.target.Target` optional target_host : str or :any:`tvm.target.Target`, optional
Host compilation target, if target is device. Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA, When TVM compiles device specific program such as CUDA,
we also need host(CPU) side code to interact with the driver we also need host(CPU) side code to interact with the driver
...@@ -205,6 +219,10 @@ def build(func, ...@@ -205,6 +219,10 @@ def build(func,
Input parameters to the graph that do not change Input parameters to the graph that do not change
during inference time. Used for constant folding. during inference time. Used for constant folding.
fallback_device : str or tvm.TVMContext, optional.
The fallback device. It is also used as the default device for
operators with no specified device.
Returns Returns
------- -------
graph_json : str graph_json : str
...@@ -219,11 +237,22 @@ def build(func, ...@@ -219,11 +237,22 @@ def build(func,
target = target if target else _target.current_target() target = target if target else _target.current_target()
if target is None: if target is None:
raise ValueError("Target is not set in env or passed as argument.") raise ValueError("Target is not set in env or passed as argument.")
if isinstance(target, dict):
target, fallback_device = \
_update_heterogeneous_inputs(target, fallback_device)
elif isinstance(target, (str, _target.Target)):
target = _target.create(target) target = _target.create(target)
else:
raise ValueError("target must be the type of str, tvm.target.Target," +
"or dict of device name to target")
# If current dispatch context is fallback context (the default root context), # If current dispatch context is fallback context (the default root context),
# then load pre-tuned parameters from TopHub # then load pre-tuned parameters from TopHub
if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext): if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext):
if isinstance(target, dict):
tophub_context = autotvm.tophub.context(list(target.values()))
else:
tophub_context = autotvm.tophub.context(target) tophub_context = autotvm.tophub.context(target)
else: else:
tophub_context = autotvm.util.EmptyContext() tophub_context = autotvm.util.EmptyContext()
...@@ -232,6 +261,10 @@ def build(func, ...@@ -232,6 +261,10 @@ def build(func,
with tophub_context: with tophub_context:
func = optimize(func, target, params) func = optimize(func, target, params)
# Annotate the ops for heterogeneous execution.
if isinstance(target, dict):
func, target = _run_device_annotation_passes(func, target,
fallback_device)
# Fuse ops before running code gen # Fuse ops before running code gen
func = ir_pass.infer_type(func) func = ir_pass.infer_type(func)
func = ir_pass.fuse_ops(func, cfg.opt_level) func = ir_pass.fuse_ops(func, cfg.opt_level)
...@@ -239,10 +272,112 @@ def build(func, ...@@ -239,10 +272,112 @@ def build(func,
func = ir_pass.infer_type(func) func = ir_pass.infer_type(func)
graph_gen = _graph_gen.GraphRuntimeCodegen(mod=None, target=target) graph_gen = _graph_gen.GraphRuntimeCodegen(mod=None, target=target)
graph_json, lowered_funcs, params = graph_gen.codegen(func) graph_json, lowered_funcs, params = graph_gen.codegen(func)
mod = _tvm_build_module(lowered_funcs, target=target, target_host=target_host) mod = _tvm_build_module(
lowered_funcs, target=target, target_host=target_host)
return graph_json, mod, params return graph_json, mod, params
def _update_heterogeneous_inputs(target, fallback_device=None):
"""Update the target and fallback device required for heterogeneous
compilation. CPU is used as the fallback device if it wasn't provided.
Meanwhile, a CPU device type and "llvm" pair will be added to the target
dictionary in this case.
Parameters
----------
target : dict of str(i.e. device/context name) to str/tvm.target.Target.
A dict contains context to target pairs.
fallback_device : str or tvm.TVMContext, optional.
The fallback device. It is also used as the default device for
operators with no specified device.
Returns
-------
device_target : dict of int to tvm.target.Target.
The updated device type to target dict.
fallback_device : int
The updated fallback device type.
"""
if not isinstance(target, dict):
raise ValueError("target must be dict of device name to target for " +
"heterogeneous execution, but received %s."
% type(target))
if fallback_device is None:
# cpu is used as the default fallback device when heterogeneous
# execution is needed, but no fallback device is provided.
fallback_device = _nd.cpu(0).device_type
target[fallback_device] = str(_target.create("llvm"))
elif isinstance(fallback_device, str):
fallback_device = _nd.context(fallback_device).device_type
elif isinstance(fallback_device, TVMContext):
fallback_device = fallback_device.device_type
else:
raise ValueError("fallback_device expects the type of str or" +
"TVMContext, but received %s." % type(fallback_device))
device_target = {}
for dev, tgt in target.items():
device_target[_nd.context(dev).device_type] = _target.create(tgt)
if fallback_device not in device_target:
raise ValueError("%s is used as the default device, but the target" +
"is not provided."
% _nd.context(fallback_device).device_name)
return device_target, fallback_device
def _run_device_annotation_passes(func, target, fallback_device):
"""Execute the device annotation passes to update the input program and
target information.
Parameters
----------
func: tvm.relay.Function
The function where annotation passes will be execute at.
target : Dict[int, tvm.target.Target]
A dict contains device type to target pairs.
fallback_device : int
The fallback device type.
Returns
-------
target : Dict[int, tvm.target.Target]
The updated device type to target dict.
func : tvm.relay.Function
The updated func.
"""
func = ir_pass.infer_type(func)
func = ir_pass.rewrite_annotated_ops(func, fallback_device)
device_map = ir_pass.collect_device_info(func)
# The expression to device type map will be empty if all or none of
# the expressions in the `func` are annotated because this map is
# obtained by propagating the device information in the device copy
# operator. None of the above cases needs device copy operator.
if not device_map:
annotation_map = ir_pass.collect_device_annotation_ops(func)
# No annotation.
if not annotation_map:
target = {0: target[fallback_device]}
else:
dev_type = next(iter(annotation_map.values()))
# All annotated with the same device type.
if all(val == dev_type for val in annotation_map.values()):
target = {0: target[dev_type]}
else:
raise RuntimeError("Expressions in the function are "
"annotated with various device types,"
"but not device copy operators "
"found. Please check the "
"RewriteAnnotation pass.")
return func, target
class GraphExecutor(_interpreter.Executor): class GraphExecutor(_interpreter.Executor):
"""Wrapper around Executor interface. """Wrapper around Executor interface.
...@@ -259,6 +394,7 @@ class GraphExecutor(_interpreter.Executor): ...@@ -259,6 +394,7 @@ class GraphExecutor(_interpreter.Executor):
target : :py:class:`Target` target : :py:class:`Target`
The target option to build the function. The target option to build the function.
""" """
def __init__(self, mod, ctx, target): def __init__(self, mod, ctx, target):
self.mod = mod self.mod = mod
self.ctx = ctx self.ctx = ctx
......
...@@ -357,3 +357,60 @@ def alter_op_layout(expr): ...@@ -357,3 +357,60 @@ def alter_op_layout(expr):
Transformed expression with alternated layout. Transformed expression with alternated layout.
""" """
return _ir_pass.AlterOpLayout(expr) return _ir_pass.AlterOpLayout(expr)
def rewrite_annotated_ops(expr, fallback_device):
"""Rewrite the annotated program where annotation operators, e.g.
`on_deivce`, mark which device an expression should be scheduled to.
This pass helps heterogeneous execution where different operators may need
to be allocated on various devices.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
fallback_device : int
The fallback device type. It is also used as the default device for
operators with no annotated device.
Returns
-------
transformed_expr : tvm.relay.Expr
Transformed expression with cross device data copy operators.
"""
return _ir_pass.RewriteDeviceAnnotation(expr, fallback_device)
def collect_device_info(expr):
"""Collect the device allocation map for the given expression. The device
ids are propagated from the `device_copy` operators.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
Returns
-------
ret : Dict[tvm.relay.expr, int]
A dictionary mapping tvm.relay.Expr to device type.
"""
return _ir_pass.CollectDeviceInfo(expr)
def collect_device_annotation_ops(expr):
"""Collect the device annotation ops for the given expression.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
Returns
-------
ret : Dict[tvm.relay.expr, int]
A dictionary mapping tvm.relay.Expr to device type where the keys are
annotation expressions.
"""
return _ir_pass.CollectDeviceAnnotationOps(expr)
...@@ -10,6 +10,7 @@ from .reduce import * ...@@ -10,6 +10,7 @@ from .reduce import *
from .tensor import * from .tensor import *
from .transform import * from .transform import *
from . import nn from . import nn
from . import annotation
from . import image from . import image
from . import vision from . import vision
from . import op_attrs from . import op_attrs
......
# pylint: disable=wildcard-import
"""Annotation related operators."""
from __future__ import absolute_import as _abs
from .annotation import *
"""Constructor APIs"""
from ...._ffi.function import _init_api
_init_api("relay.op.annotation._make", __name__)
"""Annotation operations."""
from __future__ import absolute_import as _abs
from . import _make
from .... import nd as _nd
from .... import TVMContext as _TVMContext
def on_device(data, device):
"""Annotate an expression with a certain device type.
Parameters
----------
data : tvm.relay.Expr
The expression to be annotated.
device : Union[:py:class:`TVMContext`, str]
The device type to annotate.
Returns
-------
result : tvm.relay.Expr
The annotated expression.
"""
if isinstance(device, _TVMContext):
device = device.device_type
elif isinstance(device, str):
device = _nd.context(device).device_type
else:
raise ValueError("device is expected to be the type of TVMContext or "
"str, but received %s" % (type(device)))
return _make.on_device(data, device)
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from . import _make from . import _make
from ..expr import Tuple from ..expr import Tuple
from ... import nd as _nd
from ... import TVMContext as _TVMContext
# We create a wrapper function for each operator in the # We create a wrapper function for each operator in the
# python side to call into the positional _make.OpName function. # python side to call into the positional _make.OpName function.
...@@ -616,3 +618,42 @@ def copy(data): ...@@ -616,3 +618,42 @@ def copy(data):
The copied result. The copied result.
""" """
return _make.copy(data) return _make.copy(data)
def device_copy(data, src_dev, dst_dev):
"""Copy data from the source device to the destination device. This
operator helps data transferring between difference contexts for
heterogeneous execution.
Parameters
----------
data : tvm.relay.Expr
The tensor to be copied.
src_dev : Union[:py:class:`TVMContext`, str]
The source device where the data is copied from.
dst_dev : Union[:py:class:`TVMContext`, str]
The destination device where the data is copied to.
Returns
-------
result : tvm.relay.Expr
The copied result.
"""
if isinstance(src_dev, _TVMContext):
src_dev = src_dev.device_type
elif isinstance(src_dev, str):
src_dev = _nd.context(src_dev).device_type
else:
raise ValueError("src_dev is expected to be the type of TVMContext or "
"str, but received %s" % (type(src_dev)))
if isinstance(dst_dev, _TVMContext):
dst_dev = dst_dev.device_type
elif isinstance(dst_dev, str):
dst_dev = _nd.context(dst_dev).device_type
else:
raise ValueError("dst_dev is expected to be the type of TVMContext or "
"str, but received %s" % (type(dst_dev)))
return _make.device_copy(data, src_dev, dst_dev)
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <tvm/packed_func_ext.h> #include <tvm/packed_func_ext.h>
#include <tvm/operation.h> #include <tvm/operation.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/pass.h> #include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
...@@ -82,11 +83,15 @@ class ScheduleGetter : ...@@ -82,11 +83,15 @@ class ScheduleGetter :
cache_node->func_name = readable_name_stream_.str(); cache_node->func_name = readable_name_stream_.str();
CachedFunc cfunc(cache_node); CachedFunc cfunc(cache_node);
CHECK(master_op_.defined()); CHECK(master_op_.defined());
Schedule schedule = fschedule[master_op_]( Schedule schedule;
master_attrs_, cache_node->outputs, target_); // No need to register schedule for device copy op.
if (master_attrs_.as<DeviceCopyAttrs>() == nullptr) {
schedule =
fschedule[master_op_](master_attrs_, cache_node->outputs, target_);
for (const auto& scalar : scalars_) { for (const auto& scalar : scalars_) {
schedule[scalar].compute_inline(); schedule[scalar].compute_inline();
} }
}
return std::make_pair(schedule, cfunc); return std::make_pair(schedule, cfunc);
} }
...@@ -153,11 +158,18 @@ class ScheduleGetter : ...@@ -153,11 +158,18 @@ class ScheduleGetter :
CHECK(call_node->op.as<OpNode>()) CHECK(call_node->op.as<OpNode>())
<< "Primitive function only allows call into primitive ops"; << "Primitive function only allows call into primitive ops";
Op op = Downcast<Op>(call_node->op); Op op = Downcast<Op>(call_node->op);
Array<Tensor> outputs = fcompute[op]( // Check if the op is a device copy op.
call_node->attrs, bool is_copy_op = op.same_as(Op::Get("device_copy"));
inputs, Array<Tensor> outputs;
call_node->checked_type(), // Skip fcompute for device copy operators as it is not registered.
target_); if (is_copy_op) {
const auto* copy_input = inputs[0].operator->();
outputs.push_back(TensorNode::make(copy_input->shape, copy_input->dtype,
Operation(), 0));
} else {
outputs = fcompute[op](call_node->attrs, inputs,
call_node->checked_type(), target_);
}
int op_pattern = fpattern[op]; int op_pattern = fpattern[op];
if (op_pattern >= kCommReduce) { if (op_pattern >= kCommReduce) {
...@@ -176,7 +188,14 @@ class ScheduleGetter : ...@@ -176,7 +188,14 @@ class ScheduleGetter :
CHECK(tuple_type) << "Expect output to be a tuple type"; CHECK(tuple_type) << "Expect output to be a tuple type";
CHECK_EQ(tuple_type->fields.size(), outputs.size()); CHECK_EQ(tuple_type->fields.size(), outputs.size());
} }
// Set the name to `__copy`. It will be detected in graph runtime to perform
// data copy across devices.
if (is_copy_op) {
readable_name_stream_.str(std::string());
readable_name_stream_ << "__copy";
} else {
readable_name_stream_ << '_' << op->name; readable_name_stream_ << '_' << op->name;
}
return outputs; return outputs;
} }
...@@ -291,6 +310,16 @@ class CompileEngineImpl : public CompileEngineNode { ...@@ -291,6 +310,16 @@ class CompileEngineImpl : public CompileEngineNode {
auto spair = CreateSchedule(key->source_func, key->target); auto spair = CreateSchedule(key->source_func, key->target);
auto cache_node = make_node<CachedFuncNode>( auto cache_node = make_node<CachedFuncNode>(
*(spair.second.operator->())); *(spair.second.operator->()));
// Skip lowering for device copy node.
const Expr body = (key->source_func)->body;
if (const CallNode* call_node = body.as<CallNode>()) {
if (call_node->attrs.as<DeviceCopyAttrs>()) {
value->cached_func = CachedFunc(cache_node);
return value;
}
}
cache_node->func_name = GetUniqueName(cache_node->func_name); cache_node->func_name = GetUniqueName(cache_node->func_name);
// NOTE: array will copy on write. // NOTE: array will copy on write.
Array<Tensor> all_args = cache_node->inputs; Array<Tensor> all_args = cache_node->inputs;
......
...@@ -6,11 +6,14 @@ ...@@ -6,11 +6,14 @@
*/ */
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/pass.h>
#include "../../common/arena.h" #include "../../common/arena.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
using IntegerArray = Array<Integer>;
struct StorageToken { struct StorageToken {
/*! \brief Reference counter */ /*! \brief Reference counter */
int ref_counter{0}; int ref_counter{0};
...@@ -18,8 +21,9 @@ struct StorageToken { ...@@ -18,8 +21,9 @@ struct StorageToken {
size_t max_bytes{0}; size_t max_bytes{0};
/*! \brief The corresponding tensor type node. */ /*! \brief The corresponding tensor type node. */
const TensorTypeNode* ttype{nullptr}; const TensorTypeNode* ttype{nullptr};
/*! \brief virtual device index */ /*! \brief virtual device index that corresponds to the device_type in
int device_id{0}; * DLContext. */
int device_type{0};
/*! \brief The storage id */ /*! \brief The storage id */
int64_t storage_id{-1}; int64_t storage_id{-1};
}; };
...@@ -106,33 +110,35 @@ class StorageAllocaBaseVisitor : public ExprVisitor { ...@@ -106,33 +110,35 @@ class StorageAllocaBaseVisitor : public ExprVisitor {
virtual void CreateToken(const ExprNode* op, bool can_realloc) = 0; virtual void CreateToken(const ExprNode* op, bool can_realloc) = 0;
}; };
class StorageAllocaInit : protected StorageAllocaBaseVisitor { class StorageAllocaInit : protected StorageAllocaBaseVisitor {
public: public:
explicit StorageAllocaInit(common::Arena* arena) explicit StorageAllocaInit(common::Arena* arena)
: arena_(arena) {} : arena_(arena) {}
/*! \return The internal token map */ /*! \return The internal token map */
std::unordered_map<const ExprNode*, std::vector<StorageToken*> > std::unordered_map<const ExprNode*, std::vector<StorageToken*> >
GetInitTokenMap(const Function& func) { GetInitTokenMap(const Function& func) {
node_device_map_ = CollectDeviceInfo(func);
this->Run(func); this->Run(func);
return std::move(token_map_); return std::move(token_map_);
} }
protected: protected:
using StorageAllocaBaseVisitor::VisitExpr_; using StorageAllocaBaseVisitor::VisitExpr_;
void CreateToken(const ExprNode* op, bool can_realloc) final { void CreateToken(const ExprNode* op, bool can_realloc) final {
CHECK(!token_map_.count(op)); CHECK(!token_map_.count(op));
std::vector<StorageToken*> tokens; std::vector<StorageToken*> tokens;
int device_type = node_device_map_.count(GetRef<Expr>(op))
? node_device_map_[GetRef<Expr>(op)]->value
: 0;
if (const auto* tuple_type = op->checked_type().as<TupleTypeNode>()) { if (const auto* tuple_type = op->checked_type().as<TupleTypeNode>()) {
for (Type t : tuple_type->fields) { for (Type t : tuple_type->fields) {
const auto* ttype = t.as<TensorTypeNode>(); const auto* ttype = t.as<TensorTypeNode>();
CHECK(ttype); CHECK(ttype);
StorageToken* token = arena_->make<StorageToken>(); StorageToken* token = arena_->make<StorageToken>();
token->ttype = ttype; token->ttype = ttype;
token->device_type = device_type;
tokens.push_back(token); tokens.push_back(token);
} }
} else { } else {
...@@ -140,6 +146,7 @@ class StorageAllocaInit : protected StorageAllocaBaseVisitor { ...@@ -140,6 +146,7 @@ class StorageAllocaInit : protected StorageAllocaBaseVisitor {
CHECK(ttype); CHECK(ttype);
StorageToken* token = arena_->make<StorageToken>(); StorageToken* token = arena_->make<StorageToken>();
token->ttype = ttype; token->ttype = ttype;
token->device_type = device_type;
tokens.push_back(token); tokens.push_back(token);
} }
token_map_[op] = tokens; token_map_[op] = tokens;
...@@ -159,9 +166,9 @@ class StorageAllocaInit : protected StorageAllocaBaseVisitor { ...@@ -159,9 +166,9 @@ class StorageAllocaInit : protected StorageAllocaBaseVisitor {
private: private:
// allocator // allocator
common::Arena* arena_; common::Arena* arena_;
Map<Expr, Integer> node_device_map_;
}; };
class StorageAllocator : public StorageAllocaBaseVisitor { class StorageAllocator : public StorageAllocaBaseVisitor {
public: public:
/*! /*!
...@@ -176,23 +183,39 @@ class StorageAllocator : public StorageAllocaBaseVisitor { ...@@ -176,23 +183,39 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
} }
// Run storage allocation for a function. // Run storage allocation for a function.
Map<Expr, Array<Integer> > Plan(const Function& func) { Map<Expr, Array<IntegerArray> > Plan(const Function& func) {
prototype_ = StorageAllocaInit(&arena_).GetInitTokenMap(func); prototype_ = StorageAllocaInit(&arena_).GetInitTokenMap(func);
this->Run(func); this->Run(func);
Map<Expr, Array<Integer> > smap; // The value of smap contains two integer arrays where the first array
// contains the planned storage ids and the second holds the device types.
Map<Expr, Array<IntegerArray> > smap;
int num_annotated_nodes = 0;
int num_nodes = 0;
for (const auto& kv : token_map_) { for (const auto& kv : token_map_) {
Array<Integer> vec; std::vector<Integer> storage_ids;
std::vector<Integer> device_types;
for (StorageToken* tok : kv.second) { for (StorageToken* tok : kv.second) {
vec.push_back(tok->storage_id); if (tok->device_type) {
num_annotated_nodes++;
} }
smap.Set(GetRef<Expr>(kv.first), vec); num_nodes++;
storage_ids.push_back(tok->storage_id);
device_types.push_back(tok->device_type);
}
smap.Set(GetRef<Expr>(kv.first), Array<IntegerArray>({storage_ids, device_types}));
}
// Either all or none of the nodes should be annotated.
if (num_annotated_nodes != 0 && num_annotated_nodes != num_nodes) {
LOG(FATAL)
<< num_annotated_nodes << " out of " << num_nodes
<< "expressions are assigned with virtual device types. Either all "
"or none of the expressions are expected to be annotated.";
} }
return smap; return smap;
} }
protected: protected:
using StorageAllocaBaseVisitor::VisitExpr_; using StorageAllocaBaseVisitor::VisitExpr_;
// override create token by getting token as prototype requirements. // override create token by getting token as prototype requirements.
...@@ -207,6 +230,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor { ...@@ -207,6 +230,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
} else { } else {
// Allocate a new token, // Allocate a new token,
StorageToken* allocated_tok = Alloc(tok, GetMemorySize(tok)); StorageToken* allocated_tok = Alloc(tok, GetMemorySize(tok));
allocated_tok->device_type = tok->device_type;
// ensure it never get de-allocated. // ensure it never get de-allocated.
allocated_tok->ref_counter += 1; allocated_tok->ref_counter += 1;
tokens.push_back(allocated_tok); tokens.push_back(allocated_tok);
...@@ -282,7 +306,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor { ...@@ -282,7 +306,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
// search for memory blocks larger than requested // search for memory blocks larger than requested
for (auto it = mid; it != end; ++it) { for (auto it = mid; it != end; ++it) {
StorageToken *tok = it->second; StorageToken *tok = it->second;
if (tok->device_id != prototype->device_id) continue; if (tok->device_type != prototype->device_type) continue;
CHECK_EQ(tok->ref_counter, 0); CHECK_EQ(tok->ref_counter, 0);
// Use exect matching strategy // Use exect matching strategy
tok->max_bytes = std::max(size, tok->max_bytes); tok->max_bytes = std::max(size, tok->max_bytes);
...@@ -295,7 +319,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor { ...@@ -295,7 +319,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
for (auto it = mid; it != begin;) { for (auto it = mid; it != begin;) {
--it; --it;
StorageToken *tok = it->second; StorageToken *tok = it->second;
if (tok->device_id != prototype->device_id) continue; if (tok->device_type != prototype->device_type) continue;
CHECK_EQ(tok->ref_counter, 0); CHECK_EQ(tok->ref_counter, 0);
// Use exect matching strategy // Use exect matching strategy
tok->max_bytes = std::max(size, tok->max_bytes); tok->max_bytes = std::max(size, tok->max_bytes);
...@@ -343,13 +367,12 @@ class StorageAllocator : public StorageAllocaBaseVisitor { ...@@ -343,13 +367,12 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
std::unordered_map<const ExprNode*, std::vector<StorageToken*> > prototype_; std::unordered_map<const ExprNode*, std::vector<StorageToken*> > prototype_;
}; };
Map<Expr, Array<IntegerArray> > GraphPlanMemory(const Function& func) {
Map<Expr, Array<Integer> > GraphPlanMemory(const Function& func) {
return StorageAllocator().Plan(func); return StorageAllocator().Plan(func);
} }
TVM_REGISTER_GLOBAL("relay.backend.GraphPlanMemory") TVM_REGISTER_GLOBAL("relay.backend.GraphPlanMemory")
.set_body_typed<Map<Expr, Array<Integer> >(const Function&)>(GraphPlanMemory); .set_body_typed<Map<Expr, Array<IntegerArray> >(const Function&)>(GraphPlanMemory);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
*
* \file src/relay/op/annotation/annotation.cc
* \brief Registration of annotation operators.
*/
#include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include "../type_relations.h"
#include "../../pass/alter_op_layout.h"
namespace tvm {
namespace relay {
// relay.annotation.on_device
TVM_REGISTER_NODE_TYPE(OnDeviceAttrs);
TVM_REGISTER_API("relay.op.annotation._make.on_device")
.set_body_typed<Expr(Expr, int)>([](Expr data, int device_type) {
auto attrs = make_node<OnDeviceAttrs>();
attrs->device_type = device_type;
static const Op& op = Op::Get("on_device");
return CallNode::make(op, {data}, Attrs(attrs), {});
});
RELAY_REGISTER_OP("on_device")
.describe(R"code(Annotate an expression with device type)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_support_level(10)
.add_type_rel("Identity", IdentityRel)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ElemwiseArbitraryLayout);
} // namespace relay
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
*
* \file src/relay/op/device_copy.cc
* \brief Crossing device data copy operator.
*
* The pattern of this operator is registered as kOpaque. Hence, it could be
* used as "barrier" to avoid fusing operators belonging to differen devices.
*/
#include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include "type_relations.h"
#include "../pass/alter_op_layout.h"
namespace tvm {
namespace relay {
// relay.device_copy
TVM_REGISTER_NODE_TYPE(DeviceCopyAttrs);
TVM_REGISTER_API("relay.op._make.device_copy")
.set_body_typed<Expr(Expr, int, int)>([](Expr data, int src_dev_type,
int dst_dev_type) {
auto attrs = make_node<DeviceCopyAttrs>();
attrs->src_dev_type = src_dev_type;
attrs->dst_dev_type = dst_dev_type;
static const Op& op = Op::Get("device_copy");
return CallNode::make(op, {data}, Attrs(attrs), {});
});
RELAY_REGISTER_OP("device_copy")
.describe(R"code(
Copy data from one tensor to another. The source and destination might be
on different devices.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_support_level(10)
.add_type_rel("Identity", IdentityRel)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ElemwiseArbitraryLayout);
} // namespace relay
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
*
* \file deivce_annotation.cc
* \brief Passes to rewrite annotated program and retrieve the device allocation
* of expression.
*
* The following passes are performed:
* 1. Validate the unnecessary and redundant annotation.
* 2. Rewrite the annotated program and insert data copy operators.
* 3. Collect the device allocation of each expression.
*/
#include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pass.h>
#include <memory>
#include <unordered_map>
#include <unordered_set>
namespace tvm {
namespace relay {
namespace {
bool IsOnDeviceNode(const ExprNode* node) {
const auto* call_node = dynamic_cast<const CallNode*>(node);
return call_node != nullptr && call_node->attrs.as<OnDeviceAttrs>();
}
bool IsDeviceCopyNode(const ExprNode* node) {
const auto* call_node = dynamic_cast<const CallNode*>(node);
return call_node != nullptr && call_node->attrs.as<DeviceCopyAttrs>();
}
} // namespace
class ValidateAnnotation : private ExprVisitor {
public:
static std::unordered_map<const ExprNode*, int> Validate(const Expr& expr) {
ValidateAnnotation valid;
valid(expr);
return valid.annotation_map_;
}
private:
void VisitExpr_(const CallNode* call_node) final {
if (IsOnDeviceNode(call_node)) {
int device_type = GetDeviceId(call_node);
if (annotation_map_.count(call_node)) {
CHECK_EQ(annotation_map_.at(call_node), device_type)
<< "An expression node can only be annotated to one device.";
} else {
annotation_map_.insert({call_node, GetDeviceId(call_node)});
}
CHECK_EQ(call_node->args.size(), 1U);
const auto* node = call_node->args[0].operator->();
if (annotation_map_.count(node)) {
CHECK_EQ(annotation_map_.at(node), device_type)
<< "An expression node can only be annotated to one device.";
} else {
annotation_map_.insert({node, GetDeviceId(call_node)});
}
}
ExprVisitor::VisitExpr_(call_node);
}
/*
* \brief Get the device type of the annotation node.
* \param call_node The on_device annotation call node.
* \return The device type.
*/
int GetDeviceId(const CallNode* call_node) {
CHECK(IsOnDeviceNode(call_node))
<< "The input call node must be on_device node.";
const OnDeviceAttrs* on_device_attr = call_node->attrs.as<OnDeviceAttrs>();
return on_device_attr->device_type;
}
std::unordered_map<const ExprNode*, int> annotation_map_;
};
// Replace the use of an expression with the output of a `copy_device` operator
// if the `on_device` operator takes the annotated expr as an input.
//
// This actually replaces annotation ops with device copy ops and connects any
// two dependent expressions with a `device_copy` op when needed. Note that the
// device type of a `device_copy` op is identical to that of the destination op
// since it is where the data should be copied to.
class RewriteAnnotation : public ExprMutator {
public:
Expr Rewrite(const Expr& expr, int fallback_device) {
fallback_device_ = fallback_device;
annotation_map_ = ValidateAnnotation::Validate(expr);
return this->VisitExpr(expr);
}
Expr VisitExpr_(const LetNode* op) final {
Expr value = GetDeviceCopyExpr(op->value, op);
Expr body = GetDeviceCopyExpr(op->body, op);
if (value.same_as(op->value) && body.same_as(op->body)) {
return ExprMutator::VisitExpr_(op);
} else {
Expr new_let = LetNode::make(op->var, value, body);
UpdateAnnotationMap(op, new_let.operator->());
return this->VisitExpr(new_let);
}
}
Expr VisitExp_(const TupleNode* op) {
Array<Expr> fields;
bool annotated = false;
for (const auto& field : fields) {
annotated |= NeedDeviceCopy(field.operator->(), op);
fields.push_back(GetDeviceCopyExpr(field, op));
}
if (annotated) {
Expr new_tuple = TupleNode::make(fields);
UpdateAnnotationMap(op, new_tuple.operator->());
return this->VisitExpr(new_tuple);
} else {
return ExprMutator::VisitExpr_(op);
}
}
Expr VisitExpr_(const TupleGetItemNode* op) final {
Expr tuple = op->tuple;
if (NeedDeviceCopy(tuple.operator->(), op)) {
Expr new_expr =
TupleGetItemNode::make(GetDeviceCopyExpr(tuple, op), op->index);
UpdateAnnotationMap(op, new_expr.operator->());
return this->VisitExpr(new_expr);
} else {
return ExprMutator::VisitExpr_(op);
}
}
Expr VisitExpr_(const IfNode* if_node) final {
Expr cond = GetDeviceCopyExpr(if_node->cond, if_node);
Expr true_br = GetDeviceCopyExpr(if_node->true_branch, if_node);
Expr false_br = GetDeviceCopyExpr(if_node->false_branch, if_node);
if (if_node->cond.same_as(cond) && if_node->true_branch.same_as(true_br) &&
if_node->false_branch.same_as(false_br)) {
return ExprMutator::VisitExpr_(if_node);
} else {
Expr new_if = IfNode::make(cond, true_br, false_br);
UpdateAnnotationMap(if_node, new_if.operator->());
return this->VisitExpr(new_if);
}
}
Expr VisitExpr_(const CallNode* call_node) final {
if (IsOnDeviceNode(call_node) || IsDeviceCopyNode(call_node)) {
return ExprMutator::VisitExpr_(call_node);
}
Array<Expr> new_args;
bool annotated = false;
for (const auto& arg : call_node->args) {
annotated |= NeedDeviceCopy(arg.operator->(), call_node);
new_args.push_back(GetDeviceCopyExpr(arg, call_node));
}
if (annotated) {
Call new_call = CallNode::make(call_node->op, new_args, call_node->attrs,
call_node->type_args);
UpdateAnnotationMap(call_node, new_call.operator->());
return this->VisitExpr(new_call);
} else {
return ExprMutator::VisitExpr_(call_node);
}
}
private:
void UpdateAnnotationMap(const ExprNode* old_node, const ExprNode* new_node) {
const auto it = annotation_map_.find(old_node);
if (it == annotation_map_.end()) {
annotation_map_.insert({new_node, fallback_device_});
} else {
annotation_map_.insert({new_node, it->second});
}
this->memo_[GetRef<Expr>(old_node)] = GetRef<Expr>(new_node);
}
Expr GetDeviceCopyExpr(const Expr& src, const ExprNode* dst) {
const auto* src_node = src.operator->();
if (!NeedDeviceCopy(src_node, dst)) return src;
const auto sit = annotation_map_.find(src_node);
if (sit == annotation_map_.end()) {
const auto dit = annotation_map_.find(dst);
CHECK(dit != annotation_map_.end())
<< "Device copy op is not required when both src and dst ops are not "
"annotated.";
return CreateDeviceCopy(src, fallback_device_, dit->second);
} else {
const auto dit = annotation_map_.find(dst);
int dst_dev_type =
dit == annotation_map_.end() ? fallback_device_ : dit->second;
return CreateDeviceCopy(src, sit->second, dst_dev_type);
}
}
// Check if a device copy op is need between two ops.
bool NeedDeviceCopy(const ExprNode* src, const ExprNode* dst) {
if (annotation_map_.count(src)) {
int src_dev_type = annotation_map_.at(src);
if (annotation_map_.count(dst)) {
return src_dev_type != annotation_map_.at(dst);
} else {
return src_dev_type != fallback_device_;
}
} else {
if (annotation_map_.count(dst)) {
// Though data copy op could be inserted whenever the `src` and `dst`
// ops are annotated to different devices, it leads to high overhead.
//
// Here we need across device data transferring only when `src` is a
// CallNode or FunctionNode and the `dst` is annotated with any device
// id other than fallback_device_.
if (src->is_type<CallNode>() || src->is_type<FunctionNode>()) {
return annotation_map_.at(dst) != fallback_device_;
} else {
return false;
}
} else {
return false;
}
}
}
/*
* \brief Create an operator to copy data from the source device to the
* destination device.
* \param src The source expression that produces data to be copied.
* \param src_dev_type The device type where the data is copied from.
* \param dst_dev_type The device type where the data is copied to.
* \return The created call node.
*/
Call CreateDeviceCopy(const Expr& src, int src_dev_type, int dst_dev_type) {
auto attrs = make_node<DeviceCopyAttrs>();
attrs->src_dev_type = src_dev_type;
attrs->dst_dev_type = dst_dev_type;
static const Op& op = Op::Get("device_copy");
Call device_copy = CallNode::make(op, {src}, Attrs(attrs), {});
annotation_map_.insert({device_copy.operator->(), dst_dev_type});
return device_copy;
}
std::unordered_map<const ExprNode*, int> annotation_map_;
int fallback_device_;
};
// Get all annotation expressions.
class AnnotatationVisitor : private ExprVisitor {
public:
static Map<Expr, Integer> GetAnnotations(const Expr& expr) {
AnnotatationVisitor visitor;
visitor(expr);
return visitor.annotations_;
}
private:
void VisitExpr_(const CallNode* call_node) {
if (IsOnDeviceNode(call_node)) {
const auto* attr = call_node->attrs.as<OnDeviceAttrs>();
annotations_.Set(GetRef<Expr>(call_node), attr->device_type);
}
ExprVisitor::VisitExpr_(call_node);
}
Map<Expr, Integer> annotations_;
};
/*
* \brief Return device allocation map based on the post order traversed graph.
* For the following program:
* .. code-block:: python
* x = relay.var("x")
* y = relay.var("y")
* add = relay.add(x, y)
* sqrt = relay.sqrt(add)
* log = relay.log(add)
* subtract = relay.subtract(sqrt, log)
* exp = relay.exp(subtract)
*
* Suppose we have annotated add, sqrt, and log with device 1, 2, and 3,
* respectively. The fallback/default device is 4. After Rewriting the
* program, we can have the following graph, where each copy op has both
* source and destination device type denoting which device the data should be
* copied from and to.
*
* x y
* \ /
* add/1
* / \
* copy1 copy2
* | |
* sqrt/2 log/3
* | |
* copy3 copy4
* \ /
* subtract
* |
* exp
*
* To Get the device mapping of each expression, we need to propagate the
* device information from the copy ops. This can be done in two passes.
* -Pass 1: Propagating the source device type to ops in a bottom-up way to the
* ancestors until encountering another copy op. For example, this way
* provides add, x, and y device types from the copy operator, `copy1`.
* -Pass 2: Propagating the destination device type of "the last" copy op in a
* top-down manner to the nodes on the output paths. For instance,
* this offers `subtract` and `exp` the same device type as `copy3`.
*/
class DeviceInfo {
public:
static Map<Expr, Integer> GetDeviceMap(const Expr& expr) {
DeviceInfo device_info;
device_info.post_visitor_ = PostDfsOrderVisitor();
device_info.post_visitor_.Visit(expr);
if (device_info.post_visitor_.num_device_copy_ops_ > 0) {
device_info.PropagateDeviceId();
return device_info.device_map_;
} else {
return Map<Expr, Integer>();
}
}
private:
class PostDfsOrderVisitor : private ExprVisitor {
public:
void Visit(const Expr& expr) { this->VisitExpr(expr); }
private:
// Post order traversal.
void VisitExpr_(const FunctionNode* fn) final {
ExprVisitor::VisitExpr_(fn);
// TODO(zhiics) Skip annotation of function node for now.
}
void VisitExpr_(const ConstantNode* cn) final {
post_dfs_order_.push_back(cn);
}
void VisitExpr_(const CallNode* call) final {
// Skip annotation nodes.
if (!IsOnDeviceNode(call)) {
ExprVisitor::VisitExpr_(call);
post_dfs_order_.push_back(call);
if (IsDeviceCopyNode(call)) {
num_device_copy_ops_++;
}
}
}
void VisitExpr_(const TupleNode* tn) final {
ExprVisitor::VisitExpr_(tn);
// TODO(zhiics) Skip annotation of tuple node for now.
}
void VisitExpr_(const TupleGetItemNode* op) final {
ExprVisitor::VisitExpr_(op);
post_dfs_order_.push_back(op);
}
void VisitExpr_(const VarNode* vn) final { post_dfs_order_.push_back(vn); }
void VisitExpr_(const LetNode* ln) final {
ExprVisitor::VisitExpr_(ln);
post_dfs_order_.push_back(ln);
}
void VisitExpr_(const IfNode* in) final {
ExprVisitor::VisitExpr_(in);
post_dfs_order_.push_back(in);
}
int num_device_copy_ops_{0};
std::vector<const ExprNode*> post_dfs_order_;
friend DeviceInfo;
};
void PropagateDeviceId() {
// Bottom-up propagation.
BottomUpPropagation();
// Top-down propagation.
TopDownPropagation();
}
void BottomUpPropagation() {
const CallNode* last_copy_node = nullptr;
int cur_dev_type = -1;
for (auto it = post_visitor_.post_dfs_order_.crbegin();
it != post_visitor_.post_dfs_order_.crend(); ++it) {
if (IsDeviceCopyNode(*it)) {
last_copy_node = dynamic_cast<const CallNode*>(*it);
const auto* attrs = last_copy_node->attrs.as<DeviceCopyAttrs>();
cur_dev_type = attrs->src_dev_type;
device_map_.Set(GetRef<Expr>(last_copy_node), attrs->dst_dev_type);
} else if (last_copy_node) {
Expr expr = GetRef<Expr>(*it);
CHECK_EQ(device_map_.count(expr), 0U);
device_map_.Set(expr, cur_dev_type);
}
}
}
void TopDownPropagation() {
const CallNode* last_copy_node = nullptr;
int cur_dev_type = -1;
for (const auto& it : post_visitor_.post_dfs_order_) {
if (IsDeviceCopyNode(it)) {
last_copy_node = dynamic_cast<const CallNode*>(it);
const auto* attrs = last_copy_node->attrs.as<DeviceCopyAttrs>();
cur_dev_type = attrs->dst_dev_type;
} else if (last_copy_node) {
Expr expr = GetRef<Expr>(it);
if (device_map_.count(expr) == 0) {
device_map_.Set(expr, cur_dev_type);
}
}
}
}
PostDfsOrderVisitor post_visitor_;
Map<Expr, Integer> device_map_;
};
Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device) {
RewriteAnnotation rewrote = RewriteAnnotation();
return rewrote.Rewrite(expr, fallback_device);
}
Map<Expr, Integer> CollectDeviceInfo(const Expr& expr) {
return DeviceInfo::GetDeviceMap(expr);
}
Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr) {
return AnnotatationVisitor::GetAnnotations(expr);
}
TVM_REGISTER_API("relay._ir_pass.CollectDeviceInfo")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = CollectDeviceInfo(args[0]);
});
TVM_REGISTER_API("relay._ir_pass.RewriteDeviceAnnotation")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = RewriteAnnotatedOps(args[0], args[1]);
});
TVM_REGISTER_API("relay._ir_pass.CollectDeviceAnnotationOps")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = CollectDeviceAnnotationOps(args[0]);
});
} // namespace relay
} // namespace tvm
...@@ -112,14 +112,19 @@ def test_plan_memory(): ...@@ -112,14 +112,19 @@ def test_plan_memory():
func = relay.ir_pass.infer_type(func) func = relay.ir_pass.infer_type(func)
smap = relay.backend._backend.GraphPlanMemory(func) smap = relay.backend._backend.GraphPlanMemory(func)
storage_ids = set() storage_ids = set()
device_types = set()
for k, v in smap.items(): for k, v in smap.items():
for x in v: assert len(v) == 2
for x in v[0]:
storage_ids.add(x.value) storage_ids.add(x.value)
for x in v[1]:
device_types.add(x.value)
# Current rule requires vars have unique storage id # Current rule requires vars have unique storage id
# because we don't do inplace, we will need another # because we don't do inplace, we will need another
# two alternating temporary space. # two alternating temporary space.
assert len(storage_ids) == 4 assert len(storage_ids) == 4
assert len(device_types) == 1
if __name__ == "__main__": if __name__ == "__main__":
......
"""Unit tests for heterogeneous compilation and execution."""
import numpy as np
import tvm
from tvm import relay
from tvm.relay import testing
from tvm.contrib import graph_runtime
def test_redundant_annotation():
ctx1 = tvm.context(1)
ctx2 = tvm.context(2)
x = relay.var("x", shape=(3,))
y = relay.var("y", shape=(3,))
z = relay.var("z", shape=(3,))
def annotated():
add = relay.add(x, y)
_add1 = relay.annotation.on_device(add, ctx2)
_add2 = relay.annotation.on_device(add, ctx2)
sub = relay.subtract(add, z)
func = relay.Function([x, y, z],
relay.Tuple(tvm.convert([_add1, _add2,
sub])))
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
ctx1.device_type)
func = relay.ir_pass.infer_type(func)
return relay.Function(relay.ir_pass.free_vars(func.body[2]),
func.body[2])
def expected():
add = relay.add(x, y)
copy_add_sub = relay.device_copy(add, ctx2, ctx1)
sub = relay.subtract(copy_add_sub, z)
func = relay.Function([x, y, z], sub)
return func
annotated_func = relay.ir_pass.infer_type(annotated())
expected_func = relay.ir_pass.infer_type(expected())
assert relay.ir_pass.alpha_equal(annotated_func, expected_func)
def test_annotate_all():
ctx1 = tvm.context(1)
ctx2 = tvm.context(2)
x = relay.var("x", shape=(3,))
y = relay.var("y", shape=(3,))
z = relay.var("z", shape=(3,))
def annotated():
add = relay.add(x, y)
_add = relay.annotation.on_device(add, ctx2)
sub = relay.subtract(add, z)
_sub = relay.annotation.on_device(sub, ctx2)
func = relay.Function([x, y, z],
relay.Tuple(tvm.convert([_add, _sub,
sub])))
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
ctx1.device_type)
func = relay.ir_pass.infer_type(func)
return relay.Function(relay.ir_pass.free_vars(func.body[2]),
func.body[2])
def expected():
add = relay.add(x, y)
sub = relay.subtract(add, z)
func = relay.Function([x, y, z], sub)
return func
annotated_func = relay.ir_pass.infer_type(annotated())
expected_func = relay.ir_pass.infer_type(expected())
def test_annotate_none():
ctx1 = tvm.context(1)
ctx2 = tvm.context(2)
x = relay.var("x", shape=(3,))
y = relay.var("y", shape=(3,))
z = relay.var("z", shape=(3,))
def annotated():
add = relay.add(x, y)
sub = relay.subtract(add, z)
func = relay.Function([x, y, z], sub)
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
ctx1.device_type)
return func
def expected():
add = relay.add(x, y)
sub = relay.subtract(add, z)
func = relay.Function([x, y, z], sub)
return func
annotated_func = relay.ir_pass.infer_type(annotated())
expected_func = relay.ir_pass.infer_type(expected())
assert relay.ir_pass.alpha_equal(annotated_func, expected_func)
def check_annotated_graph(annotated_func, expected_func):
annotated_func = relay.ir_pass.infer_type(annotated_func)
expected_func = relay.ir_pass.infer_type(expected_func)
assert relay.ir_pass.alpha_equal(annotated_func, expected_func)
def test_conv_network():
R""" The network is as following:
data1 data2
| |
conv2d conv2d
\ /
add
|
conv2d
"""
batch_size = 1
dshape = (batch_size, 64, 56, 56)
weight = relay.var("weight", shape=(64, 64, 3, 3))
data1 = relay.var("data1", shape=dshape)
data2 = relay.var("data2", shape=dshape)
dev1 = tvm.context(1)
dev2 = tvm.context(2)
def annotated():
conv2d_1 = relay.nn.conv2d(
data1,
weight,
channels=64,
kernel_size=(3, 3),
padding=(1, 1))
_conv2d_1 = relay.annotation.on_device(conv2d_1, dev2)
conv2d_2 = relay.nn.conv2d(
data2,
weight,
channels=64,
kernel_size=(3, 3),
padding=(1, 1))
_conv2d_2 = relay.annotation.on_device(conv2d_2, dev2)
add = relay.add(conv2d_1, conv2d_2)
_add = relay.annotation.on_device(add, dev1)
conv2d_3 = relay.nn.conv2d(
add,
weight,
channels=64,
kernel_size=(3, 3),
padding=(1, 1))
_conv2d_3 = relay.annotation.on_device(conv2d_3, dev2)
func = relay.Function([data1, data2, weight],
relay.Tuple(tvm.convert([_conv2d_1, _conv2d_2,
_conv2d_3, _add,
conv2d_3])))
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
tvm.context(3).device_type)
func = relay.ir_pass.infer_type(func)
return relay.Function(relay.ir_pass.free_vars(func.body[4]),
func.body[4])
def expected():
conv2d_1 = relay.nn.conv2d(
data1,
weight,
channels=64,
kernel_size=(3, 3),
padding=(1, 1))
device_copy1 = relay.device_copy(conv2d_1, dev2, dev1)
conv2d_2 = relay.nn.conv2d(
data2,
weight,
channels=64,
kernel_size=(3, 3),
padding=(1, 1))
device_copy2 = relay.device_copy(conv2d_2, dev2, dev1)
add = relay.add(device_copy1, device_copy2)
device_copy3 = relay.device_copy(add, dev1, dev2)
conv2d_3 = relay.nn.conv2d(
device_copy3,
weight,
channels=64,
kernel_size=(3, 3),
padding=(1, 1))
func = relay.Function([data1, weight, data2], conv2d_3)
return func
def check_storage_and_device_types():
func = annotated()
func = relay.ir_pass.rewrite_annotated_ops(func, 3)
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.fuse_ops(func, opt_level=2)
func = relay.ir_pass.infer_type(func)
smap = relay.backend._backend.GraphPlanMemory(func)
storage_ids = []
device_types = []
for _, storage_dev_type in smap.items():
assert len(storage_dev_type) == 2
for sid in storage_dev_type[0]:
storage_ids.append(sid.value)
for did in storage_dev_type[1]:
device_types.append(did.value)
assert len(storage_ids) == 10
assert len(set(storage_ids)) == 7
assert len(set(device_types)) == 2
assert set(device_types) == {1, 2}
annotated_func = annotated()
expected_func = expected()
check_annotated_graph(annotated_func, expected_func)
check_storage_and_device_types()
def test_fusible_network():
R""" The network is as following:
x y
\ /
add
/ \
sqrt log
\ /
subtract
|
exp
"""
x = relay.var("x", shape=(1, 10))
y = relay.var("y", shape=(10, 10))
x_data = np.random.rand(1, 10).astype('float32')
y_data = np.random.rand(10, 10).astype('float32')
tmp_add = x_data + y_data
tmp_sqrt = np.sqrt(tmp_add)
tmp_log = np.log(tmp_add)
tmp_sub = np.subtract(tmp_sqrt, tmp_log)
ref_res = np.exp(tmp_sub)
def get_func():
add = relay.add(x, y)
sqrt = relay.sqrt(add)
log = relay.log(add)
subtract = relay.subtract(sqrt, log)
exp = relay.exp(subtract)
func = relay.Function([x, y], exp)
return func
def test_runtime(target, device, func, fallback_device=None):
params = {"x": x_data, "y": y_data}
with relay.build_config(opt_level=1):
graph, lib, params = relay.build(
func,
target,
params=params,
fallback_device=fallback_device)
contexts = [tvm.cpu(0), tvm.context(device)]
mod = graph_runtime.create(graph, lib, contexts)
mod.set_input(**params)
mod.run()
res = mod.get_output(0).asnumpy()
tvm.testing.assert_allclose(res, ref_res, rtol=1e-5, atol=1e-5)
def test_fuse_log_add(device, tgt):
""" Only log and add are fused."""
fallback_device = tvm.context("cpu")
target = {"cpu": "llvm", device: tgt}
cpu_ctx = fallback_device
dev_ctx = tvm.context(device)
def annotated():
add = relay.add(x, y)
sqrt = relay.sqrt(add)
_sqrt = relay.annotation.on_device(sqrt, dev_ctx)
log = relay.log(add)
subtract = relay.subtract(sqrt, log)
exp = relay.exp(subtract)
_exp = relay.annotation.on_device(exp, dev_ctx)
func = relay.Function([x, y],
relay.Tuple(tvm.convert([_sqrt, _exp, exp])))
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
cpu_ctx.device_type)
func = relay.ir_pass.infer_type(func)
return relay.Function(relay.ir_pass.free_vars(func.body[2]),
func.body[2])
def expected():
add = relay.add(x, y)
copy_add_sqrt = relay.device_copy(add, cpu_ctx, dev_ctx)
sqrt = relay.sqrt(copy_add_sqrt)
log = relay.log(add)
copy_sqrt_subtract = relay.device_copy(sqrt, dev_ctx, cpu_ctx)
subtract = relay.subtract(copy_sqrt_subtract, log)
copy_sub_exp = relay.device_copy(subtract, cpu_ctx, dev_ctx)
exp = relay.exp(copy_sub_exp)
func = relay.Function([x, y], exp)
return func
annotated_func = annotated()
expected_func = expected()
check_annotated_graph(annotated_func, expected_func)
test_runtime(target, device, annotated_func, fallback_device)
def test_fuse_all(device, tgt):
"""Fuse all operators."""
fallback_device = tvm.context("cpu")
target = {"cpu": "llvm", device: tgt}
cpu_ctx = fallback_device
dev_ctx = tvm.context(device)
def annotated():
add = relay.add(x, y)
_add = relay.annotation.on_device(add, dev_ctx)
sqrt = relay.sqrt(add)
_sqrt = relay.annotation.on_device(sqrt, dev_ctx)
log = relay.log(add)
_log = relay.annotation.on_device(log, dev_ctx)
subtract = relay.subtract(sqrt, log)
_subtract = relay.annotation.on_device(subtract, dev_ctx)
exp = relay.exp(subtract)
_exp = relay.annotation.on_device(exp, dev_ctx)
func = relay.Function([x, y],
relay.Tuple(tvm.convert([_add, _sqrt, _log,
_subtract, _exp,
exp])))
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
cpu_ctx.device_type)
func = relay.ir_pass.infer_type(func)
return relay.Function(relay.ir_pass.free_vars(func.body[5]),
func.body[5])
annotated_func = annotated()
expected_func = get_func()
check_annotated_graph(annotated_func, expected_func)
test_runtime(target, device, annotated_func, fallback_device)
def test_fallback_exp(device, tgt):
fallback_device = tvm.context("cpu")
target = {"cpu": "llvm", device: tgt}
cpu_ctx = fallback_device
def annotated():
add = relay.add(x, y)
sqrt = relay.sqrt(add)
log = relay.log(add)
subtract = relay.subtract(sqrt, log)
exp = relay.exp(subtract)
_exp = relay.annotation.on_device(exp, cpu_ctx)
func = relay.Function([x, y],
relay.Tuple(tvm.convert([_exp, exp])))
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
cpu_ctx.device_type)
func = relay.ir_pass.infer_type(func)
return relay.Function(relay.ir_pass.free_vars(func.body[1]),
func.body[1])
annotated_func = annotated()
expected_func = get_func()
check_annotated_graph(annotated_func, expected_func)
test_runtime(target, device, annotated_func, fallback_device)
def test_fallback_all_operators(device, tgt):
target = {"cpu": "llvm", device: tgt}
fallback_device = tvm.cpu(0)
annotated_func = get_func()
expected_func = get_func()
check_annotated_graph(annotated_func, expected_func)
test_runtime(target, device, annotated_func, fallback_device)
for dev, tgt in [("opencl", "opencl"), ("cuda", "cuda"),
("opencl", str(tvm.target.intel_graphics()))]:
if not tvm.module.enabled(dev):
print("Skip test because %s is not enabled." % dev)
continue
test_fuse_log_add(dev, tgt)
test_fuse_all(dev, tgt)
test_fallback_exp(dev, tgt)
test_fallback_all_operators(dev, tgt)
if __name__ == "__main__":
test_redundant_annotation()
test_annotate_all()
test_annotate_none()
test_conv_network()
test_fusible_network()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment