Unverified Commit 9a8ed5b7 by Jared Roesch Committed by GitHub

[Runtime][Relay][Cleanup] Clean up for memory pass to enable heterogenous…

[Runtime][Relay][Cleanup] Clean up for memory pass to enable heterogenous execution support. (#5324)

* Cleanup type pack and unpack for tuples.

* Clean up the memory_pass using common helpers

* Clean up memory.cc

* Refactor pass

* Add doc strings

* Fix CPPlint

* Fix PyLint

* Fix

* Apply suggestions from code review

Co-Authored-By: Zhi <5145158+zhiics@users.noreply.github.com>

* Fix typo

Co-authored-by: Zhi <5145158+zhiics@users.noreply.github.com>
parent 92c78266
...@@ -40,11 +40,11 @@ struct DeviceCopyAttrs : public tvm::AttrsNode<DeviceCopyAttrs> { ...@@ -40,11 +40,11 @@ struct DeviceCopyAttrs : public tvm::AttrsNode<DeviceCopyAttrs> {
TVM_DECLARE_ATTRS(DeviceCopyAttrs, "relay.attrs.DeviceCopyAttrs") { TVM_DECLARE_ATTRS(DeviceCopyAttrs, "relay.attrs.DeviceCopyAttrs") {
TVM_ATTR_FIELD(src_dev_type) TVM_ATTR_FIELD(src_dev_type)
.describe( .describe(
"The virutal device/context type where the op copies data from.") "The virtual device/context type where the op copies data from.")
.set_default(0); .set_default(0);
TVM_ATTR_FIELD(dst_dev_type) TVM_ATTR_FIELD(dst_dev_type)
.describe( .describe(
"The virutal device/context type where the op copies data to.") "The virtual device/context type where the op copies data to.")
.set_default(0); .set_default(0);
} }
}; };
......
...@@ -27,10 +27,37 @@ ...@@ -27,10 +27,37 @@
#include <tvm/ir/attrs.h> #include <tvm/ir/attrs.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <string> #include <string>
#include <vector>
namespace tvm { namespace tvm {
namespace relay { namespace relay {
std::vector<TensorType> FlattenTupleType(const Type& type);
std::vector<Expr> FromTupleType(const Type& type, const Expr& expr);
Expr ToTupleType(const Type& t, const Array<Expr>& exprs);
/*!
* \brief Options for allocating storage.
*/
struct AllocStorageAttrs : public tvm::AttrsNode<AllocStorageAttrs> {
DataType dtype;
int device_id;
int device_type;
TVM_DECLARE_ATTRS(AllocStorageAttrs, "relay.attrs.AllocStorageAttrs") {
TVM_ATTR_FIELD(dtype)
.describe(
"The dtype of the tensor to allocate.")
.set_default(DataType::Float(32, 1));
TVM_ATTR_FIELD(device_id)
.describe(
"The device id on which to allocate memory.");
TVM_ATTR_FIELD(device_type)
.describe(
"The device type on which to allocate memory.");
}
};
/*! /*!
* \brief Options for allocating tensors. * \brief Options for allocating tensors.
*/ */
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=no-else-return,invalid-name,len-as-condition,too-many-nested-blocks
"""Operators for manipulating low-level memory.""" """Operators for manipulating low-level memory."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from . import _make from . import _make
...@@ -23,6 +24,9 @@ def invoke_tvm_op(func, inputs, outputs): ...@@ -23,6 +24,9 @@ def invoke_tvm_op(func, inputs, outputs):
Parameters Parameters
---------- ----------
func : tvm.relay.Expr
The input expr.
inputs : tvm.relay.Expr inputs : tvm.relay.Expr
A tuple of the inputs to pass to the TVM function. A tuple of the inputs to pass to the TVM function.
...@@ -59,7 +63,7 @@ def alloc_tensor(storage, shape, dtype='float32', assert_shape=None): ...@@ -59,7 +63,7 @@ def alloc_tensor(storage, shape, dtype='float32', assert_shape=None):
""" """
return _make.alloc_tensor(storage, shape, dtype, assert_shape) return _make.alloc_tensor(storage, shape, dtype, assert_shape)
def alloc_storage(size, alignment, dtype_hint='float32'): def alloc_storage(size, alignment, ctx, dtype_hint='float32'):
"""Allocate a piece of tensor storage. """Allocate a piece of tensor storage.
Parameters Parameters
...@@ -76,7 +80,7 @@ def alloc_storage(size, alignment, dtype_hint='float32'): ...@@ -76,7 +80,7 @@ def alloc_storage(size, alignment, dtype_hint='float32'):
result : tvm.relay.Expr result : tvm.relay.Expr
The alloc_storage expression. The alloc_storage expression.
""" """
return _make.alloc_storage(size, alignment, dtype_hint) return _make.alloc_storage(size, alignment, ctx, dtype_hint)
def shape_func(func, inputs, outputs, dependent=False): def shape_func(func, inputs, outputs, dependent=False):
"""Invoke the shape function of the passed function. """Invoke the shape function of the passed function.
...@@ -96,3 +100,56 @@ def shape_func(func, inputs, outputs, dependent=False): ...@@ -96,3 +100,56 @@ def shape_func(func, inputs, outputs, dependent=False):
The shape function expression. The shape function expression.
""" """
return _make.shape_func(func, inputs, outputs, dependent) return _make.shape_func(func, inputs, outputs, dependent)
def flatten_tuple_type(ty):
"""Return a sequence of the types contained in the tuple type in order.
Parameters
----------
ty: tvm.Type
The type to flatten.
Returns
-------
result: List[tvm.Type]
The types in their linear order.
"""
return _make.FlattenTupleType(ty)
def from_tuple_type(ty, expr):
"""Convert an expression with the given type into a sequence of expressions.
Each expression maps to a field of the tuple or nested tuples in linear
order.
Parameters
----------
ty: tvm.Type
The type to unpack.
expr: tvm.relay.Expr
The expression from which to extract each sub-field.
Returns
-------
result: List[tvm.relay.Expr]
The list of sub-expressions.
"""
return _make.FromTupleType(ty, expr)
def to_tuple_type(ty, exprs):
"""Pack the sequence of expressions into the nested tuple type.
Parameters
----------
ty: tvm.Type
The type to pack with.
exprs: tvm.relay.Expr
The expressions to pack back into the nested tuple type.
Returns
-------
result: List[tvm.relay.Expr]
The packed tuple expression.
"""
return _make.ToTupleType(ty, exprs)
...@@ -26,60 +26,14 @@ from .. import op ...@@ -26,60 +26,14 @@ from .. import op
from ... import DataType, register_func from ... import DataType, register_func
from .. import ty, expr from .. import ty, expr
from ..backend import compile_engine from ..backend import compile_engine
from ..op.memory import flatten_tuple_type, from_tuple_type, to_tuple_type
from ...import cpu
def is_primitive(call): def is_primitive(call):
return hasattr(call, 'op') and hasattr(call.op, 'attrs') and \ return hasattr(call, 'op') and hasattr(call.op, 'attrs') and \
hasattr(call.op.attrs, 'Primitive') and int(call.op.attrs.Primitive) == 1 hasattr(call.op.attrs, 'Primitive') and int(call.op.attrs.Primitive) == 1
# TODO(@jroesch): port to c++ and unify with existing code
class LinearizeRetType:
"""A linear view of a Relay type, handles a linear order
for nested tuples, and tensor types.
"""
def __init__(self, typ):
"""Initialize the linearizer."""
self.typ = typ
def unpack(self):
"""Return the linear representation of the type."""
def _unpack(typ, out):
# TODO(@jroesch): replace with new flattening pass
if isinstance(typ, ty.TensorType):
out.append(typ)
elif isinstance(typ, ty.TupleType):
for field_ty in typ.fields:
_unpack(field_ty, out)
else:
raise Exception("unsupported Relay type: {0}".format(typ))
output = []
_unpack(self.typ, output)
return output
def pack(self, seq):
"""Repack a linear type as a nested type."""
def _pack(value, typ, out):
if isinstance(typ, ty.TensorType):
out.append(value)
elif isinstance(typ, ty.TupleType):
tuple_out = []
for i, field_ty in enumerate(typ.fields):
_pack(value[i], field_ty, tuple_out)
out.append(expr.Tuple(tuple_out))
else:
raise Exception("unsupported Relay type: {0}".format(typ))
if len(seq) == 1:
return seq[0]
else:
out = []
_pack(seq, self.typ, out)
assert len(out) == 1, "must return fully packed type"
return out[0]
class ManifestAllocPass(ExprMutator): class ManifestAllocPass(ExprMutator):
"""A pass for explictly manifesting all memory allocations in Relay.""" """A pass for explictly manifesting all memory allocations in Relay."""
...@@ -90,6 +44,7 @@ class ManifestAllocPass(ExprMutator): ...@@ -90,6 +44,7 @@ class ManifestAllocPass(ExprMutator):
self.shape_func = op.memory.shape_func self.shape_func = op.memory.shape_func
self.scopes = [ScopeBuilder()] self.scopes = [ScopeBuilder()]
self.target_host = target_host self.target_host = target_host
self.default_context = cpu(0)
self.compute_dtype = "int64" self.compute_dtype = "int64"
super().__init__() super().__init__()
...@@ -147,7 +102,7 @@ class ManifestAllocPass(ExprMutator): ...@@ -147,7 +102,7 @@ class ManifestAllocPass(ExprMutator):
alignment = self.compute_alignment(tensor_type.dtype) alignment = self.compute_alignment(tensor_type.dtype)
dtype = tensor_type.dtype dtype = tensor_type.dtype
sto = scope.let("storage_{0}".format(i), self.alloc_storage( sto = scope.let("storage_{0}".format(i), self.alloc_storage(
size, alignment, dtype)) size, alignment, self.default_context, dtype))
# TODO(@jroesch): There is a bug with typing based on the constant shape. # TODO(@jroesch): There is a bug with typing based on the constant shape.
tensor = self.alloc_tensor(sto, shape, dtype, tensor_type.shape) tensor = self.alloc_tensor(sto, shape, dtype, tensor_type.shape)
return scope.let("tensor_{0}".format(i), tensor) return scope.let("tensor_{0}".format(i), tensor)
...@@ -167,6 +122,83 @@ class ManifestAllocPass(ExprMutator): ...@@ -167,6 +122,83 @@ class ManifestAllocPass(ExprMutator):
return scope.get() return scope.get()
def dynamic_invoke(self, scope, func, ins, new_args, out_types, ret_type):
"""Generate the code for invoking a TVM op with a dynamic shape."""
shape_func_ins = []
engine = compile_engine.get()
cfunc = engine.lower_shape_func(func, self.target_host)
input_states = cfunc.shape_func_param_states
is_inputs = []
input_pos = 0
for i, (arg, state) in enumerate(zip(new_args, input_states)):
state = int(state)
# Pass Shapes
if state == 2:
for j, subexp in enumerate(from_tuple_type(arg.type_annotation, arg)):
let_in_arg = scope.let("in_arg_{0}".format(input_pos + j), subexp)
sh_of = self.visit(self.shape_of(let_in_arg))
shape_func_ins.append(
scope.let("in_shape_{0}".format(input_pos + j), sh_of))
input_pos += 1
is_inputs.append(0)
# Pass Inputs
elif state == 1:
new_arg = self.visit(arg)
shape_func_ins.append(
scope.let("in_shape_{0}".format(input_pos), new_arg))
input_pos += 1
is_inputs.append(1)
else:
# TODO(@jroesch): handle 3rd case
raise Exception("unsupported shape function input state")
out_shapes = []
for i, out in enumerate(cfunc.outputs):
tt = ty.TensorType(out.shape, out.dtype)
alloc = self.make_static_allocation(scope, tt, i)
alloc = scope.let("shape_func_out_{0}".format(i), alloc)
out_shapes.append(alloc)
shape_call = self.shape_func(
func,
expr.Tuple(shape_func_ins),
expr.Tuple(out_shapes), is_inputs)
scope.let("shape_func", shape_call)
storages = []
for out_shape, out_type in zip(out_shapes, out_types):
size = self.compute_storage_in_relay(
out_shape, out_type.dtype)
alignment = self.compute_alignment(out_type.dtype)
sto = scope.let("storage_{i}".format(i=i), self.alloc_storage(
size, alignment, self.default_context, out_type.dtype))
storages.append(sto)
outs = []
sh_ty_storage = zip(out_shapes, out_types, storages)
for i, (out_shape, out_type, storage) in enumerate(sh_ty_storage):
alloc = self.alloc_tensor(
storage,
out_shape,
out_type.dtype,
out_type.shape)
alloc = scope.let("out_{i}".format(i=i), alloc)
outs.append(alloc)
tuple_outs = expr.Tuple(outs)
invoke = self.invoke_tvm(func, ins, tuple_outs)
scope.let("", invoke)
return to_tuple_type(ret_type, tuple_outs.fields)
def is_dynamic(self, ret_type):
is_dynamic = ty.type_has_any(ret_type)
# TODO(@jroesch): restore this code, more complex then it seems
# for arg in call.args:
# is_dynamic = is_dynamic or arg.checked_type.is_dynamic()
return is_dynamic
def visit_call(self, call): def visit_call(self, call):
if is_primitive(call): if is_primitive(call):
# Because we are in ANF we do not need to visit the arguments. # Because we are in ANF we do not need to visit the arguments.
...@@ -174,90 +206,13 @@ class ManifestAllocPass(ExprMutator): ...@@ -174,90 +206,13 @@ class ManifestAllocPass(ExprMutator):
new_args = [self.visit(arg) for arg in call.args] new_args = [self.visit(arg) for arg in call.args]
ins = expr.Tuple(new_args) ins = expr.Tuple(new_args)
ret_type = call.checked_type ret_type = call.checked_type
view = LinearizeRetType(ret_type) out_types = flatten_tuple_type(ret_type)
out_types = view.unpack()
is_dynamic = ty.type_has_any(ret_type)
# TODO(@jroesch): restore this code, more complex then it seems
# for arg in call.args:
# is_dynamic = is_dynamic or arg.checked_type.is_dynamic()
if is_dynamic:
shape_func_ins = []
engine = compile_engine.get()
cfunc = engine.lower_shape_func(call.op, self.target_host)
input_states = cfunc.shape_func_param_states
is_inputs = []
input_pos = 0
for i, (arg, state) in enumerate(zip(new_args, input_states)):
state = int(state)
# Pass Shapes
if state == 2:
if isinstance(arg.type_annotation, ty.TupleType):
for j in range(len(arg.type_annotation.fields)):
let_in_arg = scope.let("in_arg_{0}".format(input_pos + j),
expr.TupleGetItem(arg, j))
sh_of = self.visit(self.shape_of(let_in_arg))
shape_func_ins.append(
scope.let("in_shape_{0}".format(input_pos + j), sh_of))
input_pos += len(arg.type_annotation.fields)
else:
sh_of = self.visit(self.shape_of(arg))
shape_func_ins.append(
scope.let("in_shape_{0}".format(input_pos), sh_of))
input_pos += 1
is_inputs.append(0)
# Pass Inputs
elif state == 1:
new_arg = self.visit(arg)
shape_func_ins.append(
scope.let("in_shape_{0}".format(input_pos), new_arg))
input_pos += 1
is_inputs.append(1)
# TODO(@jroesch): handle 3rd case
else:
raise Exception("unsupported shape function input state")
out_shapes = []
for i, out in enumerate(cfunc.outputs):
tt = ty.TensorType(out.shape, out.dtype)
alloc = self.make_static_allocation(scope, tt, i)
alloc = scope.let("shape_func_out_{0}".format(i), alloc)
out_shapes.append(alloc)
shape_call = self.shape_func(
call.op,
expr.Tuple(shape_func_ins),
expr.Tuple(out_shapes), is_inputs)
scope.let("shape_func", shape_call)
storages = []
for out_shape, out_type in zip(out_shapes, out_types):
size = self.compute_storage_in_relay(
out_shape, out_type.dtype)
alignment = self.compute_alignment(out_type.dtype)
sto = scope.let("storage_{i}".format(i=i), self.alloc_storage(
size, alignment, out_type.dtype))
storages.append(sto)
outs = [] if self.is_dynamic(ret_type):
sh_ty_storage = zip(out_shapes, out_types, storages) # Handle dynamic case.
for i, (out_shape, out_type, storage) in enumerate(sh_ty_storage): return self.dynamic_invoke(scope, call.op, ins, new_args, out_types, ret_type)
alloc = self.alloc_tensor(
storage,
out_shape,
out_type.dtype,
out_type.shape)
alloc = scope.let("out_{i}".format(i=i), alloc)
outs.append(alloc)
tuple_outs = expr.Tuple(outs)
invoke = self.invoke_tvm(call.op, ins, tuple_outs)
scope.let("", invoke)
return outs[0] if len(outs) == 1 else tuple_outs
else: else:
# Handle static case.
outs = [] outs = []
for i, out_ty in enumerate(out_types): for i, out_ty in enumerate(out_types):
out = self.make_static_allocation(scope, out_ty, i) out = self.make_static_allocation(scope, out_ty, i)
...@@ -266,7 +221,7 @@ class ManifestAllocPass(ExprMutator): ...@@ -266,7 +221,7 @@ class ManifestAllocPass(ExprMutator):
output = expr.Tuple(outs) output = expr.Tuple(outs)
invoke = self.invoke_tvm(call.op, ins, output) invoke = self.invoke_tvm(call.op, ins, output)
scope.let("", invoke) scope.let("", invoke)
return view.pack(output) return to_tuple_type(ret_type, output.fields)
else: else:
return super().visit_call(call) return super().visit_call(call)
......
...@@ -579,7 +579,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -579,7 +579,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
auto alignment_register = last_register_; auto alignment_register = last_register_;
// Get the dtype hint from the attributes. // Get the dtype hint from the attributes.
auto alloc_attrs = attrs.as<AllocTensorAttrs>(); auto alloc_attrs = attrs.as<AllocStorageAttrs>();
CHECK(alloc_attrs != nullptr) CHECK(alloc_attrs != nullptr)
<< "must be the alloc tensor attrs"; << "must be the alloc tensor attrs";
auto dtype = alloc_attrs->dtype; auto dtype = alloc_attrs->dtype;
......
...@@ -23,18 +23,19 @@ ...@@ -23,18 +23,19 @@
*/ */
#include <topi/elemwise.h> #include <topi/elemwise.h>
#include <tvm/relay/attrs/memory.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
#include <tvm/relay/attrs/memory.h>
#include "../op_common.h"
#include "../../transforms/infer_layout_util.h" #include "../../transforms/infer_layout_util.h"
#include "../op_common.h"
#include "../type_relations.h" #include "../type_relations.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
TVM_REGISTER_NODE_TYPE(AllocStorageAttrs);
TVM_REGISTER_NODE_TYPE(AllocTensorAttrs); TVM_REGISTER_NODE_TYPE(AllocTensorAttrs);
TVM_REGISTER_NODE_TYPE(ShapeFuncAttrs); TVM_REGISTER_NODE_TYPE(ShapeFuncAttrs);
...@@ -42,9 +43,11 @@ TVM_REGISTER_NODE_TYPE(ShapeFuncAttrs); ...@@ -42,9 +43,11 @@ TVM_REGISTER_NODE_TYPE(ShapeFuncAttrs);
// We should consider a better solution, i.e the type relation // We should consider a better solution, i.e the type relation
// being able to see the arguments as well? // being able to see the arguments as well?
TVM_REGISTER_GLOBAL("relay.op.memory._make.alloc_storage") TVM_REGISTER_GLOBAL("relay.op.memory._make.alloc_storage")
.set_body_typed([](Expr size, Expr alignment, DataType dtype) { .set_body_typed([](Expr size, Expr alignment, TVMContext ctx, DataType dtype_hint) {
auto attrs = make_object<AllocTensorAttrs>(); auto attrs = make_object<AllocStorageAttrs>();
attrs->dtype = dtype; attrs->dtype = dtype_hint;
attrs->device_id = ctx.device_id;
attrs->device_type = ctx.device_type;
static const Op& op = Op::Get("memory.alloc_storage"); static const Op& op = Op::Get("memory.alloc_storage");
return Call(op, {size, alignment}, Attrs(attrs), {}); return Call(op, {size, alignment}, Attrs(attrs), {});
}); });
...@@ -88,29 +91,28 @@ RELAY_REGISTER_OP("memory.alloc_storage") ...@@ -88,29 +91,28 @@ RELAY_REGISTER_OP("memory.alloc_storage")
}); });
TVM_REGISTER_GLOBAL("relay.op.memory._make.alloc_tensor") TVM_REGISTER_GLOBAL("relay.op.memory._make.alloc_tensor")
.set_body_typed( .set_body_typed([](Expr storage, tvm::relay::Expr shape, DataType dtype,
[](Expr storage, tvm::relay::Expr shape, DataType dtype, Array<IndexExpr> assert_shape) { Array<IndexExpr> assert_shape) {
auto attrs = make_object<AllocTensorAttrs>(); auto attrs = make_object<AllocTensorAttrs>();
attrs->dtype = dtype; attrs->dtype = dtype;
if (assert_shape.defined()) { if (assert_shape.defined()) {
attrs->assert_shape = assert_shape; attrs->assert_shape = assert_shape;
} else { } else {
attrs->const_shape = Downcast<Constant>(shape); attrs->const_shape = Downcast<Constant>(shape);
} }
static const Op& op = Op::Get("memory.alloc_tensor"); static const Op& op = Op::Get("memory.alloc_tensor");
return Call(op, {storage, shape}, Attrs(attrs), {}); return Call(op, {storage, shape}, Attrs(attrs), {});
}); });
std::vector<int64_t> FromConstShape(Constant konst) { std::vector<int64_t> FromConstShape(Constant konst) {
runtime::NDArray shape = konst->data; runtime::NDArray shape = konst->data;
std::vector<int64_t> raw_shape; std::vector<int64_t> raw_shape;
DLTensor tensor = shape.ToDLPack()->dl_tensor; DLTensor tensor = shape.ToDLPack()->dl_tensor;
CHECK_EQ(tensor.ndim, 1u); CHECK_EQ(tensor.ndim, 1u);
CHECK_EQ(tensor.dtype.code, 0U) CHECK_EQ(tensor.dtype.code, 0U) << "found " << tensor.dtype.code;
<< "found " << tensor.dtype.code;
CHECK(tensor.dtype.bits == 64 || tensor.dtype.bits == 32) CHECK(tensor.dtype.bits == 64 || tensor.dtype.bits == 32)
<< "found " << static_cast<int>(tensor.dtype.bits); << "found " << static_cast<int>(tensor.dtype.bits);
if (tensor.dtype.bits == 32) { if (tensor.dtype.bits == 32) {
const int32_t* int_ptr = reinterpret_cast<int32_t*>(tensor.data); const int32_t* int_ptr = reinterpret_cast<int32_t*>(tensor.data);
...@@ -209,10 +211,9 @@ bool InvokeTVMOPRel(const Array<Type>& types, int num_inputs, const Attrs& attrs ...@@ -209,10 +211,9 @@ bool InvokeTVMOPRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
} }
TVM_REGISTER_GLOBAL("relay.op.memory._make.invoke_tvm_op") TVM_REGISTER_GLOBAL("relay.op.memory._make.invoke_tvm_op")
.set_body_typed( .set_body_typed([](Expr func, Expr inputs, Expr outputs) {
[](Expr func, Expr inputs, Expr outputs) { return Call(Op::Get("memory.invoke_tvm_op"), {func, inputs, outputs}, Attrs());
return Call(Op::Get("memory.invoke_tvm_op"), {func, inputs, outputs}, Attrs()); });
});
RELAY_REGISTER_OP("memory.invoke_tvm_op") RELAY_REGISTER_OP("memory.invoke_tvm_op")
.describe(R"code(Invoke an operation compiled by TVM.)code" TVM_ADD_FILELINE) .describe(R"code(Invoke an operation compiled by TVM.)code" TVM_ADD_FILELINE)
...@@ -257,37 +258,94 @@ RELAY_REGISTER_OP("memory.kill") ...@@ -257,37 +258,94 @@ RELAY_REGISTER_OP("memory.kill")
}); });
TVM_REGISTER_GLOBAL("relay.op.memory._make.shape_func") TVM_REGISTER_GLOBAL("relay.op.memory._make.shape_func")
.set_body_typed( .set_body_typed([](Expr func, Expr inputs, Expr outputs, Array<tvm::Integer> is_input) {
[](Expr func, Expr inputs, Expr outputs, Array<tvm::Integer> is_input) {
static const Op& op = Op::Get("memory.shape_func"); static const Op& op = Op::Get("memory.shape_func");
auto attrs = make_object<ShapeFuncAttrs>(); auto attrs = make_object<ShapeFuncAttrs>();
attrs->is_input = is_input; attrs->is_input = is_input;
return Call(op, {func, inputs, outputs}, Attrs(attrs), {}); return Call(op, {func, inputs, outputs}, Attrs(attrs), {});
}); });
static void FlattenTypeAux(const Type& type, std::vector<TensorType>* out) { static void FlattenTupleTypeAux(const Type& type, std::vector<TensorType>* out) {
if (auto tt = type.as<TensorTypeNode>()) { if (auto tt = type.as<TensorTypeNode>()) {
out->push_back(GetRef<TensorType>(tt)); out->push_back(GetRef<TensorType>(tt));
} else if (auto tuple_ty = type.as<TupleTypeNode>()) { } else if (auto tuple_ty = type.as<TupleTypeNode>()) {
for (auto field : tuple_ty->fields) { for (auto field : tuple_ty->fields) {
FlattenTypeAux(field, out); FlattenTupleTypeAux(field, out);
} }
} else { } else {
LOG(FATAL) << "unsupported " << type; LOG(FATAL) << "unsupported " << type;
} }
} }
std::vector<TensorType> FlattenType(const Type& type) { std::vector<TensorType> FlattenTupleType(const Type& type) {
std::vector<TensorType> out; std::vector<TensorType> out;
FlattenTypeAux(type, &out); FlattenTupleTypeAux(type, &out);
return out; return out;
} }
Expr PackByType(const Type& t, const Array<Expr>& exprs) { static void FromTupleTypeAux(const Type& type, const Expr& expr, std::vector<Expr>* out) {
LOG(FATAL) << "NYI"; if (type.as<TensorTypeNode>()) {
return Expr(); out->push_back(expr);
} else if (auto tuple_ty = type.as<TupleTypeNode>()) {
for (size_t i = 0; i < tuple_ty->fields.size(); i++) {
FromTupleTypeAux(tuple_ty->fields[i], TupleGetItem(expr, i), out);
}
} else {
LOG(FATAL) << "unsupported " << type;
}
} }
std::vector<Expr> FromTupleType(const Type& type, const Expr& expr) {
std::vector<Expr> out;
FromTupleTypeAux(type, expr, &out);
return out;
}
static void ToTupleTypeAux(const Type& type, const std::vector<Expr>& exprs, int* index,
std::vector<Expr>* out) {
if (type.as<TensorTypeNode>()) {
out->push_back(exprs[*index]);
*index += 1;
} else if (auto tuple_ty = type.as<TupleTypeNode>()) {
std::vector<Expr> tuple_out;
for (size_t i = 0; i < tuple_ty->fields.size(); i++) {
ToTupleTypeAux(tuple_ty->fields[i], exprs, index, &tuple_out);
}
out->push_back(Tuple(tuple_out));
} else {
LOG(FATAL) << "unsupported " << type;
}
}
// Pack the sequence of expressions according to the provided TupleType.
Expr ToTupleType(const Type& t, const std::vector<Expr>& exprs) {
if (t.as<TensorTypeNode>() && exprs.size() == 1) {
return exprs[0];
} else {
std::vector<Expr> out;
int index = 0;
ToTupleTypeAux(t, exprs, &index, &out);
return out[0];
}
}
TVM_REGISTER_GLOBAL("relay.op.memory._make.FlattenTupleType")
.set_body_typed([](Type type) {
auto types = FlattenTupleType(type);
return Array<Type>(types.begin(), types.end());
});
TVM_REGISTER_GLOBAL("relay.op.memory._make.FromTupleType")
.set_body_typed([](Type type, Expr expr) {
auto exprs = FromTupleType(type, expr);
return Array<Expr>(exprs.begin(), exprs.end());
});
TVM_REGISTER_GLOBAL("relay.op.memory._make.ToTupleType")
.set_body_typed([](Type t, Array<Expr> array) {
return ToTupleType(t, std::vector<Expr>(array.begin(), array.end()));
});
bool ShapeFuncRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, bool ShapeFuncRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) { const TypeReporter& reporter) {
CHECK_EQ(types.size(), 4u); CHECK_EQ(types.size(), 4u);
...@@ -298,8 +356,8 @@ bool ShapeFuncRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, ...@@ -298,8 +356,8 @@ bool ShapeFuncRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
CHECK(func_type != nullptr); CHECK(func_type != nullptr);
auto tuple = TupleType(func_type->arg_types); auto tuple = TupleType(func_type->arg_types);
auto in_types = FlattenType(tuple); auto in_types = FlattenTupleType(tuple);
auto out_types = FlattenType(func_type->ret_type); auto out_types = FlattenTupleType(func_type->ret_type);
Array<Type> shape_func_ins, shape_func_outs; Array<Type> shape_func_ins, shape_func_outs;
for (size_t i = 0; i < in_types.size(); i++) { for (size_t i = 0; i < in_types.size(); i++) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment