container.py 3.68 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
"""Container data structures used in TVM DSL."""
tqchen committed
18
from __future__ import absolute_import as _abs
19
from ._ffi.node import NodeBase, register_node
20
from . import _api_internal
tqchen committed
21 22 23

@register_node
class Array(NodeBase):
24 25 26 27 28 29 30
    """Array container of TVM.

    You do not need to create Array explicitly.
    Normally python list and tuple will be converted automatically
    to Array during tvm function call.
    You may get Array in return values of TVM function call.
    """
tqchen committed
31
    def __getitem__(self, i):
32 33 34 35
        if isinstance(i, slice):
            start = i.start if i.start is not None else 0
            stop = i.stop if i.stop is not None else len(self)
            step = i.step if i.step is not None else 1
36 37 38 39
            if start < 0:
                start += len(self)
            if stop < 0:
                stop += len(self)
40 41
            return [self[idx] for idx in range(start, stop, step)]

42 43 44 45 46
        if i < -len(self) or i >= len(self):
            raise IndexError("Array index out of range. Array size: {}, got index {}"
                             .format(len(self), i))
        if i < 0:
            i += len(self)
47
        return _api_internal._ArrayGetItem(self, i)
tqchen committed
48 49

    def __len__(self):
50
        return _api_internal._ArraySize(self)
tqchen committed
51

52

tqchen committed
53
@register_node
54 55 56 57 58 59 60 61 62 63 64 65 66 67
class EnvFunc(NodeBase):
    """Environment function.

    This is a global function object that can be serialized by its name.
    """
    def __call__(self, *args):
        return _api_internal._EnvFuncCall(self, *args)

    @property
    def func(self):
        return _api_internal._EnvFuncGetPackedFunc(self)


@register_node
tqchen committed
68
class Map(NodeBase):
69 70 71
    """Map container of TVM.

    You do not need to create Map explicitly.
72 73
    Normally python dict will be converted automaticall to Map during tvm function call.
    You can use convert to create a dict[NodeBase-> NodeBase] into a Map
74
    """
tqchen committed
75
    def __getitem__(self, k):
76
        return _api_internal._MapGetItem(self, k)
tqchen committed
77 78

    def __contains__(self, k):
79
        return _api_internal._MapCount(self, k) != 0
tqchen committed
80 81

    def items(self):
82
        """Get the items from the map"""
83
        akvs = _api_internal._MapItems(self)
tqchen committed
84 85 86
        return [(akvs[i], akvs[i+1]) for i in range(0, len(akvs), 2)]

    def __len__(self):
87
        return _api_internal._MapSize(self)
tqchen committed
88

89 90

@register_node
91 92 93 94 95 96 97 98 99 100 101 102
class StrMap(Map):
    """A special map container that has str as key.

    You can use convert to create a dict[str->NodeBase] into a Map.
    """
    def items(self):
        """Get the items from the map"""
        akvs = _api_internal._MapItems(self)
        return [(akvs[i].value, akvs[i+1]) for i in range(0, len(akvs), 2)]


@register_node
103
class Range(NodeBase):
104
    """Represent a range in TVM.
105

106 107
    You do not need to create a Range explicitly.
    Python lists and tuples will be converted automatically to a Range in API functions.
108
    """
109

110 111

@register_node
112 113
class LoweredFunc(NodeBase):
    """Represent a LoweredFunc in TVM."""
114 115 116
    MixedFunc = 0
    HostFunc = 1
    DeviceFunc = 2