memory_alloc.py 11.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
# pylint: disable=no-else-return,invalid-name,len-as-condition,too-many-nested-blocks
18 19 20 21
"""
A pass for manifesting explicit memory allocations.
"""
import numpy as np
22 23
from ..expr_functor import ExprMutator
from ..scope_builder import ScopeBuilder
24
from . import transform
25 26 27 28
from .. import op
from ... import DataType, register_func
from .. import ty, expr
from ..backend import compile_engine
29 30 31


def is_primitive(call):
32 33
    return hasattr(call, 'op') and hasattr(call.op, 'attrs') and \
           hasattr(call.op.attrs, 'Primitive') and int(call.op.attrs.Primitive) == 1
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54

# TODO(@jroesch): port to c++ and unify with existing code
class LinearizeRetType:
    """A linear view of a Relay type, handles a linear order
       for nested tuples, and tensor types.
    """

    def __init__(self, typ):
        """Initialize the linearizer."""
        self.typ = typ

    def unpack(self):
        """Return the linear representation of the type."""
        def _unpack(typ, out):
            # TODO(@jroesch): replace with new flattening pass
            if isinstance(typ, ty.TensorType):
                out.append(typ)
            elif isinstance(typ, ty.TupleType):
                for field_ty in typ.fields:
                    _unpack(field_ty, out)
            else:
55
                raise Exception("unsupported Relay type: {0}".format(typ))
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71

        output = []
        _unpack(self.typ, output)
        return output

    def pack(self, seq):
        """Repack a linear type as a nested type."""
        def _pack(value, typ, out):
            if isinstance(typ, ty.TensorType):
                out.append(value)
            elif isinstance(typ, ty.TupleType):
                tuple_out = []
                for i, field_ty in enumerate(typ.fields):
                    _pack(value[i], field_ty, tuple_out)
                out.append(expr.Tuple(tuple_out))
            else:
72
                raise Exception("unsupported Relay type: {0}".format(typ))
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112

        if len(seq) == 1:
            return seq[0]
        else:
            out = []
            _pack(seq, self.typ, out)
            assert len(out) == 1, "must return fully packed type"
            return out[0]


class ManifestAllocPass(ExprMutator):
    """A pass for explictly manifesting all memory allocations in Relay."""

    def __init__(self, target_host):
        self.invoke_tvm = op.memory.invoke_tvm_op
        self.alloc_storage = op.memory.alloc_storage
        self.alloc_tensor = op.memory.alloc_tensor
        self.shape_func = op.memory.shape_func
        self.scopes = [ScopeBuilder()]
        self.target_host = target_host
        self.compute_dtype = "int64"
        super().__init__()

    def current_scope(self):
        return self.scopes[-1]

    def shape_of(self, e):
        return op.shape_of(e, self.compute_dtype)

    def visit_tuple(self, tup):
        scope = self.current_scope()
        new_fields = []
        for field in tup.fields:
            field = self.visit(field)
            if isinstance(field, expr.Constant):
                field = scope.let('const', field)
            new_fields.append(field)
        return expr.Tuple(new_fields)

    def compute_alignment(self, dtype):
113
        dtype = DataType(dtype)
