Commit 09236bf6 by Zhi Committed by Haichen Shen

[relay][pass] Annotation for heterogeneous compilation (#2361)

parent 30a5a600
...@@ -148,6 +148,8 @@ This level support backpropagation of broadcast operators. It is temporary. ...@@ -148,6 +148,8 @@ This level support backpropagation of broadcast operators. It is temporary.
tvm.relay.broadcast_to_like tvm.relay.broadcast_to_like
tvm.relay.collapse_sum_like tvm.relay.collapse_sum_like
tvm.relay.slice_like tvm.relay.slice_like
tvm.relay.device_copy
tvm.relay.annotation.on_device
Level 1 Definitions Level 1 Definitions
......
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/attrs/annotation.h
* \brief Attribute for annotation operators.
*/
#ifndef TVM_RELAY_ATTRS_ANNOTATION_H_
#define TVM_RELAY_ATTRS_ANNOTATION_H_
#include <tvm/attrs.h>
#include <string>
namespace tvm {
namespace relay {
/*!
* \brief Options for the device annotation operators.
*/
struct OnDeviceAttrs : public tvm::AttrsNode<OnDeviceAttrs> {
int device_type;
TVM_DECLARE_ATTRS(OnDeviceAttrs, "relay.attrs.OnDeviceAttrs") {
TVM_ATTR_FIELD(device_type)
.describe(
"The virutal device/context type that an expression is annotated with.")
.set_default(0);
}
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_ANNOTATION_H_
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/attrs/device_copy.h
* \brief Attribute for the device copy operator.
*/
#ifndef TVM_RELAY_ATTRS_DEVICE_COPY_H_
#define TVM_RELAY_ATTRS_DEVICE_COPY_H_
#include <tvm/attrs.h>
#include <string>
namespace tvm {
namespace relay {
/*!
* \brief Options for the device copy operators.
*/
struct DeviceCopyAttrs : public tvm::AttrsNode<DeviceCopyAttrs> {
int dst_dev_type;
int src_dev_type;
TVM_DECLARE_ATTRS(DeviceCopyAttrs, "relay.attrs.DeviceCopyAttrs") {
TVM_ATTR_FIELD(src_dev_type)
.describe(
"The virutal device/context type where the op copies data from.")
.set_default(0);
TVM_ATTR_FIELD(dst_dev_type)
.describe(
"The virutal device/context type where the op copies data to.")
.set_default(0);
}
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_DEVICE_COPY_H_
...@@ -164,7 +164,6 @@ struct ClipAttrs : public tvm::AttrsNode<ClipAttrs> { ...@@ -164,7 +164,6 @@ struct ClipAttrs : public tvm::AttrsNode<ClipAttrs> {
} }
}; };
struct LayoutTransformAttrs : public tvm::AttrsNode<LayoutTransformAttrs> { struct LayoutTransformAttrs : public tvm::AttrsNode<LayoutTransformAttrs> {
std::string src_layout; std::string src_layout;
std::string dst_layout; std::string dst_layout;
......
...@@ -188,6 +188,21 @@ Expr ForwardRewrite(const Expr& expr, ...@@ -188,6 +188,21 @@ Expr ForwardRewrite(const Expr& expr,
std::function<NodeRef(const Call&)> fcontext = nullptr, std::function<NodeRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr); std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
/*!
* \brief Rewrite the annotated program.
* \param expr The expression.
* \param fallback_device The fallback device which is the default device for
* operators without annotation.
* \return The updated program.
*/
Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device);
/*!
* \brief Collect the device mapping information of each expression.
* \param expr The expression.
* \return The device mapping.
*/
Map<Expr, Integer> CollectDeviceInfo(const Expr& expr);
/*! \brief A hashing structure in the style of std::hash. */ /*! \brief A hashing structure in the style of std::hash. */
struct StructuralHash { struct StructuralHash {
......
...@@ -25,6 +25,7 @@ from .ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl ...@@ -25,6 +25,7 @@ from .ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl
from .ndarray import vpi, rocm, opengl, ext_dev from .ndarray import vpi, rocm, opengl, ext_dev
from ._ffi.runtime_ctypes import TypeCode from ._ffi.runtime_ctypes import TypeCode
from ._ffi.ndarray import TVMContext
from ._ffi.function import Function from ._ffi.function import Function
from ._ffi.base import TVMError, __version__ from ._ffi.base import TVMError, __version__
from .api import * from .api import *
......
...@@ -18,6 +18,7 @@ from .op.reduce import * ...@@ -18,6 +18,7 @@ from .op.reduce import *
from .op.tensor import * from .op.tensor import *
from .op.transform import * from .op.transform import *
from . import nn from . import nn
from . import annotation
from . import vision from . import vision
from . import image from . import image
from . import frontend from . import frontend
......
# pylint: disable=wildcard-import, unused-import, unused-wildcard-import
"""Annotation related operators."""
# Re-export in a specific file name so that autodoc can pick it up
from .op.annotation import *
...@@ -52,20 +52,21 @@ def build(funcs, target, target_host=None): ...@@ -52,20 +52,21 @@ def build(funcs, target, target_host=None):
Parameters Parameters
---------- ----------
funcs : List[tvm.LoweredFunc] funcs : List[tvm.LoweredFunc] or Dict[str, List[tvm.LoweredFunc]]
The list of lowered functions. A list of lowered functions or dictionary mapping from targets to
lowered functions.
target : tvm.Target target : tvm.Target
The target to run the code on. The target to run the code on.
target_host : tvm.Target target_host : tvm.Target
The host target. The host target.
Returns Returns
------- -------
module : tvm.Module module : tvm.Module
The runtime module. The runtime module.
""" """
if target_host == "": if target_host == "":
target_host = None target_host = None
......
...@@ -20,6 +20,7 @@ contrib.graph_runtime or any other TVM runtime comptatible system. ...@@ -20,6 +20,7 @@ contrib.graph_runtime or any other TVM runtime comptatible system.
from __future__ import absolute_import from __future__ import absolute_import
import json import json
from collections import defaultdict
import attr import attr
from . import _backend from . import _backend
from . import compile_engine from . import compile_engine
...@@ -27,6 +28,7 @@ from ..op import Op ...@@ -27,6 +28,7 @@ from ..op import Op
from ..expr import Function, GlobalVar from ..expr import Function, GlobalVar
from ..expr_functor import ExprFunctor from ..expr_functor import ExprFunctor
from ..ty import TupleType, TensorType from ..ty import TupleType, TensorType
from ... import target as _target
@attr.s @attr.s
...@@ -105,9 +107,9 @@ class GraphRuntimeCodegen(ExprFunctor): ...@@ -105,9 +107,9 @@ class GraphRuntimeCodegen(ExprFunctor):
self.nodes = [] self.nodes = []
self.var_map = {} self.var_map = {}
self.params = {} self.params = {}
self.storage_map = None self.storage_device_map = None
self.compile_engine = compile_engine.get() self.compile_engine = compile_engine.get()
self.lowered_funcs = set() self.lowered_funcs = defaultdict(set)
self._name_map = {} self._name_map = {}
def add_node(self, node, expr): def add_node(self, node, expr):
...@@ -129,10 +131,20 @@ class GraphRuntimeCodegen(ExprFunctor): ...@@ -129,10 +131,20 @@ class GraphRuntimeCodegen(ExprFunctor):
""" """
checked_type = expr.checked_type checked_type = expr.checked_type
# setup storage ids # setup storage ids
assert expr in self.storage_map assert expr in self.storage_device_map
node.attrs["storage_id"] = [ storage_device_info = self.storage_device_map[expr]
x.value for x in self.storage_map[expr] assert len(storage_device_info) == 2
] node.attrs["storage_id"] = [x.value for x in storage_device_info[0]]
device_types = [x.value for x in storage_device_info[1]]
num_unknown_devices = device_types.count(0)
if num_unknown_devices != 0 and num_unknown_devices != len(device_types):
raise RuntimeError("The graph contains not annotated nodes for "
"heterogeneous execution. All nodes must be "
"annotated.")
# Add the `device_index` attribute when the graph is annotated.
if num_unknown_devices == 0:
node.attrs["device_index"] = device_types
node_id = len(self.nodes) node_id = len(self.nodes)
self.nodes.append(node) self.nodes.append(node)
...@@ -232,9 +244,25 @@ class GraphRuntimeCodegen(ExprFunctor): ...@@ -232,9 +244,25 @@ class GraphRuntimeCodegen(ExprFunctor):
"TVM only support calls to primitive functions " + "TVM only support calls to primitive functions " +
"(i.e functions composed of fusable operator invocations)") "(i.e functions composed of fusable operator invocations)")
cached_func = self.compile_engine.lower(func, self.target) assert call in self.storage_device_map
device_types = self.storage_device_map[call][1]
call_dev_type = device_types[0].value
if isinstance(self.target, (str, _target.Target)):
# homogeneous execution.
cached_func = self.compile_engine.lower(func, self.target)
self.target = {0: str(self.target)}
elif isinstance(self.target, dict):
# heterogeneous execution.
if call_dev_type not in self.target:
raise Exception("No target is provided for device " +
"{0}".format(call_dev_type))
cached_func = self.compile_engine.lower(func,
self.target[call_dev_type])
else:
raise ValueError("self.target must be the type of str," +
"tvm.target.Target, or dict of int to str")
for loweredf in cached_func.funcs: for loweredf in cached_func.funcs:
self.lowered_funcs.add(loweredf) self.lowered_funcs[self.target[call_dev_type]].add(loweredf)
inputs = [] inputs = []
# flatten tuple in the call. # flatten tuple in the call.
...@@ -284,6 +312,7 @@ class GraphRuntimeCodegen(ExprFunctor): ...@@ -284,6 +312,7 @@ class GraphRuntimeCodegen(ExprFunctor):
num_entry = 0 num_entry = 0
shapes = [] shapes = []
storage_ids = [] storage_ids = []
device_types = []
dltypes = [] dltypes = []
node_row_ptr = [0] node_row_ptr = [0]
for node in self.nodes: for node in self.nodes:
...@@ -291,6 +320,8 @@ class GraphRuntimeCodegen(ExprFunctor): ...@@ -291,6 +320,8 @@ class GraphRuntimeCodegen(ExprFunctor):
shapes += node.attrs["shape"] shapes += node.attrs["shape"]
dltypes += node.attrs["dtype"] dltypes += node.attrs["dtype"]
storage_ids += node.attrs["storage_id"] storage_ids += node.attrs["storage_id"]
if "device_index" in node.attrs:
device_types += node.attrs["device_index"]
num_entry += node.num_outputs num_entry += node.num_outputs
node_row_ptr.append(num_entry) node_row_ptr.append(num_entry)
...@@ -298,6 +329,8 @@ class GraphRuntimeCodegen(ExprFunctor): ...@@ -298,6 +329,8 @@ class GraphRuntimeCodegen(ExprFunctor):
attrs = {} attrs = {}
attrs["shape"] = ["list_shape", shapes] attrs["shape"] = ["list_shape", shapes]
attrs["storage_id"] = ["list_int", storage_ids] attrs["storage_id"] = ["list_int", storage_ids]
if device_types:
attrs["device_index"] = ["list_int", device_types]
attrs["dltype"] = ["list_str", dltypes] attrs["dltype"] = ["list_str", dltypes]
json_dict = { json_dict = {
...@@ -313,11 +346,24 @@ class GraphRuntimeCodegen(ExprFunctor): ...@@ -313,11 +346,24 @@ class GraphRuntimeCodegen(ExprFunctor):
def debug_dump_memory_plan(self, func): def debug_dump_memory_plan(self, func):
"""Debug function to dump memory plan.""" """Debug function to dump memory plan."""
def _annotate(expr): def _annotate(expr):
if expr in self.storage_map: if expr in self.storage_device_map:
return str(self.storage_map[expr]) storage_device_info = self.storage_device_map[expr]
assert len(storage_device_info) == 2
return str(storage_device_info[0])
return ""
return func.astext(show_meta_data=False, annotate=_annotate)
def debug_dump_device_annotation(self, func):
"""Debug function to dump device annotation result."""
def _annotate(expr):
if expr in self.storage_device_map:
storage_device_info = self.storage_device_map[expr]
assert len(storage_device_info) == 2
return str(storage_device_info[1])
return "" return ""
return func.astext(show_meta_data=False, annotate=_annotate) return func.astext(show_meta_data=False, annotate=_annotate)
def codegen(self, func): def codegen(self, func):
"""Compile a single function into a graph. """Compile a single function into a graph.
...@@ -331,24 +377,31 @@ class GraphRuntimeCodegen(ExprFunctor): ...@@ -331,24 +377,31 @@ class GraphRuntimeCodegen(ExprFunctor):
graph_json : str graph_json : str
The graph json that can be consumed by runtime. The graph json that can be consumed by runtime.
lowered_funcs : List[tvm.LoweredFunc] lowered_funcs : List[tvm.LoweredFunc] or Dict[str, List[tvm.LoweredFunc]]
The lowered functions. The lowered functions.
params : Dict[str, tvm.nd.NDArray] params : Dict[str, tvm.nd.NDArray]
Additional constant parameters. Additional constant parameters.
""" """
self.storage_map = _backend.GraphPlanMemory(func) self.storage_device_map = _backend.GraphPlanMemory(func)
# First we convert all the parameters into input nodes. # First we convert all the parameters into input nodes.
for param in func.params: for param in func.params:
node = InputNode(param.name_hint, {}) node = InputNode(param.name_hint, {})
self.var_map[param] = self.add_node( self.var_map[param] = self.add_node(node, param)
node, param)
# Then we compile the body into a graph which can depend # Then we compile the body into a graph which can depend
# on input variables. # on input variables.
self.heads = self.visit(func.body) self.heads = self.visit(func.body)
graph_json = self._get_json() graph_json = self._get_json()
lowered_funcs = list(self.lowered_funcs)
# Return the lowered functions as a list for homogeneous compilation.
# Otherwise, for heterogeneous compilation, a dictionary containing
# the device id to a list of lowered functions is returned. Both forms
# are acceptable to tvm.build.
if not isinstance(self.target, dict):
lowered_funcs = list(list(self.lowered_funcs.values())[0])
else:
lowered_funcs = {k: list(v) for k, v in self.lowered_funcs.items()}
return graph_json, lowered_funcs, self.params return graph_json, lowered_funcs, self.params
def _get_unique_name(self, name): def _get_unique_name(self, name):
......
...@@ -357,3 +357,60 @@ def alter_op_layout(expr): ...@@ -357,3 +357,60 @@ def alter_op_layout(expr):
Transformed expression with alternated layout. Transformed expression with alternated layout.
""" """
return _ir_pass.AlterOpLayout(expr) return _ir_pass.AlterOpLayout(expr)
def rewrite_annotated_ops(expr, fallback_device):
"""Rewrite the annotated program where annotation operators, e.g.
`on_deivce`, mark which device an expression should be scheduled to.
This pass helps heterogeneous execution where different operators may need
to be allocated on various devices.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
fallback_device : int
The fallback device type. It is also used as the default device for
operators with no annotated device.
Returns
-------
transformed_expr : tvm.relay.Expr
Transformed expression with cross device data copy operators.
"""
return _ir_pass.RewriteDeviceAnnotation(expr, fallback_device)
def collect_device_info(expr):
"""Collect the device allocation map for the given expression. The device
ids are propagated from the `device_copy` operators.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
Returns
-------
ret : Dict[tvm.relay.expr, int]
A dictionary mapping tvm.relay.Expr to device type.
"""
return _ir_pass.CollectDeviceInfo(expr)
def collect_device_annotation_ops(expr):
"""Collect the device annotation ops for the given expression.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
Returns
-------
ret : Dict[tvm.relay.expr, int]
A dictionary mapping tvm.relay.Expr to device type where the keys are
annotation expressions.
"""
return _ir_pass.CollectDeviceAnnotationOps(expr)
...@@ -10,6 +10,7 @@ from .reduce import * ...@@ -10,6 +10,7 @@ from .reduce import *
from .tensor import * from .tensor import *
from .transform import * from .transform import *
from . import nn from . import nn
from . import annotation
from . import image from . import image
from . import vision from . import vision
from . import op_attrs from . import op_attrs
......
# pylint: disable=wildcard-import
"""Annotation related operators."""
from __future__ import absolute_import as _abs
from .annotation import *
"""Constructor APIs"""
from ...._ffi.function import _init_api
_init_api("relay.op.annotation._make", __name__)
"""Annotation operations."""
from __future__ import absolute_import as _abs
from . import _make
from .... import nd as _nd
from .... import TVMContext as _TVMContext
def on_device(data, device):
"""Annotate an expression with a certain device type.
Parameters
----------
data : tvm.relay.Expr
The expression to be annotated.
device : Union[:py:class:`TVMContext`, str]
The device type to annotate.
Returns
-------
result : tvm.relay.Expr
The annotated expression.
"""
if isinstance(device, _TVMContext):
device = device.device_type
elif isinstance(device, str):
device = _nd.context(device).device_type
else:
raise ValueError("device is expected to be the type of TVMContext or "
"str, but received %s" % (type(device)))
return _make.on_device(data, device)
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from . import _make from . import _make
from ..expr import Tuple from ..expr import Tuple
from ... import nd as _nd
from ... import TVMContext as _TVMContext
# We create a wrapper function for each operator in the # We create a wrapper function for each operator in the
# python side to call into the positional _make.OpName function. # python side to call into the positional _make.OpName function.
...@@ -616,3 +618,42 @@ def copy(data): ...@@ -616,3 +618,42 @@ def copy(data):
The copied result. The copied result.
""" """
return _make.copy(data) return _make.copy(data)
def device_copy(data, src_dev, dst_dev):
"""Copy data from the source device to the destination device. This
operator helps data transferring between difference contexts for
heterogeneous execution.
Parameters
----------
data : tvm.relay.Expr
The tensor to be copied.
src_dev : Union[:py:class:`TVMContext`, str]
The source device where the data is copied from.
dst_dev : Union[:py:class:`TVMContext`, str]
The destination device where the data is copied to.
Returns
-------
result : tvm.relay.Expr
The copied result.
"""
if isinstance(src_dev, _TVMContext):
src_dev = src_dev.device_type
elif isinstance(src_dev, str):
src_dev = _nd.context(src_dev).device_type
else:
raise ValueError("src_dev is expected to be the type of TVMContext or "
"str, but received %s" % (type(src_dev)))
if isinstance(dst_dev, _TVMContext):
dst_dev = dst_dev.device_type
elif isinstance(dst_dev, str):
dst_dev = _nd.context(dst_dev).device_type
else:
raise ValueError("dst_dev is expected to be the type of TVMContext or "
"str, but received %s" % (type(dst_dev)))
return _make.device_copy(data, src_dev, dst_dev)
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <tvm/packed_func_ext.h> #include <tvm/packed_func_ext.h>
#include <tvm/operation.h> #include <tvm/operation.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/pass.h> #include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
...@@ -82,10 +83,14 @@ class ScheduleGetter : ...@@ -82,10 +83,14 @@ class ScheduleGetter :
cache_node->func_name = readable_name_stream_.str(); cache_node->func_name = readable_name_stream_.str();
CachedFunc cfunc(cache_node); CachedFunc cfunc(cache_node);
CHECK(master_op_.defined()); CHECK(master_op_.defined());
Schedule schedule = fschedule[master_op_]( Schedule schedule;
master_attrs_, cache_node->outputs, target_); // No need to register schedule for device copy op.
for (const auto& scalar : scalars_) { if (master_attrs_.as<DeviceCopyAttrs>() == nullptr) {
schedule[scalar].compute_inline(); schedule =
fschedule[master_op_](master_attrs_, cache_node->outputs, target_);
for (const auto& scalar : scalars_) {
schedule[scalar].compute_inline();
}
} }
return std::make_pair(schedule, cfunc); return std::make_pair(schedule, cfunc);
} }
...@@ -153,11 +158,18 @@ class ScheduleGetter : ...@@ -153,11 +158,18 @@ class ScheduleGetter :
CHECK(call_node->op.as<OpNode>()) CHECK(call_node->op.as<OpNode>())
<< "Primitive function only allows call into primitive ops"; << "Primitive function only allows call into primitive ops";
Op op = Downcast<Op>(call_node->op); Op op = Downcast<Op>(call_node->op);
Array<Tensor> outputs = fcompute[op]( // Check if the op is a device copy op.
call_node->attrs, bool is_copy_op = op.same_as(Op::Get("device_copy"));
inputs, Array<Tensor> outputs;
call_node->checked_type(), // Skip fcompute for device copy operators as it is not registered.
target_); if (is_copy_op) {
const auto* copy_input = inputs[0].operator->();
outputs.push_back(TensorNode::make(copy_input->shape, copy_input->dtype,
Operation(), 0));
} else {
outputs = fcompute[op](call_node->attrs, inputs,
call_node->checked_type(), target_);
}
int op_pattern = fpattern[op]; int op_pattern = fpattern[op];
if (op_pattern >= kCommReduce) { if (op_pattern >= kCommReduce) {
...@@ -176,7 +188,14 @@ class ScheduleGetter : ...@@ -176,7 +188,14 @@ class ScheduleGetter :
CHECK(tuple_type) << "Expect output to be a tuple type"; CHECK(tuple_type) << "Expect output to be a tuple type";
CHECK_EQ(tuple_type->fields.size(), outputs.size()); CHECK_EQ(tuple_type->fields.size(), outputs.size());
} }
readable_name_stream_ << '_' << op->name; // Set the name to `__copy`. It will be detected in graph runtime to perform
// data copy across devices.
if (is_copy_op) {
readable_name_stream_.str(std::string());
readable_name_stream_ << "__copy";
} else {
readable_name_stream_ << '_' << op->name;
}
return outputs; return outputs;
} }
...@@ -291,6 +310,16 @@ class CompileEngineImpl : public CompileEngineNode { ...@@ -291,6 +310,16 @@ class CompileEngineImpl : public CompileEngineNode {
auto spair = CreateSchedule(key->source_func, key->target); auto spair = CreateSchedule(key->source_func, key->target);
auto cache_node = make_node<CachedFuncNode>( auto cache_node = make_node<CachedFuncNode>(
*(spair.second.operator->())); *(spair.second.operator->()));
// Skip lowering for device copy node.
const Expr body = (key->source_func)->body;
if (const CallNode* call_node = body.as<CallNode>()) {
if (call_node->attrs.as<DeviceCopyAttrs>()) {
value->cached_func = CachedFunc(cache_node);
return value;
}
}
cache_node->func_name = GetUniqueName(cache_node->func_name); cache_node->func_name = GetUniqueName(cache_node->func_name);
// NOTE: array will copy on write. // NOTE: array will copy on write.
Array<Tensor> all_args = cache_node->inputs; Array<Tensor> all_args = cache_node->inputs;
......
...@@ -6,11 +6,14 @@ ...@@ -6,11 +6,14 @@
*/ */
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/pass.h>
#include "../../common/arena.h" #include "../../common/arena.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
using IntegerArray = Array<Integer>;
struct StorageToken { struct StorageToken {
/*! \brief Reference counter */ /*! \brief Reference counter */
int ref_counter{0}; int ref_counter{0};
...@@ -18,8 +21,9 @@ struct StorageToken { ...@@ -18,8 +21,9 @@ struct StorageToken {
size_t max_bytes{0}; size_t max_bytes{0};
/*! \brief The corresponding tensor type node. */ /*! \brief The corresponding tensor type node. */
const TensorTypeNode* ttype{nullptr}; const TensorTypeNode* ttype{nullptr};
/*! \brief virtual device index */ /*! \brief virtual device index that corresponds to the device_type in
int device_id{0}; * DLContext. */
int device_type{0};
/*! \brief The storage id */ /*! \brief The storage id */
int64_t storage_id{-1}; int64_t storage_id{-1};
}; };
...@@ -106,33 +110,35 @@ class StorageAllocaBaseVisitor : public ExprVisitor { ...@@ -106,33 +110,35 @@ class StorageAllocaBaseVisitor : public ExprVisitor {
virtual void CreateToken(const ExprNode* op, bool can_realloc) = 0; virtual void CreateToken(const ExprNode* op, bool can_realloc) = 0;
}; };
class StorageAllocaInit : protected StorageAllocaBaseVisitor { class StorageAllocaInit : protected StorageAllocaBaseVisitor {
public: public:
explicit StorageAllocaInit(common::Arena* arena) explicit StorageAllocaInit(common::Arena* arena)
: arena_(arena) {} : arena_(arena) {}
/*! \return The internal token map */ /*! \return The internal token map */
std::unordered_map<const ExprNode*, std::vector<StorageToken*> > std::unordered_map<const ExprNode*, std::vector<StorageToken*> >
GetInitTokenMap(const Function& func) { GetInitTokenMap(const Function& func) {
node_device_map_ = CollectDeviceInfo(func);
this->Run(func); this->Run(func);
return std::move(token_map_); return std::move(token_map_);
} }
protected: protected:
using StorageAllocaBaseVisitor::VisitExpr_; using StorageAllocaBaseVisitor::VisitExpr_;
void CreateToken(const ExprNode* op, bool can_realloc) final { void CreateToken(const ExprNode* op, bool can_realloc) final {
CHECK(!token_map_.count(op)); CHECK(!token_map_.count(op));
std::vector<StorageToken*> tokens; std::vector<StorageToken*> tokens;
int device_type = node_device_map_.count(GetRef<Expr>(op))
? node_device_map_[GetRef<Expr>(op)]->value
: 0;
if (const auto* tuple_type = op->checked_type().as<TupleTypeNode>()) { if (const auto* tuple_type = op->checked_type().as<TupleTypeNode>()) {
for (Type t : tuple_type->fields) { for (Type t : tuple_type->fields) {
const auto* ttype = t.as<TensorTypeNode>(); const auto* ttype = t.as<TensorTypeNode>();
CHECK(ttype); CHECK(ttype);
StorageToken* token = arena_->make<StorageToken>(); StorageToken* token = arena_->make<StorageToken>();
token->ttype = ttype; token->ttype = ttype;
token->device_type = device_type;
tokens.push_back(token); tokens.push_back(token);
} }
} else { } else {
...@@ -140,6 +146,7 @@ class StorageAllocaInit : protected StorageAllocaBaseVisitor { ...@@ -140,6 +146,7 @@ class StorageAllocaInit : protected StorageAllocaBaseVisitor {
CHECK(ttype); CHECK(ttype);
StorageToken* token = arena_->make<StorageToken>(); StorageToken* token = arena_->make<StorageToken>();
token->ttype = ttype; token->ttype = ttype;
token->device_type = device_type;
tokens.push_back(token); tokens.push_back(token);
} }
token_map_[op] = tokens; token_map_[op] = tokens;
...@@ -159,9 +166,9 @@ class StorageAllocaInit : protected StorageAllocaBaseVisitor { ...@@ -159,9 +166,9 @@ class StorageAllocaInit : protected StorageAllocaBaseVisitor {
private: private:
// allocator // allocator
common::Arena* arena_; common::Arena* arena_;
Map<Expr, Integer> node_device_map_;
}; };
class StorageAllocator : public StorageAllocaBaseVisitor { class StorageAllocator : public StorageAllocaBaseVisitor {
public: public:
/*! /*!
...@@ -176,23 +183,39 @@ class StorageAllocator : public StorageAllocaBaseVisitor { ...@@ -176,23 +183,39 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
} }
// Run storage allocation for a function. // Run storage allocation for a function.
Map<Expr, Array<Integer> > Plan(const Function& func) { Map<Expr, Array<IntegerArray> > Plan(const Function& func) {
prototype_ = StorageAllocaInit(&arena_).GetInitTokenMap(func); prototype_ = StorageAllocaInit(&arena_).GetInitTokenMap(func);
this->Run(func); this->Run(func);
Map<Expr, Array<Integer> > smap; // The value of smap contains two integer arrays where the first array
// contains the planned storage ids and the second holds the device types.
Map<Expr, Array<IntegerArray> > smap;
int num_annotated_nodes = 0;
int num_nodes = 0;
for (const auto& kv : token_map_) { for (const auto& kv : token_map_) {
Array<Integer> vec; std::vector<Integer> storage_ids;
std::vector<Integer> device_types;
for (StorageToken* tok : kv.second) { for (StorageToken* tok : kv.second) {
vec.push_back(tok->storage_id); if (tok->device_type) {
num_annotated_nodes++;
}
num_nodes++;
storage_ids.push_back(tok->storage_id);
device_types.push_back(tok->device_type);
} }
smap.Set(GetRef<Expr>(kv.first), vec); smap.Set(GetRef<Expr>(kv.first), Array<IntegerArray>({storage_ids, device_types}));
}
// Either all or none of the nodes should be annotated.
if (num_annotated_nodes != 0 && num_annotated_nodes != num_nodes) {
LOG(FATAL)
<< num_annotated_nodes << " out of " << num_nodes
<< "expressions are assigned with virtual device types. Either all "
"or none of the expressions are expected to be annotated.";
} }
return smap; return smap;
} }
protected: protected:
using StorageAllocaBaseVisitor::VisitExpr_; using StorageAllocaBaseVisitor::VisitExpr_;
// override create token by getting token as prototype requirements. // override create token by getting token as prototype requirements.
...@@ -207,6 +230,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor { ...@@ -207,6 +230,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
} else { } else {
// Allocate a new token, // Allocate a new token,
StorageToken* allocated_tok = Alloc(tok, GetMemorySize(tok)); StorageToken* allocated_tok = Alloc(tok, GetMemorySize(tok));
allocated_tok->device_type = tok->device_type;
// ensure it never get de-allocated. // ensure it never get de-allocated.
allocated_tok->ref_counter += 1; allocated_tok->ref_counter += 1;
tokens.push_back(allocated_tok); tokens.push_back(allocated_tok);
...@@ -282,7 +306,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor { ...@@ -282,7 +306,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
// search for memory blocks larger than requested // search for memory blocks larger than requested
for (auto it = mid; it != end; ++it) { for (auto it = mid; it != end; ++it) {
StorageToken *tok = it->second; StorageToken *tok = it->second;
if (tok->device_id != prototype->device_id) continue; if (tok->device_type != prototype->device_type) continue;
CHECK_EQ(tok->ref_counter, 0); CHECK_EQ(tok->ref_counter, 0);
// Use exect matching strategy // Use exect matching strategy
tok->max_bytes = std::max(size, tok->max_bytes); tok->max_bytes = std::max(size, tok->max_bytes);
...@@ -295,7 +319,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor { ...@@ -295,7 +319,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
for (auto it = mid; it != begin;) { for (auto it = mid; it != begin;) {
--it; --it;
StorageToken *tok = it->second; StorageToken *tok = it->second;
if (tok->device_id != prototype->device_id) continue; if (tok->device_type != prototype->device_type) continue;
CHECK_EQ(tok->ref_counter, 0); CHECK_EQ(tok->ref_counter, 0);
// Use exect matching strategy // Use exect matching strategy
tok->max_bytes = std::max(size, tok->max_bytes); tok->max_bytes = std::max(size, tok->max_bytes);
...@@ -343,13 +367,12 @@ class StorageAllocator : public StorageAllocaBaseVisitor { ...@@ -343,13 +367,12 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
std::unordered_map<const ExprNode*, std::vector<StorageToken*> > prototype_; std::unordered_map<const ExprNode*, std::vector<StorageToken*> > prototype_;
}; };
Map<Expr, Array<IntegerArray> > GraphPlanMemory(const Function& func) {
Map<Expr, Array<Integer> > GraphPlanMemory(const Function& func) {
return StorageAllocator().Plan(func); return StorageAllocator().Plan(func);
} }
TVM_REGISTER_GLOBAL("relay.backend.GraphPlanMemory") TVM_REGISTER_GLOBAL("relay.backend.GraphPlanMemory")
.set_body_typed<Map<Expr, Array<Integer> >(const Function&)>(GraphPlanMemory); .set_body_typed<Map<Expr, Array<IntegerArray> >(const Function&)>(GraphPlanMemory);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
*
* \file src/relay/op/annotation/annotation.cc
* \brief Registration of annotation operators.
*/
#include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include "../type_relations.h"
#include "../../pass/alter_op_layout.h"
namespace tvm {
namespace relay {
// relay.annotation.on_device
TVM_REGISTER_NODE_TYPE(OnDeviceAttrs);
TVM_REGISTER_API("relay.op.annotation._make.on_device")
.set_body_typed<Expr(Expr, int)>([](Expr data, int device_type) {
auto attrs = make_node<OnDeviceAttrs>();
attrs->device_type = device_type;
static const Op& op = Op::Get("on_device");
return CallNode::make(op, {data}, Attrs(attrs), {});
});
RELAY_REGISTER_OP("on_device")
.describe(R"code(Annotate an expression with device type)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_support_level(10)
.add_type_rel("Identity", IdentityRel)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ElemwiseArbitraryLayout);
} // namespace relay
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
*
* \file src/relay/op/device_copy.cc
* \brief Crossing device data copy operator.
*
* The pattern of this operator is registered as kOpaque. Hence, it could be
* used as "barrier" to avoid fusing operators belonging to differen devices.
*/
#include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include "type_relations.h"
#include "../pass/alter_op_layout.h"
namespace tvm {
namespace relay {
// relay.device_copy
TVM_REGISTER_NODE_TYPE(DeviceCopyAttrs);
TVM_REGISTER_API("relay.op._make.device_copy")
.set_body_typed<Expr(Expr, int, int)>([](Expr data, int src_dev_type,
int dst_dev_type) {
auto attrs = make_node<DeviceCopyAttrs>();
attrs->src_dev_type = src_dev_type;
attrs->dst_dev_type = dst_dev_type;
static const Op& op = Op::Get("device_copy");
return CallNode::make(op, {data}, Attrs(attrs), {});
});
RELAY_REGISTER_OP("device_copy")
.describe(R"code(
Copy data from one tensor to another. The source and destination might be
on different devices.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_support_level(10)
.add_type_rel("Identity", IdentityRel)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ElemwiseArbitraryLayout);
} // namespace relay
} // namespace tvm
...@@ -112,14 +112,19 @@ def test_plan_memory(): ...@@ -112,14 +112,19 @@ def test_plan_memory():
func = relay.ir_pass.infer_type(func) func = relay.ir_pass.infer_type(func)
smap = relay.backend._backend.GraphPlanMemory(func) smap = relay.backend._backend.GraphPlanMemory(func)
storage_ids = set() storage_ids = set()
device_types = set()
for k, v in smap.items(): for k, v in smap.items():
for x in v: assert len(v) == 2
for x in v[0]:
storage_ids.add(x.value) storage_ids.add(x.value)
for x in v[1]:
device_types.add(x.value)
# Current rule requires vars have unique storage id # Current rule requires vars have unique storage id
# because we don't do inplace, we will need another # because we don't do inplace, we will need another
# two alternating temporary space. # two alternating temporary space.
assert len(storage_ids) == 4 assert len(storage_ids) == 4
assert len(device_types) == 1
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment