util.py 3.98 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#   http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18
"""Internal utilities for parsing Python subset to HalideIR"""

import ast
import inspect
21 22
import logging
import sys
23 24 25 26
import numpy
from .. import api as _api
from .. import make as _make
from .. import expr as _expr
from .. import stmt as _stmt
from .._ffi.base import numeric_types
from ..tensor import Tensor
from ..container import Array
31 32 33 34

#pylint: disable=invalid-name
np_arg_types = tuple(list(numeric_types) + [numpy.ndarray])
tvm_arg_types = (Tensor, Array, _expr.Var, _expr.ConstExpr)
36 37
halide_imm_types = (_expr.IntImm, _expr.FloatImm, _expr.UIntImm)


39 40 41 42 43
def _internal_assert(cond, err):
    """Simplify the code segment like if not XXX then raise an error"""
    if not cond:
        raise ValueError(err)

44 45 46 47 48 49 50

# Useful constants. In avoid of runtime dependences, we use function calls to return them.
def make_nop():
    """Returns a 'no operation' node in HalideIR."""
    return _make.Evaluate(_api.const(0, dtype='int32'))

51 52 53 54 55
def is_docstring(node):
    """Checks if a Python AST node is a docstring"""
    return isinstance(node, ast.Expr) and isinstance(node.value, ast.Str)

56 57
def _pruned_source(func):
    """Prune source code's extra leading spaces"""
58 59 60 61 62 63 64 65 66 67 68
        lines = inspect.getsource(func).split('\n')
        leading_space = len(lines[0]) - len(lines[0].lstrip(' '))
        lines = [line[leading_space:] for line in lines]
        return '\n'.join(lines)
    except IOError as err:
        if sys.version_info[0] == 2 and str(err) == 'could not get source code':
            logging.log(logging.CRITICAL, \
                        'This module is not fully operated under Python2... ' \
                        'Please move to Python3!')
            raise err
69 70

71 72 73 74 75 76 77 78
def replace_io(body, rmap):
    """Replacing tensors usage according to the dict given"""
    from .. import ir_pass

    def replace(op):
        if isinstance(op, _stmt.Provide) and op.func in rmap.keys():
            buf = rmap[op.func]
            return _make.Provide(buf.op, op.value_index, op.value, op.args)
        if isinstance(op, _expr.Call) and  op.func in rmap.keys():
80 81 82 83 84 85 86 87
            buf = rmap[op.func]
            return _make.Call(buf.dtype, buf.name, op.args, \
                              _expr.Call.Halide, buf.op, buf.value_index)
        return None

    return ir_pass.IRTransform(body, None, replace, ['Provide', 'Call'])

88 89 90 91 92
def _is_tvm_arg_types(args):
    """Determine a list of element is either a list of tvm arguments of a list of numpy arguments.
    If neither is true, raise a value error."""
    if isinstance(args[0], tvm_arg_types):
        for elem in args[1:]:
93 94 95
            _internal_assert(isinstance(elem, tvm_arg_types),
                             "Expecting a Var, Tensor or ConstExpr instance but %s get!" \
                             % str(type(elem)))
        return True
97 98 99

    _internal_assert(isinstance(args[0], np_arg_types), \
                     "Expect a numpy type but %s get!" % str(type(args[0])))
    for elem in args[1:]:
101 102
        _internal_assert(isinstance(elem, np_arg_types), \
                         "Expect a numpy type but %s get!" % str(type(elem)))
    return False
104 105 106 107 108 109

def _apply_indices(value, indices):
    """Apply multidimensional index"""
    if indices:
        return _apply_indices(value[indices[0]], indices[1:])
    return value