# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """Hybrid Script Parser""" import ast import operator import logging import sys import types import numbers from enum import Enum from .util import _internal_assert, _apply_indices from . import calls from . import util from .preprocessor import determine_variable_usage from ..api import all as _all from ..api import any as _any from ..container import Array from ..tensor import Tensor, Operation from .. import _api_internal as _tvm_internal from .. import expr as _expr from .. import stmt as _stmt from .. import make as _make from .. import api as _api from .. import ir_pass as _ir_pass def concat_list_to_block(lst): """Concatenate a list of Python IR nodes to HalideIR Block""" n = len(lst) if n == 1: return lst[0] body = lst[n - 1] for i in range(1, n): stmt = lst[n - 1 - i] if isinstance(stmt, _stmt.AssertStmt): body = _make.AssertStmt(stmt.condition, stmt.message, body) else: body = _make.Block(stmt, body) return body def visit_list_to_block(visit, lst): """Visit and concatenate a list of Python IR nodes to HalideIR Block""" lst = [visit(stmt) for stmt in lst if not util.is_docstring(stmt)] lst = [stmt for stmt in lst if not _ir_pass.Equal(stmt, util.make_nop())] if not lst: return util.make_nop() return concat_list_to_block(lst) class Symbol(Enum): """Enumerates types in the symbol table""" Callable = 0 Input = 1 OutputBuffer = 2 GlobalBuffer = 3 LocalBuffer = 4 SharedBuffer = 5 ConstVar = 6 BufferVar = 7 LoopVar = 8 ConstLoopVar = 9 ThreadBind = 10 class HybridParser(ast.NodeVisitor): """Python AST visitor pass which finally lowers it to HalideIR""" _binop_maker = { ast.Add : operator.add, ast.Sub : operator.sub, ast.Mult : operator.mul, ast.Div : operator.div if sys.version_info[0] == 2 else operator.truediv, ast.FloorDiv: operator.div if sys.version_info[0] == 2 else operator.truediv, ast.Mod : operator.mod, ast.BitOr : operator.or_, ast.BitAnd : operator.and_, ast.BitXor : operator.xor, ast.Gt : operator.gt, ast.GtE : operator.ge, ast.Lt : operator.lt, ast.LtE : operator.le, ast.Eq : operator.eq, ast.NotEq : operator.ne, ast.And : _all, ast.Or : _any, } _unaryop_maker = { ast.USub : operator.neg, ast.Invert : operator.invert, ast.Not : operator.not_ } def __init__(self, args, usage, symbols, closure_vars, func_name=None): """ Parameters ---------- args: A list of tvm.placeholder or tvm.var Provided by the user, the argument list of the function to be lowered. usage: A dict of variables used in last in this function Provided by last lower pass, which collects this information symbols : list of str The symbol list of the global context of the function. closure_vars: dict A dict of external name reference captured by this function. Returns ------- func_name: str The name of the function to be lowered; if not provided, the compiler will use the name in the AST """ self.args = list(args) self.usage = usage.copy() self.symbols = {} # Symbol table for k, v in symbols.items(): if isinstance(v, types.FunctionType): self.add_symbol(k, Symbol.Callable, v) self.closure_vars = closure_vars self.binds = {} # Thread binds self.device = 0 # Is it generating device self.func_name = func_name # The name of the function to be lowered self.outputs = [] # Output tensors' name self.side_effect = set() # Tensors with side effects self.parsed_body = None # The parsed HalideIR body self.returned = False # If this function has a valid return def add_symbol(self, key, ty, val): #pylint: disable=invalid-name """Add value to the symbol table context""" if key in self.symbols.keys(): old = str(self.symbols[key]) new = str((ty, val)) _internal_assert(False, "Name conflict in symbol table! [%s] %s -> %s" % (key, old, new)) self.symbols[key] = ty, val if ty == Symbol.ThreadBind: if val.var.name not in self.binds.keys(): self.binds[val.var.name] = val return val_ = self.binds[val.var.name] _internal_assert(_ir_pass.Equal(val_.dom.extent, val.dom.extent), "Thread extents should be uniform!") self.symbols[key] = ty, val_ def wrap_up_realize(self, node, body): """Wrap up all the variables which will no longer be used""" to_pop = [] for key, val in self.usage.items(): _, level, _ = val if level != node: continue _internal_assert(key in self.symbols.keys(), "Unknown symbol %s!" % key) ty, entry = self.symbols[key] #pylint: disable=invalid-name if ty in [Symbol.Input, Symbol.OutputBuffer]: continue elif 'Buffer' in ty.name: _buf = entry _scope = 'global' if ty is Symbol.BufferVar else ty.name[:-6].lower() to_pop.append(key) else: continue if _scope == 'global': body = self.wrap_up_binds(body) _domain = [_make.range_by_min_extent(0, i) for i in _buf.shape] _dtype = _buf.dtype _true = _api.convert(True) body = _make.Realize(_buf.op, 0, _dtype, _domain, _true, body) body = _make.AttrStmt(_buf.op, 'realize_scope', _api.convert(_scope), body) for elem in to_pop: self.symbols.pop(elem) return body def wrap_up_binds(self, body): for _, iter_var in self.binds.items(): ext = iter_var.dom.extent body = _make.AttrStmt(iter_var, 'thread_extent', ext, body) self.binds = {} return body #pylint: disable=invalid-name, missing-docstring def visit_Module(self, node): _internal_assert(len(node.body) == 1, \ "Only one-function source code will be fed to this parser!") return self.visit(node.body[0]) def visit_FunctionDef(self, node): _internal_assert(len(node.args.args) == len(self.args), \ "The number of arguments passed to the \ function should be the same as it is defined!") if self.func_name is None: self.func_name = node.name for idx, arg in enumerate(node.args.args): _attr = 'id' if sys.version_info[0] < 3 else 'arg' # To make py2 and 3 compatible self.add_symbol(getattr(arg, _attr), Symbol.Input, self.args[idx]) res = visit_list_to_block(self.visit, node.body) res = self.wrap_up_realize(node, res) return self.wrap_up_binds(res) def visit_Expr(self, node): return self.visit(node.value) def visit_Name(self, node): name = node.id if sys.version_info[0] == 2 and name in ['True', 'False']: return _api.convert(ast.literal_eval(name)) if name in self.closure_vars: return _api.convert(self.closure_vars[name]) ty, entry = self.symbols[name] _internal_assert(name in self.symbols, "Unknown symbol %s!" % name) if ty in [Symbol.LoopVar, Symbol.Input, Symbol.ConstLoopVar]: return entry if ty is Symbol.ThreadBind: return entry.var if ty is Symbol.ConstVar: return entry if isinstance(node.ctx, ast.Load) else None if ty is Symbol.BufferVar: if isinstance(node.ctx, ast.Load): return _make.Call(entry.dtype, entry.name, [_api.const(0, 'int32')], \ _expr.Call.Halide, entry.op, entry.value_index) return entry, [_api.const(0, 'int32')] # Do I need any assertion here? return entry def visit_Num(self, node): if isinstance(node.n, numbers.Integral): dtype = "int32" elif isinstance(node.n, float): dtype = "float32" else: _internal_assert(isinstance(node.n, bool), "The data type should be one of (int, float, bool)") dtype = "bool" return _api.const(node.n, dtype) def visit_NameConstant(self, node): return _api.convert(node.value) def visit_AugAssign(self, node): buf = self.visit(node.target) rhs = self.visit(node.value) if isinstance(buf, tuple): _internal_assert(len(buf) == 2, "LHS is supposed to be (buf, args)!") buf, args = buf else: args = [_api.const(0, 'int32')] _internal_assert(isinstance(buf, Tensor), "LHS is supposed to be Tensor!") read = _make.Call(buf.dtype, buf.name, args, _expr.Call.Halide, buf.op, buf.value_index) value = HybridParser._binop_maker[type(node.op)](read, rhs) return _make.Provide(buf.op, 0, value, args) def visit_Assign(self, node): rhs = self.visit(node.value) if isinstance(rhs, Operation): rmap = {} _internal_assert(len(node.targets) == rhs.num_outputs, \ "Unable to detuple the outs to targets") for i in range(rhs.num_outputs): _internal_assert(isinstance(node.targets[i], ast.Name), "You should bind a pure name to the tensors") self.add_symbol(node.targets[i].id, Symbol.GlobalBuffer, rhs.output(i)) rmap[rhs.outputs[i].op] = rhs.output(i) return util.replace_io(rhs.body, rmap) _internal_assert(len(node.targets) == 1, "So far only one-valued assignment is supported!") lhs = node.targets[0] if isinstance(rhs, _expr.Expr): rhs = _ir_pass.Simplify(rhs) if isinstance(lhs, ast.Name): #TODO: support defined intermediate buffer later lhs_ = lhs lhs = lhs.id if lhs in self.symbols.keys(): ty, _ = self.symbols[lhs] _internal_assert(ty != Symbol.LoopVar, \ "Loop variable cannot be overwritten!") decl, _, rw = self.usage[lhs] if decl == lhs_: _internal_assert(lhs not in self.symbols.keys(), "This value should not be defined before this point!") if isinstance(rhs, tuple): shape, dtype, scope = rhs ph = _api.placeholder(shape, dtype=dtype, name=lhs) self.add_symbol(lhs, getattr(Symbol, scope.title() + "Buffer"), ph) if scope == 'output': self.outputs.append(lhs) return util.make_nop() if isinstance(rhs, util.halide_imm_types) and ast.Store not in rw: self.add_symbol(lhs, Symbol.ConstVar, rhs) else: _internal_assert(self.device == 0, "Single variable not supported in devices' side!\n" + \ "If you are using GPU, please allocate a 'local' spad " + \ "outside the bind body") ph = _api.placeholder((1, ), dtype=rhs.dtype, name=lhs) self.add_symbol(lhs, Symbol.BufferVar, ph) lhs = self.visit(lhs_) if lhs is not None: buf, args = lhs return _make.Provide(buf.op, 0, rhs, args) return util.make_nop() lhs, args = self.visit(lhs) _internal_assert(isinstance(lhs, Tensor), \ "An array access's LHS is expected to be a expr.Call!") res = _make.Provide(lhs.op, lhs.value_index, rhs, args) return res def visit_Index(self, node): if isinstance(node.value, ast.Tuple): return self.visit(node.value) return [self.visit(node.value)] def visit_Attribute(self, node): _internal_assert(isinstance(node.value, ast.Name), \ "For atrribute access, only both names are supported so far!") buf = self.visit(node.value) return getattr(buf, node.attr) def visit_Subscript(self, node): args = self.visit(node.slice) if isinstance(node.value, ast.Name): if node.value.id in self.closure_vars: args = ast.literal_eval(str(args)) return _api.convert(_apply_indices(self.closure_vars[node.value.id], args)) buf = self.visit(node.value) if isinstance(buf, Array): for i in args: if isinstance(i, numbers.Integral): buf = buf[i] else: _internal_assert(isinstance(i, (_expr.IntImm, _expr.UIntImm)), \ "All indices are supposed to be constants") buf = buf[i.value] return buf if isinstance(node.ctx, ast.Load): return _make.Call(buf.dtype, buf.name, args, \ _expr.Call.Halide, buf.op, buf.value_index) return buf, args shape = self.visit(node.value) _internal_assert(len(args) == 1, "For 'shape' access the argument should be only one!") args = args[0] #TODO: maybe support non-constant value later? _internal_assert(isinstance(args, (_expr.IntImm, _expr.UIntImm)), \ "So far only constant shape access supported!") return shape[args.value] def visit_With(self, node): if sys.version_info[0] < 3: context = node.context_expr option = node.optional_vars else: _internal_assert(len(node.items) == 1, "Only one with element is supported so far!") context = node.items[0].context_expr option = node.items[0].optional_vars _internal_assert(isinstance(context, ast.Call), "The object must be a Python func call!") _internal_assert(isinstance(option, ast.Name), "The object after 'as' must be an id!") self.annotation[option.id] = context.func.id return visit_list_to_block(self.visit, node.body) def visit_If(self, node): cond = self.visit(node.test) # Return no IfThenElse if proven if isinstance(cond, _expr.UIntImm): if cond.value: return visit_list_to_block(self.visit, node.body) if node.orelse: return visit_list_to_block(self.visit, node.orelse) return util.make_nop() if_body = visit_list_to_block(self.visit, node.body) if node.orelse: else_body = visit_list_to_block(self.visit, node.orelse) else: else_body = None return _make.IfThenElse(cond, if_body, else_body) def visit_IfExp(self, node): cond = self.visit(node.test) if_body = self.visit(node.body) else_body = self.visit(node.orelse) return _make.Select(cond, if_body, else_body) def visit_Compare(self, node): _internal_assert(len(node.ops) == len(node.comparators), "#compare ops != #comparators") ops = [self.visit(node.left)] ops += [self.visit(i) for i in node.comparators] res = [] for i in range(len(node.ops)): lhs = ops[i] rhs = ops[i + 1] res.append(HybridParser._binop_maker[type(node.ops[i])](lhs, rhs)) return _all(*res) def visit_BoolOp(self, node): n = len(node.values) if n == 1: _internal_assert(isinstance(node.op, ast.Not), \ "Unary is supposed to be not!") return operator.not_(self.visit(node.values[0])) _internal_assert(isinstance(node.op, (ast.And, ast.Or)), \ "Binary is supposed to be and/or!") values = [self.visit(i) for i in node.values] return HybridParser._binop_maker[type(node.op)](*values) def visit_UnaryOp(self, node): operand = self.visit(node.operand) return HybridParser._unaryop_maker[type(node.op)](operand) def visit_BinOp(self, node): lhs = self.visit(node.left) rhs = self.visit(node.right) return HybridParser._binop_maker[type(node.op)](lhs, rhs) def visit_Call(self, node): # Yet, no function pointer supported _internal_assert(isinstance(node.func, ast.Name), \ "Only id-function function call is supported so far!") func_id = node.func.id args = [self.visit(i) for i in node.args] # Intrinsics' if hasattr(calls, func_id): return getattr(calls, func_id)(func_id, args) # Contexts' _internal_assert(func_id in self.symbols.keys(), \ "The function called (%s) is not in the context either!" % func_id) ty, entry = self.symbols[func_id] _internal_assert(ty is Symbol.Callable, \ "Are you sure what you call is a function?!") outs = entry(*args) op = outs.op if isinstance(outs, Tensor) else outs[0].op return op def visit_For(self, node): iter_var, low, ext, for_type = self.visit(node.iter) _internal_assert(isinstance(node.target, ast.Name), \ "The loop iterator should be a variable!") _name = node.target.id if isinstance(for_type, tuple): low = _ir_pass.Simplify(low) ext = _ir_pass.Simplify(ext) _internal_assert(isinstance(low, _expr.ConstExpr) and isinstance(ext, _expr.ConstExpr), \ "Const range should start from a const" + \ "and iterate const times") low, ext = low.value, ext.value if ext > 114514: logging.log(logging.CRITICAL, \ '[Warning] Are you sure to unroll a large loop in Python?') bodies = [] for i in range(low, low + ext): self.add_symbol(_name, Symbol.ConstLoopVar, i) body = visit_list_to_block(self.visit, node.body) body = self.wrap_up_realize(node, body) bodies.append(body) self.symbols.pop(_name) return concat_list_to_block(bodies) if iter_var is None: _internal_assert(for_type is not None, "The loop iterating function parse error!") offset = iter_var = _api.var(_name) if not _ir_pass.Equal(low, _api.const(0, 'int32')): offset = iter_var + low self.add_symbol(_name, Symbol.LoopVar, offset) _body = visit_list_to_block(self.visit, node.body) else: _internal_assert(for_type is None, "The loop bind function parse error!") self.add_symbol(_name, Symbol.ThreadBind, iter_var) self.device += 1 _body = visit_list_to_block(self.visit, node.body) self.device -= 1 _body = self.wrap_up_realize(node, _body) if for_type is None: res = _body else: _internal_assert(not isinstance(for_type, tuple), \ "Micro expansion should be handled before!") res = _make.For(iter_var, _api.const(0, 'int32'), ext, for_type, 0, _body) self.symbols.pop(_name) return res def visit_Return(self, node): _internal_assert(all(ty != Symbol.LoopVar for ty, _ in self.symbols.values()), \ "Return should not be in a loop body!") ids = [] if isinstance(node.value, ast.Name): ids = [node.value.id] else: _internal_assert(isinstance(node.value, ast.Tuple), \ "You should return either a single tensor or a tuple") _internal_assert(all(isinstance(i, ast.Name) for i in node.value.elts), \ "What do you return?") ids = [i.id for i in node.value.elts] _internal_assert(len(set(ids)) == len(ids), "Duplicated tensors in the return tuples") if len(ids) < len(self.outputs): logging.log(logging.CRITICAL, '[Warning] Not all the output buffers returned!') self.outputs = [self.symbols[i][1] for i in ids] self.returned = True return util.make_nop() def visit_Tuple(self, node): return tuple(self.visit(i) for i in node.elts) def visit_Str(self, node): return node.s def visit_Assert(self, node): test = self.visit(node.test) mesg = _api.convert(self.visit(node.msg)) return _make.AssertStmt(test, mesg, util.make_nop()) def parse_python(src, args, symbols, closure_vars): """The helper function of calling the AST visitor Parameters ---------- src : ast.node or str If an ast.node, then directly lower it. If a str, then parse it to ast and lower it. args : list of Tensors or Vars The argument lists to the function. It is NOT encouraged to write a function without arguments. It is NOT encouraged to write a function with side effect. symbols : list of str The symbol list of the global context of the function. closure_vars: dict A dict of external name reference captured by this function. Returns ------- root : Stmt The result Halide IR and the parser class instance. """ root = ast.parse(src) if isinstance(src, str) else src _internal_assert(root, ast.AST) var_usage = determine_variable_usage(root, args, symbols, closure_vars) parser = HybridParser(args, var_usage, symbols, closure_vars) parser.parsed_body = parser.visit(root) _internal_assert(parser.returned, 'No valid return found in the function body!') return parser def source_to_op(src, args, symbols, closure_vars): """Another level of wrapper Parameters ---------- src : ast.node or str If an ast.node, then directly lower it. If a str, then parse it to ast and lower it. args : list of Tensors or Vars The argument lists to the function. It is NOT encouraged to write a function without arguments. It is NOT encouraged to write a function with side effect. symbols : list of str The symbol list of the global context of the function. closure_vars: dict A dict of external name reference captured by this function. Returns ------- res : list of output tensors The result of output tensors of the formed OpNode. """ parser = parse_python(src, args, symbols, closure_vars) input_tensors = [] for i in args: if isinstance(i, Tensor): input_tensors.append(i) op = _tvm_internal._HybridOp(parser.func_name, "HybridOp", None, input_tensors, parser.outputs, parser.parsed_body) res = [op.output(i) for i in range(len(parser.outputs))] return res[0] if len(res) == 1 else res