# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """Internal utilities for parsing Python subset to HalideIR""" import ast import inspect import logging import sys import numpy from .. import api as _api from .. import make as _make from .. import expr as _expr from .. import stmt as _stmt from .._ffi.base import numeric_types from ..tensor import Tensor from ..container import Array #pylint: disable=invalid-name np_arg_types = tuple(list(numeric_types) + [numpy.ndarray]) tvm_arg_types = (Tensor, Array, _expr.Var, _expr.ConstExpr) halide_imm_types = (_expr.IntImm, _expr.FloatImm) def _internal_assert(cond, err): """Simplify the code segment like if not XXX then raise an error""" if not cond: raise ValueError(err) # Useful constants. In avoid of runtime dependences, we use function calls to return them. def make_nop(): """Returns a 'no operation' node in HalideIR.""" return _make.Evaluate(_api.const(0, dtype='int32')) def is_docstring(node): """Checks if a Python AST node is a docstring""" return isinstance(node, ast.Expr) and isinstance(node.value, ast.Str) def _pruned_source(func): """Prune source code's extra leading spaces""" try: lines = inspect.getsource(func).split('\n') leading_space = len(lines[0]) - len(lines[0].lstrip(' ')) lines = [line[leading_space:] for line in lines] return '\n'.join(lines) except IOError as err: if sys.version_info[0] == 2 and str(err) == 'could not get source code': logging.log(logging.CRITICAL, \ 'This module is not fully operated under Python2... ' \ 'Please move to Python3!') raise err def replace_io(body, rmap): """Replacing tensors usage according to the dict given""" from .. import ir_pass def replace(op): if isinstance(op, _stmt.Provide) and op.func in rmap.keys(): buf = rmap[op.func] return _make.Provide(buf.op, op.value_index, op.value, op.args) if isinstance(op, _expr.Call) and op.func in rmap.keys(): buf = rmap[op.func] return _make.Call(buf.dtype, buf.name, op.args, \ _expr.Call.Halide, buf.op, buf.value_index) return None return ir_pass.IRTransform(body, None, replace, ['Provide', 'Call']) def _is_tvm_arg_types(args): """Determine a list of element is either a list of tvm arguments of a list of numpy arguments. If neither is true, raise a value error.""" if isinstance(args[0], tvm_arg_types): for elem in args[1:]: _internal_assert(isinstance(elem, tvm_arg_types), "Expecting a Var, Tensor or ConstExpr instance but %s get!" \ % str(type(elem))) return True _internal_assert(isinstance(args[0], np_arg_types), \ "Expect a numpy type but %s get!" % str(type(args[0]))) for elem in args[1:]: _internal_assert(isinstance(elem, np_arg_types), \ "Expect a numpy type but %s get!" % str(type(elem))) return False