114 115 116 117 118 119 120 121
        align = (dtype.bits // 8) * dtype.lanes
        # MAGIC CONSTANT FROM device_api.h
        if align < 64:
            align = 64

        return expr.const(align, dtype="int64")

    def compute_storage_in_relay(self, shape, dtype):
122
        dtype = DataType(dtype)
123 124 125 126 127 128 129
        els = op.prod(shape)
        num = expr.const(dtype.bits * dtype.lanes, self.compute_dtype)
        num = num + expr.const(7, self.compute_dtype)
        div = expr.const(8, self.compute_dtype)
        return els * (num / div)

    def compute_storage(self, tensor_type):
130
        dtype = DataType(tensor_type.dtype)
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
        shape = [int(sh) for sh in tensor_type.shape]
        size = 1
        for sh in shape:
            size *= sh
        size *= (dtype.bits * dtype.lanes + 7) // 8
        return expr.const(size, dtype=self.compute_dtype)

    def make_static_allocation(self, scope, tensor_type, i):
        """Allocate a tensor with a statically known shape."""
        shape = [int(sh) for sh in tensor_type.shape]
        if len(shape) == 0:
            shape = expr.const(np.array([]).astype(
                self.compute_dtype), dtype=self.compute_dtype)
        else:
            shape = expr.const(np.array(shape), dtype=self.compute_dtype)
        size = self.compute_storage(tensor_type)
        alignment = self.compute_alignment(tensor_type.dtype)
        dtype = tensor_type.dtype
149
        sto = scope.let("storage_{0}".format(i), self.alloc_storage(
150 151 152
            size, alignment, dtype))
        # TODO(@jroesch): There is a bug with typing based on the constant shape.
        tensor = self.alloc_tensor(sto, shape, dtype, tensor_type.shape)
153
        return scope.let("tensor_{0}".format(i), tensor)
154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176

    def visit_let(self, let):
        scope = ScopeBuilder()

        self.scopes.append(scope)
        while isinstance(let, expr.Let):
            new_val = self.visit(let.value)
            scope.let(let.var, new_val)
            let = let.body

        new_body = self.visit(let)
        scope.ret(new_body)
        self.scopes.pop()

        return scope.get()

    def visit_call(self, call):
        if is_primitive(call):
            # Because we are in ANF we do not need to visit the arguments.
            scope = self.current_scope()
            new_args = [self.visit(arg) for arg in call.args]
            ins = expr.Tuple(new_args)
            ret_type = call.checked_type
177 178
            view = LinearizeRetType(ret_type)
            out_types = view.unpack()
179

180
            is_dynamic = ty.type_has_any(ret_type)
181 182 183 184 185 186 187 188 189 190 191
            # TODO(@jroesch): restore this code, more complex then it seems
            # for arg in call.args:
            #     is_dynamic = is_dynamic or arg.checked_type.is_dynamic()

            if is_dynamic:
                shape_func_ins = []
                engine = compile_engine.get()
                cfunc = engine.lower_shape_func(call.op, self.target_host)
                input_states = cfunc.shape_func_param_states

                is_inputs = []
192
                input_pos = 0
193 194 195 196
                for i, (arg, state) in enumerate(zip(new_args, input_states)):
                    state = int(state)
                    # Pass Shapes
                    if state == 2:
197 198 199 200 201 202 203 204 205 206 207 208 209
                        if isinstance(arg.type_annotation, ty.TupleType):
                            for j in range(len(arg.type_annotation.fields)):
                                let_in_arg = scope.let("in_arg_{0}".format(input_pos + j),
                                                       expr.TupleGetItem(arg, j))
                                sh_of = self.visit(self.shape_of(let_in_arg))
                                shape_func_ins.append(
                                    scope.let("in_shape_{0}".format(input_pos + j), sh_of))
                            input_pos += len(arg.type_annotation.fields)
                        else:
                            sh_of = self.visit(self.shape_of(arg))
                            shape_func_ins.append(
                                scope.let("in_shape_{0}".format(input_pos), sh_of))
                            input_pos += 1
210 211 212 213 214
                        is_inputs.append(0)
                    # Pass Inputs
                    elif state == 1:
                        new_arg = self.visit(arg)
                        shape_func_ins.append(
215 216
                            scope.let("in_shape_{0}".format(input_pos), new_arg))
                        input_pos += 1
217 218 219 220 221 222 223 224 225
                        is_inputs.append(1)
                    # TODO(@jroesch): handle 3rd case
                    else:
                        raise Exception("unsupported shape function input state")

                out_shapes = []
                for i, out in enumerate(cfunc.outputs):
                    tt = ty.TensorType(out.shape, out.dtype)
                    alloc = self.make_static_allocation(scope, tt, i)
226
                    alloc = scope.let("shape_func_out_{0}".format(i), alloc)
227 228 229 230 231 232 233 234 235 236 237 238 239 240
                    out_shapes.append(alloc)

                shape_call = self.shape_func(
                    call.op,
                    expr.Tuple(shape_func_ins),
                    expr.Tuple(out_shapes), is_inputs)

                scope.let("shape_func", shape_call)

                storages = []
                for out_shape, out_type in zip(out_shapes, out_types):
                    size = self.compute_storage_in_relay(
                        out_shape, out_type.dtype)
                    alignment = self.compute_alignment(out_type.dtype)
241
                    sto = scope.let("storage_{i}".format(i=i), self.alloc_storage(
242 243 244 245 246 247 248 249 250 251 252
                        size, alignment, out_type.dtype))
                    storages.append(sto)

                outs = []
                sh_ty_storage = zip(out_shapes, out_types, storages)
                for i, (out_shape, out_type, storage) in enumerate(sh_ty_storage):
                    alloc = self.alloc_tensor(
                        storage,
                        out_shape,
                        out_type.dtype,
                        out_type.shape)
253
                    alloc = scope.let("out_{i}".format(i=i), alloc)
254 255
                    outs.append(alloc)

256 257
                tuple_outs = expr.Tuple(outs)
                invoke = self.invoke_tvm(call.op, ins, tuple_outs)
258
                scope.let("", invoke)
259
                return outs[0] if len(outs) == 1 else tuple_outs
260 261
            else:
                outs = []
262
                for i, out_ty in enumerate(out_types):
263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289
                    out = self.make_static_allocation(scope, out_ty, i)
                    outs.append(out)

                output = expr.Tuple(outs)
                invoke = self.invoke_tvm(call.op, ins, output)
                scope.let("", invoke)
                return view.pack(output)
        else:
            return super().visit_call(call)


@transform.function_pass(opt_level=0)
class ManifestAlloc:
    """The explicit pass wrapper around ManifestAlloc."""
    def __init__(self, target_host):
        self.target_host = target_host

    def transform_function(self, func, mod, _):
        # TODO(@jroesch): Is there a way to do one shot initilization?
        # can we have def pass_init?
        mod.import_from_std("core.rly")
        ea = ManifestAllocPass(self.target_host)
        func = ea.visit(func)
        return func


register_func("relay.transform.ManifestAlloc", ManifestAlloc)