container.py 2.91 KB
Newer Older
1
"""Container data structures used in TVM DSL."""
tqchen committed
2
from __future__ import absolute_import as _abs
3
from ._ffi.node import NodeBase, register_node
4
from . import _api_internal
tqchen committed
5 6 7

@register_node
class Array(NodeBase):
8 9 10 11 12 13 14
    """Array container of TVM.

    You do not need to create Array explicitly.
    Normally python list and tuple will be converted automatically
    to Array during tvm function call.
    You may get Array in return values of TVM function call.
    """
tqchen committed
15
    def __getitem__(self, i):
16 17 18 19
        if isinstance(i, slice):
            start = i.start if i.start is not None else 0
            stop = i.stop if i.stop is not None else len(self)
            step = i.step if i.step is not None else 1
20 21 22 23
            if start < 0:
                start += len(self)
            if stop < 0:
                stop += len(self)
24 25
            return [self[idx] for idx in range(start, stop, step)]

26 27 28 29 30
        if i < -len(self) or i >= len(self):
            raise IndexError("Array index out of range. Array size: {}, got index {}"
                             .format(len(self), i))
        if i < 0:
            i += len(self)
31
        return _api_internal._ArrayGetItem(self, i)
tqchen committed
32 33

    def __len__(self):
34
        return _api_internal._ArraySize(self)
tqchen committed
35

36

tqchen committed
37
@register_node
38 39 40 41 42 43 44 45 46 47 48 49 50 51
class EnvFunc(NodeBase):
    """Environment function.

    This is a global function object that can be serialized by its name.
    """
    def __call__(self, *args):
        return _api_internal._EnvFuncCall(self, *args)

    @property
    def func(self):
        return _api_internal._EnvFuncGetPackedFunc(self)


@register_node
tqchen committed
52
class Map(NodeBase):
53 54 55
    """Map container of TVM.

    You do not need to create Map explicitly.
56 57
    Normally python dict will be converted automaticall to Map during tvm function call.
    You can use convert to create a dict[NodeBase-> NodeBase] into a Map
58
    """
tqchen committed
59
    def __getitem__(self, k):
60
        return _api_internal._MapGetItem(self, k)
tqchen committed
61 62

    def __contains__(self, k):
63
        return _api_internal._MapCount(self, k) != 0
tqchen committed
64 65

    def items(self):
66
        """Get the items from the map"""
67
        akvs = _api_internal._MapItems(self)
tqchen committed
68 69 70
        return [(akvs[i], akvs[i+1]) for i in range(0, len(akvs), 2)]

    def __len__(self):
71
        return _api_internal._MapSize(self)
tqchen committed
72

73 74

@register_node
75 76 77 78 79 80 81 82 83 84 85 86
class StrMap(Map):
    """A special map container that has str as key.

    You can use convert to create a dict[str->NodeBase] into a Map.
    """
    def items(self):
        """Get the items from the map"""
        akvs = _api_internal._MapItems(self)
        return [(akvs[i].value, akvs[i+1]) for i in range(0, len(akvs), 2)]


@register_node
87
class Range(NodeBase):
88
    """Represent range in TVM.
89

90 91 92 93
    You do not need to create Range explicitly.
    Python list and tuple will be converted automatically to Range in api functions.
    """
    pass
94 95

@register_node
96 97
class LoweredFunc(NodeBase):
    """Represent a LoweredFunc in TVM."""
98 99 100
    MixedFunc = 0
    HostFunc = 1
    DeviceFunc = 2