util.py 3.04 KB
Newer Older
1 2
"""Internal utilities for parsing Python subset to HalideIR"""

3
import ast
4
import inspect
5 6
import logging
import sys
7 8 9 10
import numpy
from .. import api as _api
from .. import make as _make
from .. import expr as _expr
11
from .. import stmt as _stmt
12
from .._ffi.base import numeric_types
13
from ..tensor import Tensor
14
from ..container import Array
15 16 17 18


#pylint: disable=invalid-name
np_arg_types = tuple(list(numeric_types) + [numpy.ndarray])
19
tvm_arg_types = (Tensor, Array, _expr.Var, _expr.ConstExpr)
20 21
halide_imm_types = (_expr.IntImm, _expr.FloatImm, _expr.UIntImm)

22

23 24 25 26 27
def _internal_assert(cond, err):
    """Simplify the code segment like if not XXX then raise an error"""
    if not cond:
        raise ValueError(err)

28 29 30 31 32 33 34

# Useful constants. In avoid of runtime dependences, we use function calls to return them.
def make_nop():
    """Returns a 'no operation' node in HalideIR."""
    return _make.Evaluate(_api.const(0, dtype='int32'))


35 36 37 38 39
def is_docstring(node):
    """Checks if a Python AST node is a docstring"""
    return isinstance(node, ast.Expr) and isinstance(node.value, ast.Str)


40 41
def _pruned_source(func):
    """Prune source code's extra leading spaces"""
42 43 44 45 46 47 48 49 50 51 52
    try:
        lines = inspect.getsource(func).split('\n')
        leading_space = len(lines[0]) - len(lines[0].lstrip(' '))
        lines = [line[leading_space:] for line in lines]
        return '\n'.join(lines)
    except IOError as err:
        if sys.version_info[0] == 2 and str(err) == 'could not get source code':
            logging.log(logging.CRITICAL, \
                        'This module is not fully operated under Python2... ' \
                        'Please move to Python3!')
            raise err
53 54


55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
def replace_io(body, rmap):
    """Replacing tensors usage according to the dict given"""
    from .. import ir_pass

    def replace(op):
        if isinstance(op, _stmt.Provide) and op.func in rmap.keys():
            buf = rmap[op.func]
            return _make.Provide(buf.op, op.value_index, op.value, op.args)
        elif isinstance(op, _expr.Call) and  op.func in rmap.keys():
            buf = rmap[op.func]
            return _make.Call(buf.dtype, buf.name, op.args, \
                              _expr.Call.Halide, buf.op, buf.value_index)
        return None

    return ir_pass.IRTransform(body, None, replace, ['Provide', 'Call'])


72 73 74 75 76
def _is_tvm_arg_types(args):
    """Determine a list of element is either a list of tvm arguments of a list of numpy arguments.
    If neither is true, raise a value error."""
    if isinstance(args[0], tvm_arg_types):
        for elem in args[1:]:
77 78 79
            _internal_assert(isinstance(elem, tvm_arg_types),
                             "Expecting a Var, Tensor or ConstExpr instance but %s get!" \
                             % str(type(elem)))
80
        return True
81 82 83

    _internal_assert(isinstance(args[0], np_arg_types), \
                     "Expect a numpy type but %s get!" % str(type(args[0])))
84
    for elem in args[1:]:
85 86
        _internal_assert(isinstance(elem, np_arg_types), \
                         "Expect a numpy type but %s get!" % str(type(elem)))
87
    return False