Commit 09236bf6 by Zhi Committed by Haichen Shen

[relay][pass] Annotation for heterogeneous compilation (#2361)

parent 30a5a600
......@@ -148,6 +148,8 @@ This level support backpropagation of broadcast operators. It is temporary.
tvm.relay.broadcast_to_like
tvm.relay.collapse_sum_like
tvm.relay.slice_like
tvm.relay.device_copy
tvm.relay.annotation.on_device
Level 1 Definitions
......
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/attrs/annotation.h
* \brief Attribute for annotation operators.
*/
#ifndef TVM_RELAY_ATTRS_ANNOTATION_H_
#define TVM_RELAY_ATTRS_ANNOTATION_H_
#include <tvm/attrs.h>
#include <string>
namespace tvm {
namespace relay {
/*!
* \brief Options for the device annotation operators.
*/
struct OnDeviceAttrs : public tvm::AttrsNode<OnDeviceAttrs> {
int device_type;
TVM_DECLARE_ATTRS(OnDeviceAttrs, "relay.attrs.OnDeviceAttrs") {
TVM_ATTR_FIELD(device_type)
.describe(
"The virutal device/context type that an expression is annotated with.")
.set_default(0);
}
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_ANNOTATION_H_
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/attrs/device_copy.h
* \brief Attribute for the device copy operator.
*/
#ifndef TVM_RELAY_ATTRS_DEVICE_COPY_H_
#define TVM_RELAY_ATTRS_DEVICE_COPY_H_
#include <tvm/attrs.h>
#include <string>
namespace tvm {
namespace relay {
/*!
* \brief Options for the device copy operators.
*/
struct DeviceCopyAttrs : public tvm::AttrsNode<DeviceCopyAttrs> {
int dst_dev_type;
int src_dev_type;
TVM_DECLARE_ATTRS(DeviceCopyAttrs, "relay.attrs.DeviceCopyAttrs") {
TVM_ATTR_FIELD(src_dev_type)
.describe(
"The virutal device/context type where the op copies data from.")
.set_default(0);
TVM_ATTR_FIELD(dst_dev_type)
.describe(
"The virutal device/context type where the op copies data to.")
.set_default(0);
}
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_DEVICE_COPY_H_
......@@ -164,7 +164,6 @@ struct ClipAttrs : public tvm::AttrsNode<ClipAttrs> {
}
};
struct LayoutTransformAttrs : public tvm::AttrsNode<LayoutTransformAttrs> {
std::string src_layout;
std::string dst_layout;
......
......@@ -188,6 +188,21 @@ Expr ForwardRewrite(const Expr& expr,
std::function<NodeRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
/*!
* \brief Rewrite the annotated program.
* \param expr The expression.
* \param fallback_device The fallback device which is the default device for
* operators without annotation.
* \return The updated program.
*/
Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device);
/*!
* \brief Collect the device mapping information of each expression.
* \param expr The expression.
* \return The device mapping.
*/
Map<Expr, Integer> CollectDeviceInfo(const Expr& expr);
/*! \brief A hashing structure in the style of std::hash. */
struct StructuralHash {
......
......@@ -25,6 +25,7 @@ from .ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl
from .ndarray import vpi, rocm, opengl, ext_dev
from ._ffi.runtime_ctypes import TypeCode
from ._ffi.ndarray import TVMContext
from ._ffi.function import Function
from ._ffi.base import TVMError, __version__
from .api import *
......
......@@ -18,6 +18,7 @@ from .op.reduce import *
from .op.tensor import *
from .op.transform import *
from . import nn
from . import annotation
from . import vision
from . import image
from . import frontend
......
# pylint: disable=wildcard-import, unused-import, unused-wildcard-import
"""Annotation related operators."""
# Re-export in a specific file name so that autodoc can pick it up
from .op.annotation import *
......@@ -52,8 +52,9 @@ def build(funcs, target, target_host=None):
Parameters
----------
funcs : List[tvm.LoweredFunc]
The list of lowered functions.
funcs : List[tvm.LoweredFunc] or Dict[str, List[tvm.LoweredFunc]]
A list of lowered functions or dictionary mapping from targets to
lowered functions.
target : tvm.Target
......
......@@ -20,6 +20,7 @@ contrib.graph_runtime or any other TVM runtime comptatible system.
from __future__ import absolute_import
import json
from collections import defaultdict
import attr
from . import _backend
from . import compile_engine
......@@ -27,6 +28,7 @@ from ..op import Op
from ..expr import Function, GlobalVar
from ..expr_functor import ExprFunctor
from ..ty import TupleType, TensorType
from ... import target as _target
@attr.s
......@@ -105,9 +107,9 @@ class GraphRuntimeCodegen(ExprFunctor):
self.nodes = []
self.var_map = {}
self.params = {}
self.storage_map = None
self.storage_device_map = None
self.compile_engine = compile_engine.get()
self.lowered_funcs = set()
self.lowered_funcs = defaultdict(set)
self._name_map = {}
def add_node(self, node, expr):
......@@ -129,10 +131,20 @@ class GraphRuntimeCodegen(ExprFunctor):
"""
checked_type = expr.checked_type
# setup storage ids
assert expr in self.storage_map
node.attrs["storage_id"] = [
x.value for x in self.storage_map[expr]
]
assert expr in self.storage_device_map
storage_device_info = self.storage_device_map[expr]
assert len(storage_device_info) == 2
node.attrs["storage_id"] = [x.value for x in storage_device_info[0]]
device_types = [x.value for x in storage_device_info[1]]
num_unknown_devices = device_types.count(0)
if num_unknown_devices != 0 and num_unknown_devices != len(device_types):
raise RuntimeError("The graph contains not annotated nodes for "
"heterogeneous execution. All nodes must be "
"annotated.")
# Add the `device_index` attribute when the graph is annotated.
if num_unknown_devices == 0:
node.attrs["device_index"] = device_types
node_id = len(self.nodes)
self.nodes.append(node)
......@@ -232,9 +244,25 @@ class GraphRuntimeCodegen(ExprFunctor):
"TVM only support calls to primitive functions " +
"(i.e functions composed of fusable operator invocations)")
assert call in self.storage_device_map
device_types = self.storage_device_map[call][1]
call_dev_type = device_types[0].value
if isinstance(self.target, (str, _target.Target)):
# homogeneous execution.
cached_func = self.compile_engine.lower(func, self.target)
self.target = {0: str(self.target)}
elif isinstance(self.target, dict):
# heterogeneous execution.
if call_dev_type not in self.target:
raise Exception("No target is provided for device " +
"{0}".format(call_dev_type))
cached_func = self.compile_engine.lower(func,
self.target[call_dev_type])
else:
raise ValueError("self.target must be the type of str," +
"tvm.target.Target, or dict of int to str")
for loweredf in cached_func.funcs:
self.lowered_funcs.add(loweredf)
self.lowered_funcs[self.target[call_dev_type]].add(loweredf)
inputs = []
# flatten tuple in the call.
......@@ -284,6 +312,7 @@ class GraphRuntimeCodegen(ExprFunctor):
num_entry = 0
shapes = []
storage_ids = []
device_types = []
dltypes = []
node_row_ptr = [0]
for node in self.nodes:
......@@ -291,6 +320,8 @@ class GraphRuntimeCodegen(ExprFunctor):
shapes += node.attrs["shape"]
dltypes += node.attrs["dtype"]
storage_ids += node.attrs["storage_id"]
if "device_index" in node.attrs:
device_types += node.attrs["device_index"]
num_entry += node.num_outputs
node_row_ptr.append(num_entry)
......@@ -298,6 +329,8 @@ class GraphRuntimeCodegen(ExprFunctor):
attrs = {}
attrs["shape"] = ["list_shape", shapes]
attrs["storage_id"] = ["list_int", storage_ids]
if device_types:
attrs["device_index"] = ["list_int", device_types]
attrs["dltype"] = ["list_str", dltypes]
json_dict = {
......@@ -313,11 +346,24 @@ class GraphRuntimeCodegen(ExprFunctor):
def debug_dump_memory_plan(self, func):
"""Debug function to dump memory plan."""
def _annotate(expr):
if expr in self.storage_map:
return str(self.storage_map[expr])
if expr in self.storage_device_map:
storage_device_info = self.storage_device_map[expr]
assert len(storage_device_info) == 2
return str(storage_device_info[0])
return ""
return func.astext(show_meta_data=False, annotate=_annotate)
def debug_dump_device_annotation(self, func):
"""Debug function to dump device annotation result."""
def _annotate(expr):
if expr in self.storage_device_map:
storage_device_info = self.storage_device_map[expr]
assert len(storage_device_info) == 2
return str(storage_device_info[1])
return ""
return func.astext(show_meta_data=False, annotate=_annotate)
def codegen(self, func):
"""Compile a single function into a graph.
......@@ -331,24 +377,31 @@ class GraphRuntimeCodegen(ExprFunctor):
graph_json : str
The graph json that can be consumed by runtime.
lowered_funcs : List[tvm.LoweredFunc]
lowered_funcs : List[tvm.LoweredFunc] or Dict[str, List[tvm.LoweredFunc]]
The lowered functions.
params : Dict[str, tvm.nd.NDArray]
Additional constant parameters.
"""
self.storage_map = _backend.GraphPlanMemory(func)
self.storage_device_map = _backend.GraphPlanMemory(func)
# First we convert all the parameters into input nodes.
for param in func.params:
node = InputNode(param.name_hint, {})
self.var_map[param] = self.add_node(
node, param)
self.var_map[param] = self.add_node(node, param)
# Then we compile the body into a graph which can depend
# on input variables.
self.heads = self.visit(func.body)
graph_json = self._get_json()
lowered_funcs = list(self.lowered_funcs)
# Return the lowered functions as a list for homogeneous compilation.
# Otherwise, for heterogeneous compilation, a dictionary containing
# the device id to a list of lowered functions is returned. Both forms
# are acceptable to tvm.build.
if not isinstance(self.target, dict):
lowered_funcs = list(list(self.lowered_funcs.values())[0])
else:
lowered_funcs = {k: list(v) for k, v in self.lowered_funcs.items()}
return graph_json, lowered_funcs, self.params
def _get_unique_name(self, name):
......
......@@ -2,6 +2,9 @@
Construct the necessary state for the TVM graph runtime
from a Relay expression.
"""
import warnings
from tvm._ffi.runtime_ctypes import TVMContext
from ..build_module import build as _tvm_build_module
from .. import nd as _nd, target as _target, autotvm
from ..contrib import graph_runtime as _graph_rt
......@@ -20,6 +23,7 @@ OPT_PASS_LEVEL = {
"AlterOpLayout": 3,
}
class BuildConfig(object):
"""Configuration scope to set a build config option.
......@@ -33,12 +37,13 @@ class BuildConfig(object):
"opt_level": 2,
"add_pass": None,
}
def __init__(self, **kwargs):
self._old_scope = None
for k, _ in kwargs.items():
if k not in BuildConfig.defaults:
raise ValueError(
"invalid argument %s, candidates are %s" % (k, BuildConfig.defaults.keys()))
raise ValueError("invalid argument %s, candidates are %s" %
(k, BuildConfig.defaults.keys()))
self._attr = kwargs
def __getattr__(self, name):
......@@ -127,8 +132,10 @@ def optimize(func, target, params=None):
func : tvm.relay.Function
The input to optimization.
target: :any:`tvm.target.Target`
The optimization target. Some optimization passes are target specific.
target : Optional[:any:`tvm.target.Target`, Dict[int, tvm.target.Target]]
The optimization target. For heterogeneous compilation, it is a
dictionary mapping device type to compilation target. For homogeneous
compilation, it is a build target.
params : Optional[Dict[str, tvm.nd.NDArray]]
Input parameters to the graph that do not change
......@@ -165,12 +172,19 @@ def optimize(func, target, params=None):
func = ir_pass.forward_fold_scale_axis(func)
func = ir_pass.fold_constant(func)
# FIXME(zhiics) Skip AlterOpLayout pass for heterogeneous compilation for
# now. We probably need to pass target to this pass as well. Fix it in
# a followup PR.
if cfg.pass_enabled("AlterOpLayout"):
if isinstance(target, _target.Target):
func = ir_pass.infer_type(func)
func = ir_pass.canonicalize_ops(func)
func = ir_pass.infer_type(func)
with target:
func = ir_pass.alter_op_layout(func)
elif isinstance(target, dict):
warnings.warn("AlterOpLayout pass is not enabled for heterogeneous"
" execution yet.")
if cfg.pass_enabled("FoldConstant"):
func = ir_pass.fold_constant(func)
......@@ -178,10 +192,8 @@ def optimize(func, target, params=None):
return func
def build(func,
target=None,
target_host=None,
params=None):
def build(func, target=None, target_host=None, params=None,
fallback_device=None):
"""Build a function to run on TVM graph runtime.
Parameters
......@@ -189,10 +201,12 @@ def build(func,
func: relay.Function
The function to build.
target : str or :any:`tvm.target.Target`, optional
The build target
target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context
name) to str/tvm.target.Target, optional
For heterogeneous compilation, it is a dictionary indicating context to
target mapping. For homogeneous compilation, it is a build target.
target_host : str or :any:`tvm.target.Target` optional
target_host : str or :any:`tvm.target.Target`, optional
Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA,
we also need host(CPU) side code to interact with the driver
......@@ -205,6 +219,10 @@ def build(func,
Input parameters to the graph that do not change
during inference time. Used for constant folding.
fallback_device : str or tvm.TVMContext, optional.
The fallback device. It is also used as the default device for
operators with no specified device.
Returns
-------
graph_json : str
......@@ -219,11 +237,22 @@ def build(func,
target = target if target else _target.current_target()
if target is None:
raise ValueError("Target is not set in env or passed as argument.")
if isinstance(target, dict):
target, fallback_device = \
_update_heterogeneous_inputs(target, fallback_device)
elif isinstance(target, (str, _target.Target)):
target = _target.create(target)
else:
raise ValueError("target must be the type of str, tvm.target.Target," +
"or dict of device name to target")
# If current dispatch context is fallback context (the default root context),
# then load pre-tuned parameters from TopHub
if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext):
if isinstance(target, dict):
tophub_context = autotvm.tophub.context(list(target.values()))
else:
tophub_context = autotvm.tophub.context(target)
else:
tophub_context = autotvm.util.EmptyContext()
......@@ -232,6 +261,10 @@ def build(func,
with tophub_context:
func = optimize(func, target, params)
# Annotate the ops for heterogeneous execution.
if isinstance(target, dict):
func, target = _run_device_annotation_passes(func, target,
fallback_device)
# Fuse ops before running code gen
func = ir_pass.infer_type(func)
func = ir_pass.fuse_ops(func, cfg.opt_level)
......@@ -239,10 +272,112 @@ def build(func,
func = ir_pass.infer_type(func)
graph_gen = _graph_gen.GraphRuntimeCodegen(mod=None, target=target)
graph_json, lowered_funcs, params = graph_gen.codegen(func)
mod = _tvm_build_module(lowered_funcs, target=target, target_host=target_host)
mod = _tvm_build_module(
lowered_funcs, target=target, target_host=target_host)
return graph_json, mod, params
def _update_heterogeneous_inputs(target, fallback_device=None):
"""Update the target and fallback device required for heterogeneous
compilation. CPU is used as the fallback device if it wasn't provided.
Meanwhile, a CPU device type and "llvm" pair will be added to the target
dictionary in this case.
Parameters
----------
target : dict of str(i.e. device/context name) to str/tvm.target.Target.
A dict contains context to target pairs.
fallback_device : str or tvm.TVMContext, optional.
The fallback device. It is also used as the default device for
operators with no specified device.
Returns
-------
device_target : dict of int to tvm.target.Target.
The updated device type to target dict.
fallback_device : int
The updated fallback device type.
"""
if not isinstance(target, dict):
raise ValueError("target must be dict of device name to target for " +
"heterogeneous execution, but received %s."
% type(target))
if fallback_device is None:
# cpu is used as the default fallback device when heterogeneous
# execution is needed, but no fallback device is provided.
fallback_device = _nd.cpu(0).device_type
target[fallback_device] = str(_target.create("llvm"))
elif isinstance(fallback_device, str):
fallback_device = _nd.context(fallback_device).device_type
elif isinstance(fallback_device, TVMContext):
fallback_device = fallback_device.device_type
else:
raise ValueError("fallback_device expects the type of str or" +
"TVMContext, but received %s." % type(fallback_device))
device_target = {}
for dev, tgt in target.items():
device_target[_nd.context(dev).device_type] = _target.create(tgt)
if fallback_device not in device_target:
raise ValueError("%s is used as the default device, but the target" +
"is not provided."
% _nd.context(fallback_device).device_name)
return device_target, fallback_device
def _run_device_annotation_passes(func, target, fallback_device):
"""Execute the device annotation passes to update the input program and
target information.
Parameters
----------
func: tvm.relay.Function
The function where annotation passes will be execute at.
target : Dict[int, tvm.target.Target]
A dict contains device type to target pairs.
fallback_device : int
The fallback device type.
Returns
-------
target : Dict[int, tvm.target.Target]
The updated device type to target dict.
func : tvm.relay.Function
The updated func.
"""
func = ir_pass.infer_type(func)
func = ir_pass.rewrite_annotated_ops(func, fallback_device)
device_map = ir_pass.collect_device_info(func)
# The expression to device type map will be empty if all or none of
# the expressions in the `func` are annotated because this map is
# obtained by propagating the device information in the device copy
# operator. None of the above cases needs device copy operator.
if not device_map:
annotation_map = ir_pass.collect_device_annotation_ops(func)
# No annotation.
if not annotation_map:
target = {0: target[fallback_device]}
else:
dev_type = next(iter(annotation_map.values()))
# All annotated with the same device type.
if all(val == dev_type for val in annotation_map.values()):
target = {0: target[dev_type]}
else:
raise RuntimeError("Expressions in the function are "
"annotated with various device types,"
"but not device copy operators "
"found. Please check the "
"RewriteAnnotation pass.")
return func, target
class GraphExecutor(_interpreter.Executor):
"""Wrapper around Executor interface.
......@@ -259,6 +394,7 @@ class GraphExecutor(_interpreter.Executor):
target : :py:class:`Target`
The target option to build the function.
"""
def __init__(self, mod, ctx, target):
self.mod = mod
self.ctx = ctx
......
......@@ -357,3 +357,60 @@ def alter_op_layout(expr):
Transformed expression with alternated layout.
"""
return _ir_pass.AlterOpLayout(expr)
def rewrite_annotated_ops(expr, fallback_device):
"""Rewrite the annotated program where annotation operators, e.g.
`on_deivce`, mark which device an expression should be scheduled to.
This pass helps heterogeneous execution where different operators may need
to be allocated on various devices.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
fallback_device : int
The fallback device type. It is also used as the default device for
operators with no annotated device.
Returns
-------
transformed_expr : tvm.relay.Expr
Transformed expression with cross device data copy operators.
"""
return _ir_pass.RewriteDeviceAnnotation(expr, fallback_device)
def collect_device_info(expr):
"""Collect the device allocation map for the given expression. The device
ids are propagated from the `device_copy` operators.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
Returns
-------
ret : Dict[tvm.relay.expr, int]
A dictionary mapping tvm.relay.Expr to device type.
"""
return _ir_pass.CollectDeviceInfo(expr)
def collect_device_annotation_ops(expr):
"""Collect the device annotation ops for the given expression.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
Returns
-------
ret : Dict[tvm.relay.expr, int]
A dictionary mapping tvm.relay.Expr to device type where the keys are
annotation expressions.
"""
return _ir_pass.CollectDeviceAnnotationOps(expr)
......@@ -10,6 +10,7 @@ from .reduce import *
from .tensor import *
from .transform import *
from . import nn
from . import annotation
from . import image
from . import vision
from . import op_attrs
......
# pylint: disable=wildcard-import
"""Annotation related operators."""
from __future__ import absolute_import as _abs
from .annotation import *
"""Constructor APIs"""
from ...._ffi.function import _init_api
_init_api("relay.op.annotation._make", __name__)
"""Annotation operations."""
from __future__ import absolute_import as _abs
from . import _make
from .... import nd as _nd
from .... import TVMContext as _TVMContext
def on_device(data, device):
"""Annotate an expression with a certain device type.
Parameters
----------
data : tvm.relay.Expr
The expression to be annotated.
device : Union[:py:class:`TVMContext`, str]
The device type to annotate.
Returns
-------
result : tvm.relay.Expr
The annotated expression.
"""
if isinstance(device, _TVMContext):
device = device.device_type
elif isinstance(device, str):
device = _nd.context(device).device_type
else:
raise ValueError("device is expected to be the type of TVMContext or "
"str, but received %s" % (type(device)))
return _make.on_device(data, device)
......@@ -3,6 +3,8 @@
from __future__ import absolute_import as _abs
from . import _make
from ..expr import Tuple
from ... import nd as _nd
from ... import TVMContext as _TVMContext
# We create a wrapper function for each operator in the
# python side to call into the positional _make.OpName function.
......@@ -616,3 +618,42 @@ def copy(data):
The copied result.
"""
return _make.copy(data)
def device_copy(data, src_dev, dst_dev):
"""Copy data from the source device to the destination device. This
operator helps data transferring between difference contexts for
heterogeneous execution.
Parameters
----------
data : tvm.relay.Expr
The tensor to be copied.
src_dev : Union[:py:class:`TVMContext`, str]
The source device where the data is copied from.
dst_dev : Union[:py:class:`TVMContext`, str]
The destination device where the data is copied to.
Returns
-------
result : tvm.relay.Expr
The copied result.
"""
if isinstance(src_dev, _TVMContext):
src_dev = src_dev.device_type
elif isinstance(src_dev, str):
src_dev = _nd.context(src_dev).device_type
else:
raise ValueError("src_dev is expected to be the type of TVMContext or "
"str, but received %s" % (type(src_dev)))
if isinstance(dst_dev, _TVMContext):
dst_dev = dst_dev.device_type
elif isinstance(dst_dev, str):
dst_dev = _nd.context(dst_dev).device_type
else:
raise ValueError("dst_dev is expected to be the type of TVMContext or "
"str, but received %s" % (type(dst_dev)))
return _make.device_copy(data, src_dev, dst_dev)
......@@ -7,6 +7,7 @@
#include <tvm/packed_func_ext.h>
#include <tvm/operation.h>
#include <tvm/runtime/registry.h>
#include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
......@@ -82,11 +83,15 @@ class ScheduleGetter :
cache_node->func_name = readable_name_stream_.str();
CachedFunc cfunc(cache_node);
CHECK(master_op_.defined());
Schedule schedule = fschedule[master_op_](
master_attrs_, cache_node->outputs, target_);
Schedule schedule;
// No need to register schedule for device copy op.
if (master_attrs_.as<DeviceCopyAttrs>() == nullptr) {
schedule =
fschedule[master_op_](master_attrs_, cache_node->outputs, target_);
for (const auto& scalar : scalars_) {
schedule[scalar].compute_inline();
}
}
return std::make_pair(schedule, cfunc);
}
......@@ -153,11 +158,18 @@ class ScheduleGetter :
CHECK(call_node->op.as<OpNode>())
<< "Primitive function only allows call into primitive ops";
Op op = Downcast<Op>(call_node->op);
Array<Tensor> outputs = fcompute[op](
call_node->attrs,
inputs,
call_node->checked_type(),
target_);
// Check if the op is a device copy op.
bool is_copy_op = op.same_as(Op::Get("device_copy"));
Array<Tensor> outputs;
// Skip fcompute for device copy operators as it is not registered.
if (is_copy_op) {
const auto* copy_input = inputs[0].operator->();
outputs.push_back(TensorNode::make(copy_input->shape, copy_input->dtype,
Operation(), 0));
} else {
outputs = fcompute[op](call_node->attrs, inputs,
call_node->checked_type(), target_);
}
int op_pattern = fpattern[op];
if (op_pattern >= kCommReduce) {
......@@ -176,7 +188,14 @@ class ScheduleGetter :
CHECK(tuple_type) << "Expect output to be a tuple type";
CHECK_EQ(tuple_type->fields.size(), outputs.size());
}
// Set the name to `__copy`. It will be detected in graph runtime to perform
// data copy across devices.
if (is_copy_op) {
readable_name_stream_.str(std::string());
readable_name_stream_ << "__copy";
} else {
readable_name_stream_ << '_' << op->name;
}
return outputs;
}
......@@ -291,6 +310,16 @@ class CompileEngineImpl : public CompileEngineNode {
auto spair = CreateSchedule(key->source_func, key->target);
auto cache_node = make_node<CachedFuncNode>(
*(spair.second.operator->()));
// Skip lowering for device copy node.
const Expr body = (key->source_func)->body;
if (const CallNode* call_node = body.as<CallNode>()) {
if (call_node->attrs.as<DeviceCopyAttrs>()) {
value->cached_func = CachedFunc(cache_node);
return value;
}
}
cache_node->func_name = GetUniqueName(cache_node->func_name);
// NOTE: array will copy on write.
Array<Tensor> all_args = cache_node->inputs;
......
......@@ -6,11 +6,14 @@
*/
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pass.h>
#include "../../common/arena.h"
namespace tvm {
namespace relay {
using IntegerArray = Array<Integer>;
struct StorageToken {
/*! \brief Reference counter */
int ref_counter{0};
......@@ -18,8 +21,9 @@ struct StorageToken {
size_t max_bytes{0};
/*! \brief The corresponding tensor type node. */
const TensorTypeNode* ttype{nullptr};
/*! \brief virtual device index */
int device_id{0};
/*! \brief virtual device index that corresponds to the device_type in
* DLContext. */
int device_type{0};
/*! \brief The storage id */
int64_t storage_id{-1};
};
......@@ -106,33 +110,35 @@ class StorageAllocaBaseVisitor : public ExprVisitor {
virtual void CreateToken(const ExprNode* op, bool can_realloc) = 0;
};
class StorageAllocaInit : protected StorageAllocaBaseVisitor {
public:
explicit StorageAllocaInit(common::Arena* arena)
: arena_(arena) {}
/*! \return The internal token map */
std::unordered_map<const ExprNode*, std::vector<StorageToken*> >
GetInitTokenMap(const Function& func) {
node_device_map_ = CollectDeviceInfo(func);
this->Run(func);
return std::move(token_map_);
}
protected:
using StorageAllocaBaseVisitor::VisitExpr_;
void CreateToken(const ExprNode* op, bool can_realloc) final {
CHECK(!token_map_.count(op));
std::vector<StorageToken*> tokens;
int device_type = node_device_map_.count(GetRef<Expr>(op))
? node_device_map_[GetRef<Expr>(op)]->value
: 0;
if (const auto* tuple_type = op->checked_type().as<TupleTypeNode>()) {
for (Type t : tuple_type->fields) {
const auto* ttype = t.as<TensorTypeNode>();
CHECK(ttype);
StorageToken* token = arena_->make<StorageToken>();
token->ttype = ttype;
token->device_type = device_type;
tokens.push_back(token);
}
} else {
......@@ -140,6 +146,7 @@ class StorageAllocaInit : protected StorageAllocaBaseVisitor {
CHECK(ttype);
StorageToken* token = arena_->make<StorageToken>();
token->ttype = ttype;
token->device_type = device_type;
tokens.push_back(token);
}
token_map_[op] = tokens;
......@@ -159,9 +166,9 @@ class StorageAllocaInit : protected StorageAllocaBaseVisitor {
private:
// allocator
common::Arena* arena_;
Map<Expr, Integer> node_device_map_;
};
class StorageAllocator : public StorageAllocaBaseVisitor {
public:
/*!
......@@ -176,23 +183,39 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
}
// Run storage allocation for a function.
Map<Expr, Array<Integer> > Plan(const Function& func) {
Map<Expr, Array<IntegerArray> > Plan(const Function& func) {
prototype_ = StorageAllocaInit(&arena_).GetInitTokenMap(func);
this->Run(func);
Map<Expr, Array<Integer> > smap;
// The value of smap contains two integer arrays where the first array
// contains the planned storage ids and the second holds the device types.
Map<Expr, Array<IntegerArray> > smap;
int num_annotated_nodes = 0;
int num_nodes = 0;
for (const auto& kv : token_map_) {
Array<Integer> vec;
std::vector<Integer> storage_ids;
std::vector<Integer> device_types;
for (StorageToken* tok : kv.second) {
vec.push_back(tok->storage_id);
if (tok->device_type) {
num_annotated_nodes++;
}
smap.Set(GetRef<Expr>(kv.first), vec);
num_nodes++;
storage_ids.push_back(tok->storage_id);
device_types.push_back(tok->device_type);
}
smap.Set(GetRef<Expr>(kv.first), Array<IntegerArray>({storage_ids, device_types}));
}
// Either all or none of the nodes should be annotated.
if (num_annotated_nodes != 0 && num_annotated_nodes != num_nodes) {
LOG(FATAL)
<< num_annotated_nodes << " out of " << num_nodes
<< "expressions are assigned with virtual device types. Either all "
"or none of the expressions are expected to be annotated.";
}
return smap;
}
protected:
using StorageAllocaBaseVisitor::VisitExpr_;
// override create token by getting token as prototype requirements.
......@@ -207,6 +230,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
} else {
// Allocate a new token,
StorageToken* allocated_tok = Alloc(tok, GetMemorySize(tok));
allocated_tok->device_type = tok->device_type;
// ensure it never get de-allocated.
allocated_tok->ref_counter += 1;
tokens.push_back(allocated_tok);
......@@ -282,7 +306,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
// search for memory blocks larger than requested
for (auto it = mid; it != end; ++it) {
StorageToken *tok = it->second;
if (tok->device_id != prototype->device_id) continue;
if (tok->device_type != prototype->device_type) continue;
CHECK_EQ(tok->ref_counter, 0);
// Use exect matching strategy
tok->max_bytes = std::max(size, tok->max_bytes);
......@@ -295,7 +319,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
for (auto it = mid; it != begin;) {
--it;
StorageToken *tok = it->second;
if (tok->device_id != prototype->device_id) continue;
if (tok->device_type != prototype->device_type) continue;
CHECK_EQ(tok->ref_counter, 0);
// Use exect matching strategy
tok->max_bytes = std::max(size, tok->max_bytes);
......@@ -343,13 +367,12 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
std::unordered_map<const ExprNode*, std::vector<StorageToken*> > prototype_;
};
Map<Expr, Array<Integer> > GraphPlanMemory(const Function& func) {
Map<Expr, Array<IntegerArray> > GraphPlanMemory(const Function& func) {
return StorageAllocator().Plan(func);
}
TVM_REGISTER_GLOBAL("relay.backend.GraphPlanMemory")
.set_body_typed<Map<Expr, Array<Integer> >(const Function&)>(GraphPlanMemory);
.set_body_typed<Map<Expr, Array<IntegerArray> >(const Function&)>(GraphPlanMemory);
} // namespace relay
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
*
* \file src/relay/op/annotation/annotation.cc
* \brief Registration of annotation operators.
*/
#include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include "../type_relations.h"
#include "../../pass/alter_op_layout.h"
namespace tvm {
namespace relay {
// relay.annotation.on_device
TVM_REGISTER_NODE_TYPE(OnDeviceAttrs);
TVM_REGISTER_API("relay.op.annotation._make.on_device")
.set_body_typed<Expr(Expr, int)>([](Expr data, int device_type) {
auto attrs = make_node<OnDeviceAttrs>();
attrs->device_type = device_type;
static const Op& op = Op::Get("on_device");
return CallNode::make(op, {data}, Attrs(attrs), {});
});
RELAY_REGISTER_OP("on_device")
.describe(R"code(Annotate an expression with device type)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_support_level(10)
.add_type_rel("Identity", IdentityRel)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ElemwiseArbitraryLayout);
} // namespace relay
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
*
* \file src/relay/op/device_copy.cc
* \brief Crossing device data copy operator.
*
* The pattern of this operator is registered as kOpaque. Hence, it could be
* used as "barrier" to avoid fusing operators belonging to differen devices.
*/
#include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include "type_relations.h"
#include "../pass/alter_op_layout.h"
namespace tvm {
namespace relay {
// relay.device_copy
TVM_REGISTER_NODE_TYPE(DeviceCopyAttrs);
TVM_REGISTER_API("relay.op._make.device_copy")
.set_body_typed<Expr(Expr, int, int)>([](Expr data, int src_dev_type,
int dst_dev_type) {
auto attrs = make_node<DeviceCopyAttrs>();
attrs->src_dev_type = src_dev_type;
attrs->dst_dev_type = dst_dev_type;
static const Op& op = Op::Get("device_copy");
return CallNode::make(op, {data}, Attrs(attrs), {});
});
RELAY_REGISTER_OP("device_copy")
.describe(R"code(
Copy data from one tensor to another. The source and destination might be
on different devices.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_support_level(10)
.add_type_rel("Identity", IdentityRel)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ElemwiseArbitraryLayout);
} // namespace relay
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
*
* \file deivce_annotation.cc
* \brief Passes to rewrite annotated program and retrieve the device allocation
* of expression.
*
* The following passes are performed:
* 1. Validate the unnecessary and redundant annotation.
* 2. Rewrite the annotated program and insert data copy operators.
* 3. Collect the device allocation of each expression.
*/
#include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pass.h>
#include <memory>
#include <unordered_map>
#include <unordered_set>
namespace tvm {
namespace relay {
namespace {
bool IsOnDeviceNode(const ExprNode* node) {
const auto* call_node = dynamic_cast<const CallNode*>(node);
return call_node != nullptr && call_node->attrs.as<OnDeviceAttrs>();
}
bool IsDeviceCopyNode(const ExprNode* node) {
const auto* call_node = dynamic_cast<const CallNode*>(node);
return call_node != nullptr && call_node->attrs.as<DeviceCopyAttrs>();
}
} // namespace
class ValidateAnnotation : private ExprVisitor {
public:
static std::unordered_map<const ExprNode*, int> Validate(const Expr& expr) {
ValidateAnnotation valid;
valid(expr);
return valid.annotation_map_;
}
private:
void VisitExpr_(const CallNode* call_node) final {
if (IsOnDeviceNode(call_node)) {
int device_type = GetDeviceId(call_node);
if (annotation_map_.count(call_node)) {
CHECK_EQ(annotation_map_.at(call_node), device_type)
<< "An expression node can only be annotated to one device.";
} else {
annotation_map_.insert({call_node, GetDeviceId(call_node)});
}
CHECK_EQ(call_node->args.size(), 1U);
const auto* node = call_node->args[0].operator->();
if (annotation_map_.count(node)) {
CHECK_EQ(annotation_map_.at(node), device_type)
<< "An expression node can only be annotated to one device.";
} else {
annotation_map_.insert({node, GetDeviceId(call_node)});
}
}
ExprVisitor::VisitExpr_(call_node);
}
/*
* \brief Get the device type of the annotation node.
* \param call_node The on_device annotation call node.
* \return The device type.
*/
int GetDeviceId(const CallNode* call_node) {
CHECK(IsOnDeviceNode(call_node))
<< "The input call node must be on_device node.";
const OnDeviceAttrs* on_device_attr = call_node->attrs.as<OnDeviceAttrs>();
return on_device_attr->device_type;
}
std::unordered_map<const ExprNode*, int> annotation_map_;
};
// Replace the use of an expression with the output of a `copy_device` operator
// if the `on_device` operator takes the annotated expr as an input.
//
// This actually replaces annotation ops with device copy ops and connects any
// two dependent expressions with a `device_copy` op when needed. Note that the
// device type of a `device_copy` op is identical to that of the destination op
// since it is where the data should be copied to.
class RewriteAnnotation : public ExprMutator {
public:
Expr Rewrite(const Expr& expr, int fallback_device) {
fallback_device_ = fallback_device;
annotation_map_ = ValidateAnnotation::Validate(expr);
return this->VisitExpr(expr);
}
Expr VisitExpr_(const LetNode* op) final {
Expr value = GetDeviceCopyExpr(op->value, op);
Expr body = GetDeviceCopyExpr(op->body, op);
if (value.same_as(op->value) && body.same_as(op->body)) {
return ExprMutator::VisitExpr_(op);
} else {
Expr new_let = LetNode::make(op->var, value, body);
UpdateAnnotationMap(op, new_let.operator->());
return this->VisitExpr(new_let);
}
}
Expr VisitExp_(const TupleNode* op) {
Array<Expr> fields;
bool annotated = false;
for (const auto& field : fields) {
annotated |= NeedDeviceCopy(field.operator->(), op);
fields.push_back(GetDeviceCopyExpr(field, op));
}
if (annotated) {
Expr new_tuple = TupleNode::make(fields);
UpdateAnnotationMap(op, new_tuple.operator->());
return this->VisitExpr(new_tuple);
} else {
return ExprMutator::VisitExpr_(op);
}
}
Expr VisitExpr_(const TupleGetItemNode* op) final {
Expr tuple = op->tuple;
if (NeedDeviceCopy(tuple.operator->(), op)) {
Expr new_expr =
TupleGetItemNode::make(GetDeviceCopyExpr(tuple, op), op->index);
UpdateAnnotationMap(op, new_expr.operator->());
return this->VisitExpr(new_expr);
} else {
return ExprMutator::VisitExpr_(op);
}
}
Expr VisitExpr_(const IfNode* if_node) final {
Expr cond = GetDeviceCopyExpr(if_node->cond, if_node);
Expr true_br = GetDeviceCopyExpr(if_node->true_branch, if_node);
Expr false_br = GetDeviceCopyExpr(if_node->false_branch, if_node);
if (if_node->cond.same_as(cond) && if_node->true_branch.same_as(true_br) &&
if_node->false_branch.same_as(false_br)) {
return ExprMutator::VisitExpr_(if_node);
} else {
Expr new_if = IfNode::make(cond, true_br, false_br);
UpdateAnnotationMap(if_node, new_if.operator->());
return this->VisitExpr(new_if);
}
}
Expr VisitExpr_(const CallNode* call_node) final {
if (IsOnDeviceNode(call_node) || IsDeviceCopyNode(call_node)) {
return ExprMutator::VisitExpr_(call_node);
}
Array<Expr> new_args;
bool annotated = false;
for (const auto& arg : call_node->args) {
annotated |= NeedDeviceCopy(arg.operator->(), call_node);
new_args.push_back(GetDeviceCopyExpr(arg, call_node));
}
if (annotated) {
Call new_call = CallNode::make(call_node->op, new_args, call_node->attrs,
call_node->type_args);
UpdateAnnotationMap(call_node, new_call.operator->());
return this->VisitExpr(new_call);
} else {
return ExprMutator::VisitExpr_(call_node);
}
}
private:
void UpdateAnnotationMap(const ExprNode* old_node, const ExprNode* new_node) {
const auto it = annotation_map_.find(old_node);
if (it == annotation_map_.end()) {
annotation_map_.insert({new_node, fallback_device_});
} else {
annotation_map_.insert({new_node, it->second});
}
this->memo_[GetRef<Expr>(old_node)] = GetRef<Expr>(new_node);
}
Expr GetDeviceCopyExpr(const Expr& src, const ExprNode* dst) {
const auto* src_node = src.operator->();
if (!NeedDeviceCopy(src_node, dst)) return src;
const auto sit = annotation_map_.find(src_node);
if (sit == annotation_map_.end()) {
const auto dit = annotation_map_.find(dst);
CHECK(dit != annotation_map_.end())
<< "Device copy op is not required when both src and dst ops are not "
"annotated.";
return CreateDeviceCopy(src, fallback_device_, dit->second);
} else {
const auto dit = annotation_map_.find(dst);
int dst_dev_type =
dit == annotation_map_.end() ? fallback_device_ : dit->second;
return CreateDeviceCopy(src, sit->second, dst_dev_type);
}
}
// Check if a device copy op is need between two ops.
bool NeedDeviceCopy(const ExprNode* src, const ExprNode* dst) {
if (annotation_map_.count(src)) {
int src_dev_type = annotation_map_.at(src);
if (annotation_map_.count(dst)) {
return src_dev_type != annotation_map_.at(dst);
} else {
return src_dev_type != fallback_device_;
}
} else {
if (annotation_map_.count(dst)) {
// Though data copy op could be inserted whenever the `src` and `dst`
// ops are annotated to different devices, it leads to high overhead.
//
// Here we need across device data transferring only when `src` is a
// CallNode or FunctionNode and the `dst` is annotated with any device
// id other than fallback_device_.
if (src->is_type<CallNode>() || src->is_type<FunctionNode>()) {
return annotation_map_.at(dst) != fallback_device_;
} else {
return false;
}
} else {
return false;
}
}
}
/*
* \brief Create an operator to copy data from the source device to the
* destination device.
* \param src The source expression that produces data to be copied.
* \param src_dev_type The device type where the data is copied from.
* \param dst_dev_type The device type where the data is copied to.
* \return The created call node.
*/
Call CreateDeviceCopy(const Expr& src, int src_dev_type, int dst_dev_type) {
auto attrs = make_node<DeviceCopyAttrs>();
attrs->src_dev_type = src_dev_type;
attrs->dst_dev_type = dst_dev_type;
static const Op& op = Op::Get("device_copy");
Call device_copy = CallNode::make(op, {src}, Attrs(attrs), {});
annotation_map_.insert({device_copy.operator->(), dst_dev_type});
return device_copy;
}
std::unordered_map<const ExprNode*, int> annotation_map_;
int fallback_device_;
};
// Get all annotation expressions.
class AnnotatationVisitor : private ExprVisitor {
public:
static Map<Expr, Integer> GetAnnotations(const Expr& expr) {
AnnotatationVisitor visitor;
visitor(expr);
return visitor.annotations_;
}
private:
void VisitExpr_(const CallNode* call_node) {
if (IsOnDeviceNode(call_node)) {
const auto* attr = call_node->attrs.as<OnDeviceAttrs>();
annotations_.Set(GetRef<Expr>(call_node), attr->device_type);
}
ExprVisitor::VisitExpr_(call_node);
}
Map<Expr, Integer> annotations_;
};
/*
* \brief Return device allocation map based on the post order traversed graph.
* For the following program:
* .. code-block:: python
* x = relay.var("x")
* y = relay.var("y")
* add = relay.add(x, y)
* sqrt = relay.sqrt(add)
* log = relay.log(add)
* subtract = relay.subtract(sqrt, log)
* exp = relay.exp(subtract)
*
* Suppose we have annotated add, sqrt, and log with device 1, 2, and 3,
* respectively. The fallback/default device is 4. After Rewriting the
* program, we can have the following graph, where each copy op has both
* source and destination device type denoting which device the data should be
* copied from and to.
*
* x y
* \ /
* add/1
* / \
* copy1 copy2
* | |
* sqrt/2 log/3
* | |
* copy3 copy4
* \ /
* subtract
* |
* exp
*
* To Get the device mapping of each expression, we need to propagate the
* device information from the copy ops. This can be done in two passes.
* -Pass 1: Propagating the source device type to ops in a bottom-up way to the
* ancestors until encountering another copy op. For example, this way
* provides add, x, and y device types from the copy operator, `copy1`.
* -Pass 2: Propagating the destination device type of "the last" copy op in a
* top-down manner to the nodes on the output paths. For instance,
* this offers `subtract` and `exp` the same device type as `copy3`.
*/
class DeviceInfo {
public:
static Map<Expr, Integer> GetDeviceMap(const Expr& expr) {
DeviceInfo device_info;
device_info.post_visitor_ = PostDfsOrderVisitor();
device_info.post_visitor_.Visit(expr);
if (device_info.post_visitor_.num_device_copy_ops_ > 0) {
device_info.PropagateDeviceId();
return device_info.device_map_;
} else {
return Map<Expr, Integer>();
}
}
private:
class PostDfsOrderVisitor : private ExprVisitor {
public:
void Visit(const Expr& expr) { this->VisitExpr(expr); }
private:
// Post order traversal.
void VisitExpr_(const FunctionNode* fn) final {
ExprVisitor::VisitExpr_(fn);
// TODO(zhiics) Skip annotation of function node for now.
}
void VisitExpr_(const ConstantNode* cn) final {
post_dfs_order_.push_back(cn);
}
void VisitExpr_(const CallNode* call) final {
// Skip annotation nodes.
if (!IsOnDeviceNode(call)) {
ExprVisitor::VisitExpr_(call);
post_dfs_order_.push_back(call);
if (IsDeviceCopyNode(call)) {
num_device_copy_ops_++;
}
}
}
void VisitExpr_(const TupleNode* tn) final {
ExprVisitor::VisitExpr_(tn);
// TODO(zhiics) Skip annotation of tuple node for now.
}
void VisitExpr_(const TupleGetItemNode* op) final {
ExprVisitor::VisitExpr_(op);
post_dfs_order_.push_back(op);
}
void VisitExpr_(const VarNode* vn) final { post_dfs_order_.push_back(vn); }
void VisitExpr_(const LetNode* ln) final {
ExprVisitor::VisitExpr_(ln);
post_dfs_order_.push_back(ln);
}
void VisitExpr_(const IfNode* in) final {
ExprVisitor::VisitExpr_(in);
post_dfs_order_.push_back(in);
}
int num_device_copy_ops_{0};
std::vector<const ExprNode*> post_dfs_order_;
friend DeviceInfo;
};
void PropagateDeviceId() {
// Bottom-up propagation.
BottomUpPropagation();
// Top-down propagation.
TopDownPropagation();
}
void BottomUpPropagation() {
const CallNode* last_copy_node = nullptr;
int cur_dev_type = -1;
for (auto it = post_visitor_.post_dfs_order_.crbegin();
it != post_visitor_.post_dfs_order_.crend(); ++it) {
if (IsDeviceCopyNode(*it)) {
last_copy_node = dynamic_cast<const CallNode*>(*it);
const auto* attrs = last_copy_node->attrs.as<DeviceCopyAttrs>();
cur_dev_type = attrs->src_dev_type;
device_map_.Set(GetRef<Expr>(last_copy_node), attrs->dst_dev_type);
} else if (last_copy_node) {
Expr expr = GetRef<Expr>(*it);
CHECK_EQ(device_map_.count(expr), 0U);
device_map_.Set(expr, cur_dev_type);
}
}
}
void TopDownPropagation() {
const CallNode* last_copy_node = nullptr;
int cur_dev_type = -1;
for (const auto& it : post_visitor_.post_dfs_order_) {
if (IsDeviceCopyNode(it)) {
last_copy_node = dynamic_cast<const CallNode*>(it);
const auto* attrs = last_copy_node->attrs.as<DeviceCopyAttrs>();
cur_dev_type = attrs->dst_dev_type;
} else if (last_copy_node) {
Expr expr = GetRef<Expr>(it);
if (device_map_.count(expr) == 0) {
device_map_.Set(expr, cur_dev_type);
}
}
}
}
PostDfsOrderVisitor post_visitor_;
Map<Expr, Integer> device_map_;
};
Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device) {
RewriteAnnotation rewrote = RewriteAnnotation();
return rewrote.Rewrite(expr, fallback_device);
}
Map<Expr, Integer> CollectDeviceInfo(const Expr& expr) {
return DeviceInfo::GetDeviceMap(expr);
}
Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr) {
return AnnotatationVisitor::GetAnnotations(expr);
}
TVM_REGISTER_API("relay._ir_pass.CollectDeviceInfo")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = CollectDeviceInfo(args[0]);
});
TVM_REGISTER_API("relay._ir_pass.RewriteDeviceAnnotation")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = RewriteAnnotatedOps(args[0], args[1]);
});
TVM_REGISTER_API("relay._ir_pass.CollectDeviceAnnotationOps")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = CollectDeviceAnnotationOps(args[0]);
});
} // namespace relay
} // namespace tvm
......@@ -112,14 +112,19 @@ def test_plan_memory():
func = relay.ir_pass.infer_type(func)
smap = relay.backend._backend.GraphPlanMemory(func)
storage_ids = set()
device_types = set()
for k, v in smap.items():
for x in v:
assert len(v) == 2
for x in v[0]:
storage_ids.add(x.value)
for x in v[1]:
device_types.add(x.value)
# Current rule requires vars have unique storage id
# because we don't do inplace, we will need another
# two alternating temporary space.
assert len(storage_ids) == 4
assert len(device_types) == 1
if __name__ == "__main__":
......
"""Unit tests for heterogeneous compilation and execution."""
import numpy as np
import tvm
from tvm import relay
from tvm.relay import testing
from tvm.contrib import graph_runtime
def test_redundant_annotation():
ctx1 = tvm.context(1)
ctx2 = tvm.context(2)
x = relay.var("x", shape=(3,))
y = relay.var("y", shape=(3,))
z = relay.var("z", shape=(3,))
def annotated():
add = relay.add(x, y)
_add1 = relay.annotation.on_device(add, ctx2)
_add2 = relay.annotation.on_device(add, ctx2)
sub = relay.subtract(add, z)
func = relay.Function([x, y, z],
relay.Tuple(tvm.convert([_add1, _add2,
sub])))
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
ctx1.device_type)
func = relay.ir_pass.infer_type(func)
return relay.Function(relay.ir_pass.free_vars(func.body[2]),
func.body[2])
def expected():
add = relay.add(x, y)
copy_add_sub = relay.device_copy(add, ctx2, ctx1)
sub = relay.subtract(copy_add_sub, z)
func = relay.Function([x, y, z], sub)
return func
annotated_func = relay.ir_pass.infer_type(annotated())
expected_func = relay.ir_pass.infer_type(expected())
assert relay.ir_pass.alpha_equal(annotated_func, expected_func)
def test_annotate_all():
ctx1 = tvm.context(1)
ctx2 = tvm.context(2)
x = relay.var("x", shape=(3,))
y = relay.var("y", shape=(3,))
z = relay.var("z", shape=(3,))
def annotated():
add = relay.add(x, y)
_add = relay.annotation.on_device(add, ctx2)
sub = relay.subtract(add, z)
_sub = relay.annotation.on_device(sub, ctx2)
func = relay.Function([x, y, z],
relay.Tuple(tvm.convert([_add, _sub,
sub])))
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
ctx1.device_type)
func = relay.ir_pass.infer_type(func)
return relay.Function(relay.ir_pass.free_vars(func.body[2]),
func.body[2])
def expected():
add = relay.add(x, y)
sub = relay.subtract(add, z)
func = relay.Function([x, y, z], sub)
return func
annotated_func = relay.ir_pass.infer_type(annotated())
expected_func = relay.ir_pass.infer_type(expected())
def test_annotate_none():
ctx1 = tvm.context(1)
ctx2 = tvm.context(2)
x = relay.var("x", shape=(3,))
y = relay.var("y", shape=(3,))
z = relay.var("z", shape=(3,))
def annotated():
add = relay.add(x, y)
sub = relay.subtract(add, z)
func = relay.Function([x, y, z], sub)
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
ctx1.device_type)
return func
def expected():
add = relay.add(x, y)
sub = relay.subtract(add, z)
func = relay.Function([x, y, z], sub)
return func
annotated_func = relay.ir_pass.infer_type(annotated())
expected_func = relay.ir_pass.infer_type(expected())
assert relay.ir_pass.alpha_equal(annotated_func, expected_func)
def check_annotated_graph(annotated_func, expected_func):
annotated_func = relay.ir_pass.infer_type(annotated_func)
expected_func = relay.ir_pass.infer_type(expected_func)
assert relay.ir_pass.alpha_equal(annotated_func, expected_func)
def test_conv_network():
R""" The network is as following:
data1 data2
| |
conv2d conv2d
\ /
add
|
conv2d
"""
batch_size = 1
dshape = (batch_size, 64, 56, 56)
weight = relay.var("weight", shape=(64, 64, 3, 3))
data1 = relay.var("data1", shape=dshape)
data2 = relay.var("data2", shape=dshape)
dev1 = tvm.context(1)
dev2 = tvm.context(2)
def annotated():
conv2d_1 = relay.nn.conv2d(
data1,
weight,
channels=64,
kernel_size=(3, 3),
padding=(1, 1))
_conv2d_1 = relay.annotation.on_device(conv2d_1, dev2)
conv2d_2 = relay.nn.conv2d(
data2,
weight,
channels=64,
kernel_size=(3, 3),
padding=(1, 1))
_conv2d_2 = relay.annotation.on_device(conv2d_2, dev2)
add = relay.add(conv2d_1, conv2d_2)
_add = relay.annotation.on_device(add, dev1)
conv2d_3 = relay.nn.conv2d(
add,
weight,
channels=64,
kernel_size=(3, 3),
padding=(1, 1))
_conv2d_3 = relay.annotation.on_device(conv2d_3, dev2)
func = relay.Function([data1, data2, weight],
relay.Tuple(tvm.convert([_conv2d_1, _conv2d_2,
_conv2d_3, _add,
conv2d_3])))
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
tvm.context(3).device_type)
func = relay.ir_pass.infer_type(func)
return relay.Function(relay.ir_pass.free_vars(func.body[4]),
func.body[4])
def expected():
conv2d_1 = relay.nn.conv2d(
data1,
weight,
channels=64,
kernel_size=(3, 3),
padding=(1, 1))
device_copy1 = relay.device_copy(conv2d_1, dev2, dev1)
conv2d_2 = relay.nn.conv2d(
data2,
weight,
channels=64,
kernel_size=(3, 3),
padding=(1, 1))
device_copy2 = relay.device_copy(conv2d_2, dev2, dev1)
add = relay.add(device_copy1, device_copy2)
device_copy3 = relay.device_copy(add, dev1, dev2)
conv2d_3 = relay.nn.conv2d(
device_copy3,
weight,
channels=64,
kernel_size=(3, 3),
padding=(1, 1))
func = relay.Function([data1, weight, data2], conv2d_3)
return func
def check_storage_and_device_types():
func = annotated()
func = relay.ir_pass.rewrite_annotated_ops(func, 3)
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.fuse_ops(func, opt_level=2)
func = relay.ir_pass.infer_type(func)
smap = relay.backend._backend.GraphPlanMemory(func)
storage_ids = []
device_types = []
for _, storage_dev_type in smap.items():
assert len(storage_dev_type) == 2
for sid in storage_dev_type[0]:
storage_ids.append(sid.value)
for did in storage_dev_type[1]:
device_types.append(did.value)
assert len(storage_ids) == 10
assert len(set(storage_ids)) == 7
assert len(set(device_types)) == 2
assert set(device_types) == {1, 2}
annotated_func = annotated()
expected_func = expected()
check_annotated_graph(annotated_func, expected_func)
check_storage_and_device_types()
def test_fusible_network():
R""" The network is as following:
x y
\ /
add
/ \
sqrt log
\ /
subtract
|
exp
"""
x = relay.var("x", shape=(1, 10))
y = relay.var("y", shape=(10, 10))
x_data = np.random.rand(1, 10).astype('float32')
y_data = np.random.rand(10, 10).astype('float32')
tmp_add = x_data + y_data
tmp_sqrt = np.sqrt(tmp_add)
tmp_log = np.log(tmp_add)
tmp_sub = np.subtract(tmp_sqrt, tmp_log)
ref_res = np.exp(tmp_sub)
def get_func():
add = relay.add(x, y)
sqrt = relay.sqrt(add)
log = relay.log(add)
subtract = relay.subtract(sqrt, log)
exp = relay.exp(subtract)
func = relay.Function([x, y], exp)
return func
def test_runtime(target, device, func, fallback_device=None):
params = {"x": x_data, "y": y_data}
with relay.build_config(opt_level=1):
graph, lib, params = relay.build(
func,
target,
params=params,
fallback_device=fallback_device)
contexts = [tvm.cpu(0), tvm.context(device)]
mod = graph_runtime.create(graph, lib, contexts)
mod.set_input(**params)
mod.run()
res = mod.get_output(0).asnumpy()
tvm.testing.assert_allclose(res, ref_res, rtol=1e-5, atol=1e-5)
def test_fuse_log_add(device, tgt):
""" Only log and add are fused."""
fallback_device = tvm.context("cpu")
target = {"cpu": "llvm", device: tgt}
cpu_ctx = fallback_device
dev_ctx = tvm.context(device)
def annotated():
add = relay.add(x, y)
sqrt = relay.sqrt(add)
_sqrt = relay.annotation.on_device(sqrt, dev_ctx)
log = relay.log(add)
subtract = relay.subtract(sqrt, log)
exp = relay.exp(subtract)
_exp = relay.annotation.on_device(exp, dev_ctx)
func = relay.Function([x, y],
relay.Tuple(tvm.convert([_sqrt, _exp, exp])))
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
cpu_ctx.device_type)
func = relay.ir_pass.infer_type(func)
return relay.Function(relay.ir_pass.free_vars(func.body[2]),
func.body[2])
def expected():
add = relay.add(x, y)
copy_add_sqrt = relay.device_copy(add, cpu_ctx, dev_ctx)
sqrt = relay.sqrt(copy_add_sqrt)
log = relay.log(add)
copy_sqrt_subtract = relay.device_copy(sqrt, dev_ctx, cpu_ctx)
subtract = relay.subtract(copy_sqrt_subtract, log)
copy_sub_exp = relay.device_copy(subtract, cpu_ctx, dev_ctx)
exp = relay.exp(copy_sub_exp)
func = relay.Function([x, y], exp)
return func
annotated_func = annotated()
expected_func = expected()
check_annotated_graph(annotated_func, expected_func)
test_runtime(target, device, annotated_func, fallback_device)
def test_fuse_all(device, tgt):
"""Fuse all operators."""
fallback_device = tvm.context("cpu")
target = {"cpu": "llvm", device: tgt}
cpu_ctx = fallback_device
dev_ctx = tvm.context(device)
def annotated():
add = relay.add(x, y)
_add = relay.annotation.on_device(add, dev_ctx)
sqrt = relay.sqrt(add)
_sqrt = relay.annotation.on_device(sqrt, dev_ctx)
log = relay.log(add)
_log = relay.annotation.on_device(log, dev_ctx)
subtract = relay.subtract(sqrt, log)
_subtract = relay.annotation.on_device(subtract, dev_ctx)
exp = relay.exp(subtract)
_exp = relay.annotation.on_device(exp, dev_ctx)
func = relay.Function([x, y],
relay.Tuple(tvm.convert([_add, _sqrt, _log,
_subtract, _exp,
exp])))
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
cpu_ctx.device_type)
func = relay.ir_pass.infer_type(func)
return relay.Function(relay.ir_pass.free_vars(func.body[5]),
func.body[5])
annotated_func = annotated()
expected_func = get_func()
check_annotated_graph(annotated_func, expected_func)
test_runtime(target, device, annotated_func, fallback_device)
def test_fallback_exp(device, tgt):
fallback_device = tvm.context("cpu")
target = {"cpu": "llvm", device: tgt}
cpu_ctx = fallback_device
def annotated():
add = relay.add(x, y)
sqrt = relay.sqrt(add)
log = relay.log(add)
subtract = relay.subtract(sqrt, log)
exp = relay.exp(subtract)
_exp = relay.annotation.on_device(exp, cpu_ctx)
func = relay.Function([x, y],
relay.Tuple(tvm.convert([_exp, exp])))
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
cpu_ctx.device_type)
func = relay.ir_pass.infer_type(func)
return relay.Function(relay.ir_pass.free_vars(func.body[1]),
func.body[1])
annotated_func = annotated()
expected_func = get_func()
check_annotated_graph(annotated_func, expected_func)
test_runtime(target, device, annotated_func, fallback_device)
def test_fallback_all_operators(device, tgt):
target = {"cpu": "llvm", device: tgt}
fallback_device = tvm.cpu(0)
annotated_func = get_func()
expected_func = get_func()
check_annotated_graph(annotated_func, expected_func)
test_runtime(target, device, annotated_func, fallback_device)
for dev, tgt in [("opencl", "opencl"), ("cuda", "cuda"),
("opencl", str(tvm.target.intel_graphics()))]:
if not tvm.module.enabled(dev):
print("Skip test because %s is not enabled." % dev)
continue
test_fuse_log_add(dev, tgt)
test_fuse_all(dev, tgt)
test_fallback_exp(dev, tgt)
test_fallback_all_operators(dev, tgt)
if __name__ == "__main__":
test_redundant_annotation()
test_annotate_all()
test_annotate_none()
test_conv_network()
test_fusible_network()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment