util.py 2.97 KB
Newer Older
1 2
"""Internal utilities for parsing Python subset to HalideIR"""

3
import ast
4
import inspect
5 6
import logging
import sys
7 8 9 10 11 12 13 14 15 16 17
import numpy
from .intrin import HYBRID_GLOBALS
from .._ffi.base import numeric_types
from .. import api as _api
from .. import make as _make
from .. import expr as _expr
from ..tensor import Tensor


#pylint: disable=invalid-name
np_arg_types = tuple(list(numeric_types) + [numpy.ndarray])
18
tvm_arg_types = (Tensor, _expr.Var, _expr.ConstExpr)
19 20
halide_imm_types = (_expr.IntImm, _expr.FloatImm, _expr.UIntImm)

21 22 23 24 25
def _internal_assert(cond, err):
    """Simplify the code segment like if not XXX then raise an error"""
    if not cond:
        raise ValueError(err)

26 27 28 29 30 31 32

# Useful constants. In avoid of runtime dependences, we use function calls to return them.
def make_nop():
    """Returns a 'no operation' node in HalideIR."""
    return _make.Evaluate(_api.const(0, dtype='int32'))


33 34 35 36 37
def is_docstring(node):
    """Checks if a Python AST node is a docstring"""
    return isinstance(node, ast.Expr) and isinstance(node.value, ast.Str)


38 39
def _pruned_source(func):
    """Prune source code's extra leading spaces"""
40 41 42 43 44 45 46 47 48 49 50
    try:
        lines = inspect.getsource(func).split('\n')
        leading_space = len(lines[0]) - len(lines[0].lstrip(' '))
        lines = [line[leading_space:] for line in lines]
        return '\n'.join(lines)
    except IOError as err:
        if sys.version_info[0] == 2 and str(err) == 'could not get source code':
            logging.log(logging.CRITICAL, \
                        'This module is not fully operated under Python2... ' \
                        'Please move to Python3!')
            raise err
51 52 53 54 55 56 57


def _is_tvm_arg_types(args):
    """Determine a list of element is either a list of tvm arguments of a list of numpy arguments.
    If neither is true, raise a value error."""
    if isinstance(args[0], tvm_arg_types):
        for elem in args[1:]:
58 59 60
            _internal_assert(isinstance(elem, tvm_arg_types),
                             "Expecting a Var, Tensor or ConstExpr instance but %s get!" \
                             % str(type(elem)))
61
        return True
62 63 64

    _internal_assert(isinstance(args[0], np_arg_types), \
                     "Expect a numpy type but %s get!" % str(type(args[0])))
65
    for elem in args[1:]:
66 67
        _internal_assert(isinstance(elem, np_arg_types), \
                         "Expect a numpy type but %s get!" % str(type(elem)))
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
    return False


def _enter_hybrid_runtime(func):
    """Put hybrid runtime variables into the global scope"""
    _globals = func.__globals__
    intersect = []
    for elem in list(HYBRID_GLOBALS.keys()):
        if elem in _globals.keys():
            intersect.append((elem, _globals[elem]))
        _globals[elem] = HYBRID_GLOBALS[elem]
    return intersect


def _restore_runtime(func, intersect):
    """Rollback the modification caused by hybrid runtime"""
    _globals = func.__globals__
    for elem in list(HYBRID_GLOBALS.keys()):
        _globals.pop(elem)
    for k, v in intersect:
        _globals[k] = v