calls.py 3.61 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
"""Intrinsics of TVM-Python Hybrid Script for Python compilation time
semantic support."""

from .. import api as _api
from .. import expr as _expr
from .. import make as _make
from ..container import Array
from .. import ir_pass
from ..stmt import For
from .util import _internal_assert

#pylint: disable=redefined-builtin

LOOP_INTRIN = {
15 16 17 18 19
    'range'       : For.Serial,
    'unroll'      : For.Unrolled,
    'parallel'    : For.Parallel,
    'vectorize'   : For.Vectorized,
    'const_range' : (For.Unrolled, ),
20 21
}

22

23 24
def _range(annotation, args):
    """Handling TVM loop types"""
25
    n = args.__len__()
26 27 28 29 30 31 32 33 34 35 36 37
    if n == 1:
        low, ext = _api.const(0, dtype='int32'), args[0]
    else:
        _internal_assert(n == 2, "A loop intrinsic should only have 1 or 2 arguments!")
        low, ext = args[0], args[1]
    if not ir_pass.Equal(low, _api.const(0, dtype='int32')):
        ext = ext - low
    for_type = LOOP_INTRIN[annotation]
    iter_var = None
    return iter_var, low, ext, for_type


38
range = unroll = vectorize = parallel = const_range = _range #pylint: disable=invalid-name
39 40 41 42 43


def bind(func_id, args):
    """Handling TVM thread binding"""
    _internal_assert(func_id == "bind", "This function cannot be directly invoked!")
44
    _internal_assert(args.__len__() == 2, "A loop bind should only have 2 arguments!")
45 46 47
    _internal_assert(isinstance(args[0], str), \
                     "A loop bind's first argument should be a string!")
    iter_var = _api.thread_axis(args[0])
48
    low, ext = _api.const(0, "int32"), args[1]
49 50 51 52 53 54 55 56 57 58 59 60
    for_type = None
    return iter_var, low, ext, for_type


def _math_intrin(func_id, args):
    from .. import intrin
    return getattr(intrin, func_id)(*args)

sqrt = log = exp = tanh = sigmoid = power = popcount = _math_intrin #pylint: disable=invalid-name


def _min_max(func_id, args):
61
    _internal_assert(args.__len__() == 2, "Max/Min function should have 2 elements")
62 63 64 65 66 67 68 69 70
    return getattr(_make, func_id.title())(args[0], args[1])


min = max = _min_max #pylint: disable=invalid-name


def _allocate_tensor(func_id, args):
    """Handling TVM tensor allocation.
    You may refer hybrid.intrin.allocate for more details."""
71
    n = args.__len__()
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
    _internal_assert(isinstance(_api.convert(args[0]), Array), \
                     "allocate's first argument should be a tuple of shape!")
    shape = args[0]
    for i in shape:
        _internal_assert(isinstance(i, _expr.Expr), "The shape should be an expression")
    if n > 1:
        _internal_assert(isinstance(args[1], str),
                         "The data type should be an str")
        _internal_assert(args[1].startswith('int') or args[1].startswith('float'), \
                         "The data type should be either int or float!")
        dtype = args[1]
    else:
        dtype = 'float32'
    if n > 2:
        _internal_assert(isinstance(args[2], str), \
                         "The data scope should be an string")
        _internal_assert(func_id != 'output_tensor', "Output tensor cannot specify scope")
        scope = args[2]
    else:
        scope = 'global' if func_id != 'output_tensor' else 'output'
    return (shape, dtype, scope)

94

95
output_tensor = allocate = _allocate_tensor #pylint: disable=invalid-name
96 97 98 99 100 101 102 103 104 105 106


def len(func_id, args):
    """Iterpret the len function"""
    _internal_assert(args.__len__() == 1, "Only 1 argument is expected!")
    _internal_assert(func_id == "len", "This function cannot be directly invoked!")
    try:
        return _api.convert(args[0].__len__())
    except: #pylint: disable=bare-except
        _internal_assert(args[0].shape.__len__() == 1, "Only one-dimension array can get len")
        return _api.convert(args[0].shape[0])