container.py 2.33 KB
Newer Older
1
"""Container data structures used in TVM DSL."""
tqchen committed
2
from __future__ import absolute_import as _abs
3
from ._ffi.node import NodeBase, register_node
4
from . import _api_internal
tqchen committed
5 6 7

@register_node
class Array(NodeBase):
8 9 10 11 12 13 14
    """Array container of TVM.

    You do not need to create Array explicitly.
    Normally python list and tuple will be converted automatically
    to Array during tvm function call.
    You may get Array in return values of TVM function call.
    """
tqchen committed
15
    def __getitem__(self, i):
16 17 18 19 20 21
        if isinstance(i, slice):
            start = i.start if i.start is not None else 0
            stop = i.stop if i.stop is not None else len(self)
            step = i.step if i.step is not None else 1
            return [self[idx] for idx in range(start, stop, step)]

tqchen committed
22
        if i >= len(self):
23
            raise IndexError("array index out of range")
24
        return _api_internal._ArrayGetItem(self, i)
tqchen committed
25 26

    def __len__(self):
27
        return _api_internal._ArraySize(self)
tqchen committed
28

29

tqchen committed
30 31
@register_node
class Map(NodeBase):
32 33 34
    """Map container of TVM.

    You do not need to create Map explicitly.
35 36
    Normally python dict will be converted automaticall to Map during tvm function call.
    You can use convert to create a dict[NodeBase-> NodeBase] into a Map
37
    """
tqchen committed
38
    def __getitem__(self, k):
39
        return _api_internal._MapGetItem(self, k)
tqchen committed
40 41

    def __contains__(self, k):
42
        return _api_internal._MapCount(self, k) != 0
tqchen committed
43 44

    def items(self):
45
        """Get the items from the map"""
46
        akvs = _api_internal._MapItems(self)
tqchen committed
47 48 49
        return [(akvs[i], akvs[i+1]) for i in range(0, len(akvs), 2)]

    def __len__(self):
50
        return _api_internal._MapSize(self)
tqchen committed
51

52 53

@register_node
54 55 56 57 58 59 60 61 62 63 64 65
class StrMap(Map):
    """A special map container that has str as key.

    You can use convert to create a dict[str->NodeBase] into a Map.
    """
    def items(self):
        """Get the items from the map"""
        akvs = _api_internal._MapItems(self)
        return [(akvs[i].value, akvs[i+1]) for i in range(0, len(akvs), 2)]


@register_node
66
class Range(NodeBase):
67
    """Represent range in TVM.
68

69 70 71 72
    You do not need to create Range explicitly.
    Python list and tuple will be converted automatically to Range in api functions.
    """
    pass
73 74

@register_node
75 76
class LoweredFunc(NodeBase):
    """Represent a LoweredFunc in TVM."""
77 78 79
    MixedFunc = 0
    HostFunc = 1
    DeviceFunc = 2