container.py 1.98 KB
Newer Older
1
"""Container data structures used in TVM DSL."""
tqchen committed
2
from __future__ import absolute_import as _abs
3
from ._ffi.node import NodeBase, register_node
4
from . import _api_internal
tqchen committed
5 6 7

@register_node
class Array(NodeBase):
8 9 10 11 12 13 14
    """Array container of TVM.

    You do not need to create Array explicitly.
    Normally python list and tuple will be converted automatically
    to Array during tvm function call.
    You may get Array in return values of TVM function call.
    """
tqchen committed
15
    def __getitem__(self, i):
16 17 18 19 20 21
        if isinstance(i, slice):
            start = i.start if i.start is not None else 0
            stop = i.stop if i.stop is not None else len(self)
            step = i.step if i.step is not None else 1
            return [self[idx] for idx in range(start, stop, step)]

tqchen committed
22
        if i >= len(self):
23
            raise IndexError("array index out of range")
24
        return _api_internal._ArrayGetItem(self, i)
tqchen committed
25 26

    def __len__(self):
27
        return _api_internal._ArraySize(self)
tqchen committed
28

29

tqchen committed
30 31
@register_node
class Map(NodeBase):
32 33 34 35 36 37 38
    """Map container of TVM.

    You do not need to create Map explicitly.
    Normally python dict will be converted automatically
    to Array during tvm function call.
    You may get Map in return values of TVM function call.
    """
tqchen committed
39
    def __getitem__(self, k):
40
        return _api_internal._MapGetItem(self, k)
tqchen committed
41 42

    def __contains__(self, k):
43
        return _api_internal._MapCount(self, k) != 0
tqchen committed
44 45

    def items(self):
46
        """Get the items from the map"""
47
        akvs = _api_internal._MapItems(self)
tqchen committed
48 49 50
        return [(akvs[i], akvs[i+1]) for i in range(0, len(akvs), 2)]

    def __len__(self):
51
        return _api_internal._MapSize(self)
tqchen committed
52

53 54 55

@register_node
class Range(NodeBase):
56
    """Represent range in TVM.
57

58 59 60 61
    You do not need to create Range explicitly.
    Python list and tuple will be converted automatically to Range in api functions.
    """
    pass
62 63

@register_node
64 65
class LoweredFunc(NodeBase):
    """Represent a LoweredFunc in TVM."""
66 67 68
    MixedFunc = 0
    HostFunc = 1
    DeviceFunc = 2