parser.py 23.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18 19 20
"""Hybrid Script Parser"""

import ast
import operator
21
import logging
22
import sys
23 24 25 26
import types
import numbers

from enum import Enum
27

28
from .util import _internal_assert
29 30
from . import calls
from . import util
Jian Weng committed
31
from .preprocessor import determine_variable_usage
32 33
from ..api import all as _all
from ..api import any as _any
34

35
from ..container import Array
36
from ..tensor import Tensor, Operation
37
from .. import _api_internal as _tvm_internal
38 39 40 41 42
from .. import expr as _expr
from .. import make as _make
from .. import api  as _api
from .. import ir_pass as _ir_pass

43

44 45
def concat_list_to_block(lst):
    """Concatenate a list of Python IR nodes to HalideIR Block"""
46 47
    if not lst:
        return util.make_nop()
48 49
    n = len(lst)
    if n == 1:
50
        return lst[0]
51 52 53
    body = lst[n - 1]
    for i in range(1, n):
        stmt = lst[n - 1 - i]
54
        body = _make.Block(stmt, body)
55 56 57
    return body


58
def visit_list_to_block(visit, lst):
59
    """Visit and concatenate a list of Python IR nodes to HalideIR Block"""
60 61 62 63
    lst = [visit(stmt) for stmt in lst if not util.is_docstring(stmt)]
    lst = [stmt for stmt in lst if not _ir_pass.Equal(stmt, util.make_nop())]
    if not lst:
        return util.make_nop()
64
    return concat_list_to_block(lst)
65 66 67 68 69 70 71 72 73 74 75 76 77 78


class Symbol(Enum):
    """Enumerates types in the symbol table"""
    Callable = 0
    Input = 1
    OutputBuffer = 2
    GlobalBuffer = 3
    LocalBuffer = 4
    SharedBuffer = 5
    ConstVar = 6
    BufferVar = 7
    LoopVar = 8
    ConstLoopVar = 9
Jian Weng committed
79
    ThreadBind = 10
80 81


82 83 84 85 86 87 88 89 90 91 92 93
def _floordiv(x, y):
    if isinstance(x, _expr.ExprOp) or isinstance(y, _expr.ExprOp):
        return _api.floordiv(x, y)
    return operator.floordiv(x, y)


def _floormod(x, y):
    if isinstance(x, _expr.ExprOp) or isinstance(y, _expr.ExprOp):
        return _api.floormod(x, y)
    return operator.mod(x, y)


94 95 96 97 98
class HybridParser(ast.NodeVisitor):
    """Python AST visitor pass which finally lowers it to HalideIR"""


    _binop_maker = {
99 100 101 102
        ast.Add     : operator.add,
        ast.Sub     : operator.sub,
        ast.Mult    : operator.mul,
        ast.Div     : operator.div if sys.version_info[0] == 2 else operator.truediv,
103 104
        ast.FloorDiv: _floordiv,
        ast.Mod     : _floormod,
105 106 107 108 109 110 111 112 113
        ast.BitOr   : operator.or_,
        ast.BitAnd  : operator.and_,
        ast.BitXor  : operator.xor,
        ast.Gt      : operator.gt,
        ast.GtE     : operator.ge,
        ast.Lt      : operator.lt,
        ast.LtE     : operator.le,
        ast.Eq      : operator.eq,
        ast.NotEq   : operator.ne,
114 115
        ast.And     : _all,
        ast.Or      : _any,
116 117 118 119 120 121 122 123 124 125
    }


    _unaryop_maker = {
        ast.USub   : operator.neg,
        ast.Invert : operator.invert,
        ast.Not    : operator.not_
    }


126
    def __init__(self, args, usage, symbols, closure_vars, func_name=None):
127 128 129 130 131 132 133 134 135
        """
        Parameters
        ----------
        args: A list of tvm.placeholder or tvm.var
            Provided by the user, the argument list of the function to be lowered.

        usage: A dict of variables used in last in this function
            Provided by last lower pass, which collects this information

136 137 138 139 140 141
        symbols : list of str
            The symbol list of the global context of the function.

        closure_vars: dict
            A dict of external name reference captured by this function.

142 143 144 145 146 147
        Returns
        -------
        func_name: str
            The name of the function to be lowered; if not provided,
            the compiler will use the name in the AST
        """
148
        self.args = list(args)
149
        self.usage = usage.copy()
150 151 152 153

        self.symbols = {} # Symbol table
        for k, v in symbols.items():
            if isinstance(v, types.FunctionType):
Jian Weng committed
154 155
                self.add_symbol(k, Symbol.Callable, v)

156 157
        self.closure_vars = closure_vars

Jian Weng committed
158 159
        self.binds = {} # Thread binds
        self.device = 0 # Is it generating device
160

161
        self.func_name = func_name # The name of the function to be lowered
162 163 164
        self.outputs = [] # Output tensors' name
        self.side_effect = set() # Tensors with side effects
        self.parsed_body = None # The parsed HalideIR body
165
        self.returned = False # If this function has a valid return
166

167

Jian Weng committed
168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186
    def add_symbol(self, key, ty, val): #pylint: disable=invalid-name
        """Add value to the symbol table context"""
        if key in self.symbols.keys():
            old = str(self.symbols[key])
            new = str((ty, val))
            _internal_assert(False,
                             "Name conflict in symbol table! [%s] %s -> %s" % (key, old, new))

        self.symbols[key] = ty, val

        if ty == Symbol.ThreadBind:
            if val.var.name not in self.binds.keys():
                self.binds[val.var.name] = val
                return
            val_ = self.binds[val.var.name]
            _internal_assert(_ir_pass.Equal(val_.dom.extent, val.dom.extent),
                             "Thread extents should be uniform!")
            self.symbols[key] = ty, val_

187 188 189

    def wrap_up_realize(self, node, body):
        """Wrap up all the variables which will no longer be used"""
190
        to_pop = []
191
        for key, val in self.usage.items():
192
            _, level, _ = val
193 194 195
            if key not in self.symbols:
                # don't realize the symbols that are never visited
                continue
196 197
            if level != node:
                continue
198 199 200 201
            _internal_assert(key in self.symbols.keys(), "Unknown symbol %s!" % key)

            ty, entry = self.symbols[key] #pylint: disable=invalid-name
            if ty in [Symbol.Input, Symbol.OutputBuffer]:
202
                continue
203 204
            elif 'Buffer' in ty.name:
                _buf = entry
Jian Weng committed
205
                _scope = 'global' if ty is Symbol.BufferVar else ty.name[:-6].lower()
206
                to_pop.append(key)
207
            else:
208 209
                continue

Jian Weng committed
210 211 212
            if _scope == 'global':
                body = self.wrap_up_binds(body)

213 214 215 216 217 218
            _domain = [_make.range_by_min_extent(0, i) for i in _buf.shape]
            _dtype = _buf.dtype
            _true = _api.convert(True)
            body = _make.Realize(_buf.op, 0, _dtype, _domain, _true, body)
            body = _make.AttrStmt(_buf.op, 'realize_scope', _api.convert(_scope), body)

219 220
        for elem in to_pop:
            self.symbols.pop(elem)
221

222
        return body
223 224


Jian Weng committed
225 226 227 228 229 230 231 232
    def wrap_up_binds(self, body):
        for _, iter_var in self.binds.items():
            ext = iter_var.dom.extent
            body = _make.AttrStmt(iter_var, 'thread_extent', ext, body)
        self.binds = {}
        return body


233 234
    #pylint: disable=invalid-name, missing-docstring
    def visit_Module(self, node):
235
        _internal_assert(len(node.body) == 1, \
236
                         "Only one-function source code will be fed to this parser!")
237 238 239 240
        return self.visit(node.body[0])


    def visit_FunctionDef(self, node):
241 242 243
        _internal_assert(len(node.args.args) == len(self.args), \
                         "The number of arguments passed to the \
                         function should be the same as it is defined!")
244 245
        if self.func_name is None:
            self.func_name = node.name
246 247
        for idx, arg in enumerate(node.args.args):
            _attr = 'id' if sys.version_info[0] < 3 else 'arg' # To make py2 and 3 compatible
Jian Weng committed
248
            self.add_symbol(getattr(arg, _attr), Symbol.Input, self.args[idx])
249
        res = visit_list_to_block(self.visit, node.body)
250
        res = self.wrap_up_realize(node, res)
Jian Weng committed
251
        return self.wrap_up_binds(res)
252 253 254 255 256 257 258


    def visit_Expr(self, node):
        return self.visit(node.value)


    def visit_Name(self, node):
259
        name = node.id
260
        if sys.version_info[0] == 2 and name in ['True', 'False']:
261 262 263 264 265
            return _api.convert(ast.literal_eval(name))

        if name in self.closure_vars:
            return _api.convert(self.closure_vars[name])

266 267 268 269
        ty, entry = self.symbols[name]
        _internal_assert(name in self.symbols, "Unknown symbol %s!" % name)
        if ty in [Symbol.LoopVar, Symbol.Input, Symbol.ConstLoopVar]:
            return entry
Jian Weng committed
270 271
        if ty is Symbol.ThreadBind:
            return entry.var
272
        if ty is Symbol.ConstVar:
273
            return entry if isinstance(node.ctx, ast.Load) else None
274
        if ty is Symbol.BufferVar:
275
            if isinstance(node.ctx, ast.Load):
276 277 278 279 280
                return _make.Call(entry.dtype, entry.name, [_api.const(0, 'int32')], \
                                  _expr.Call.Halide, entry.op, entry.value_index)
            return entry, [_api.const(0, 'int32')]
        # Do I need any assertion here?
        return entry
281 282 283


    def visit_Num(self, node):
284 285 286 287 288 289 290 291 292
        if isinstance(node.n, numbers.Integral):
            dtype = "int32"
        elif isinstance(node.n, float):
            dtype = "float32"
        else:
            _internal_assert(isinstance(node.n, bool),
                             "The data type should be one of (int, float, bool)")
            dtype = "bool"
        return _api.const(node.n, dtype)
293 294


295 296 297 298
    def visit_NameConstant(self, node):
        return _api.convert(node.value)


299
    def visit_AugAssign(self, node):
300
        buf = self.visit(node.target)
301
        rhs = self.visit(node.value)
302 303 304 305
        if isinstance(buf, tuple):
            _internal_assert(len(buf) == 2, "LHS is supposed to be (buf, args)!")
            buf, args = buf
        else:
306
            args = [_api.const(0, 'int32')]
307 308 309 310 311 312
        _internal_assert(isinstance(buf, Tensor), "LHS is supposed to be Tensor!")

        read = _make.Call(buf.dtype, buf.name, args, _expr.Call.Halide, buf.op, buf.value_index)
        value = HybridParser._binop_maker[type(node.op)](read, rhs)

        return _make.Provide(buf.op, 0, value, args)
313 314


315
    def visit_Assign(self, node):
316 317 318 319 320 321 322 323
        rhs = self.visit(node.value)
        if isinstance(rhs, Operation):
            rmap = {}
            _internal_assert(len(node.targets) == rhs.num_outputs, \
                             "Unable to detuple the outs to targets")
            for i in range(rhs.num_outputs):
                _internal_assert(isinstance(node.targets[i], ast.Name),
                                 "You should bind a pure name to the tensors")
Jian Weng committed
324
                self.add_symbol(node.targets[i].id, Symbol.GlobalBuffer, rhs.output(i))
325 326 327
                rmap[rhs.outputs[i].op] = rhs.output(i)
            return util.replace_io(rhs.body, rmap)

328
        _internal_assert(len(node.targets) == 1, "So far only one-valued assignment is supported!")
329
        lhs = node.targets[0]
330 331
        if isinstance(rhs, _expr.Expr):
            rhs = _ir_pass.Simplify(rhs)
332 333 334 335
        if isinstance(lhs, ast.Name):
            #TODO: support defined intermediate buffer later
            lhs_ = lhs
            lhs = lhs.id
336 337 338 339
            if lhs in self.symbols.keys():
                ty, _ = self.symbols[lhs]
                _internal_assert(ty != Symbol.LoopVar, \
                                 "Loop variable cannot be overwritten!")
340 341
            decl, _, rw = self.usage[lhs]
            if decl == lhs_:
342
                _internal_assert(lhs not in self.symbols.keys(),
343
                                 "This value should not be defined before this point!")
344 345 346
                if isinstance(rhs, tuple):
                    shape, dtype, scope = rhs
                    ph = _api.placeholder(shape, dtype=dtype, name=lhs)
Jian Weng committed
347
                    self.add_symbol(lhs, getattr(Symbol, scope.title() + "Buffer"), ph)
348
                    if scope == 'output':
349
                        self.outputs.append(lhs)
350 351
                    return util.make_nop()
                if isinstance(rhs, util.halide_imm_types) and ast.Store not in rw:
Jian Weng committed
352
                    self.add_symbol(lhs, Symbol.ConstVar, rhs)
353
                else:
Jian Weng committed
354 355 356 357
                    _internal_assert(self.device == 0,
                                     "Single variable not supported in devices' side!\n" + \
                                     "If you are using GPU, please allocate a 'local' spad " + \
                                     "outside the bind body")
358
                    ph = _api.placeholder((1, ), dtype=rhs.dtype, name=lhs)
Jian Weng committed
359
                    self.add_symbol(lhs, Symbol.BufferVar, ph)
360 361 362 363 364
            lhs = self.visit(lhs_)
            if lhs is not None:
                buf, args = lhs
                return _make.Provide(buf.op, 0, rhs, args)
            return util.make_nop()
365 366 367 368 369 370

        lhs, args = self.visit(lhs)
        _internal_assert(isinstance(lhs, Tensor), \
                         "An array access's LHS is expected to be a expr.Call!")
        res = _make.Provide(lhs.op, lhs.value_index, rhs, args)
        return res
371 372 373 374


    def visit_Index(self, node):
        if isinstance(node.value, ast.Tuple):
375
            return self.visit(node.value)
376 377 378
        return [self.visit(node.value)]


379
    def visit_Attribute(self, node):
380
        buf = self.visit(node.value)
381 382
        return getattr(buf, node.attr)

383 384
    def visit_Subscript(self, node):
        args = self.visit(node.slice)
385 386 387 388 389 390 391 392 393 394 395 396 397 398
        arr = self.visit(node.value)
        if isinstance(arr, Array):
            for i in args:
                if isinstance(i, numbers.Integral):
                    arr = arr[i]
                else:
                    _internal_assert(isinstance(i, (_expr.IntImm, _expr.UIntImm)), \
                                     "All indices are supposed to be constants")
                    arr = arr[i.value]
            return arr
        if isinstance(node.ctx, ast.Load):
            return _make.Call(arr.dtype, arr.name, args,
                              _expr.Call.Halide, arr.op, arr.value_index)
        return arr, args
399 400 401 402 403 404

    def visit_With(self, node):
        if sys.version_info[0] < 3:
            context = node.context_expr
            option = node.optional_vars
        else:
405
            _internal_assert(len(node.items) == 1, "Only one with element is supported so far!")
406 407
            context = node.items[0].context_expr
            option = node.items[0].optional_vars
408 409
        _internal_assert(isinstance(context, ast.Call), "The object must be a Python func call!")
        _internal_assert(isinstance(option, ast.Name), "The object after 'as' must be an id!")
410
        self.annotation[option.id] = context.func.id
411
        return visit_list_to_block(self.visit, node.body)
412 413 414


    def visit_If(self, node):
415
        cond = _ir_pass.CanonicalSimplify(self.visit(node.test))
416 417 418 419 420

        # Return no IfThenElse if proven
        if isinstance(cond, _expr.UIntImm):
            if cond.value:
                return visit_list_to_block(self.visit, node.body)
421
            if node.orelse:
422 423 424
                return visit_list_to_block(self.visit, node.orelse)
            return util.make_nop()

425
        if_body = visit_list_to_block(self.visit, node.body)
426

427
        if node.orelse:
428
            else_body = visit_list_to_block(self.visit, node.orelse)
429
        else:
Jian Weng committed
430
            else_body = None
431 432 433 434 435 436 437 438 439 440 441
        return _make.IfThenElse(cond, if_body, else_body)


    def visit_IfExp(self, node):
        cond = self.visit(node.test)
        if_body = self.visit(node.body)
        else_body = self.visit(node.orelse)
        return _make.Select(cond, if_body, else_body)


    def visit_Compare(self, node):
442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459
        _internal_assert(len(node.ops) == len(node.comparators),
                         "#compare ops != #comparators")
        ops = [self.visit(node.left)]
        ops += [self.visit(i) for i in node.comparators]
        res = []
        for i in range(len(node.ops)):
            lhs = ops[i]
            rhs = ops[i + 1]
            res.append(HybridParser._binop_maker[type(node.ops[i])](lhs, rhs))
        return _all(*res)


    def visit_BoolOp(self, node):
        n = len(node.values)
        if n == 1:
            _internal_assert(isinstance(node.op, ast.Not), \
                             "Unary is supposed to be not!")
            return operator.not_(self.visit(node.values[0]))
460 461 462 463
        _internal_assert(isinstance(node.op, (ast.And, ast.Or)), \
                         "Binary is supposed to be and/or!")
        values = [self.visit(i) for i in node.values]
        return HybridParser._binop_maker[type(node.op)](*values)
464 465 466 467 468 469 470 471 472 473 474 475 476 477 478


    def visit_UnaryOp(self, node):
        operand = self.visit(node.operand)
        return HybridParser._unaryop_maker[type(node.op)](operand)


    def visit_BinOp(self, node):
        lhs = self.visit(node.left)
        rhs = self.visit(node.right)
        return HybridParser._binop_maker[type(node.op)](lhs, rhs)


    def visit_Call(self, node):
        # Yet, no function pointer supported
479
        _internal_assert(isinstance(node.func, ast.Name), \
480
                         "Only id-function function call is supported so far!")
481

482
        func_id = node.func.id
483
        args = [self.visit(i) for i in node.args]
484 485
        # Intrinsics'
        if hasattr(calls, func_id):
486
            return getattr(calls, func_id)(func_id, args)
487 488 489 490 491 492 493 494 495
        # Contexts'
        _internal_assert(func_id in self.symbols.keys(), \
                         "The function called (%s) is not in the context either!" % func_id)
        ty, entry = self.symbols[func_id]
        _internal_assert(ty is Symbol.Callable, \
                         "Are you sure what you call is a function?!")
        outs = entry(*args)
        op = outs.op if isinstance(outs, Tensor) else outs[0].op
        return op
496 497 498 499


    def visit_For(self, node):
        iter_var, low, ext, for_type = self.visit(node.iter)
500
        _internal_assert(isinstance(node.target, ast.Name), \
501
                         "The loop iterator should be a variable!")
502

503
        _name = node.target.id
504 505

        if isinstance(for_type, tuple):
506 507
            low = _ir_pass.CanonicalSimplify(low)
            ext = _ir_pass.CanonicalSimplify(ext)
508 509
            _internal_assert(isinstance(low, _expr.ConstExpr) and
                             isinstance(ext, _expr.ConstExpr), \
510
                             "Const range should start from a const " + \
511 512 513 514 515 516 517 518 519
                             "and iterate const times")

            low, ext = low.value, ext.value
            if ext > 114514:
                logging.log(logging.CRITICAL, \
                            '[Warning] Are you sure to unroll a large loop in Python?')

            bodies = []
            for i in range(low, low + ext):
Jian Weng committed
520
                self.add_symbol(_name, Symbol.ConstLoopVar, i)
521 522 523
                body = visit_list_to_block(self.visit, node.body)
                body = self.wrap_up_realize(node, body)
                bodies.append(body)
Jian Weng committed
524
                self.symbols.pop(_name)
525
            return concat_list_to_block(bodies)
526

527
        if iter_var is None:
Jian Weng committed
528
            _internal_assert(for_type is not None, "The loop iterating function parse error!")
529
            offset = iter_var = _api.var(_name)
530
            if not _ir_pass.Equal(low, _api.const(0, 'int32')):
531
                offset = iter_var + low
Jian Weng committed
532
            self.add_symbol(_name, Symbol.LoopVar, offset)
533
            _body = visit_list_to_block(self.visit, node.body)
534
        else:
Jian Weng committed
535 536 537
            _internal_assert(for_type is None, "The loop bind function parse error!")
            self.add_symbol(_name, Symbol.ThreadBind, iter_var)
            self.device += 1
538
            _body = visit_list_to_block(self.visit, node.body)
Jian Weng committed
539
            self.device -= 1
540

541
        _body = self.wrap_up_realize(node, _body)
542

543
        if for_type is None:
Jian Weng committed
544
            res = _body
545
        else:
546 547 548
            _internal_assert(not isinstance(for_type, tuple), \
                            "Micro expansion should be handled before!")
            res = _make.For(iter_var, _api.const(0, 'int32'), ext, for_type, 0, _body)
549

550
        self.symbols.pop(_name)
551 552 553
        return res


554
    def visit_Return(self, node):
555 556
        _internal_assert(all(ty != Symbol.LoopVar for ty, _ in self.symbols.values()), \
                         "Return should not be in a loop body!")
557 558
        ids = []
        if isinstance(node.value, ast.Name):
559
            ids = [node.value.id]
560 561
        else:
            _internal_assert(isinstance(node.value, ast.Tuple), \
562
                             "You should return either a single tensor or a tuple")
563 564 565
            _internal_assert(all(isinstance(i, ast.Name) for i in node.value.elts), \
                             "What do you return?")
            ids = [i.id for i in node.value.elts]
566
        _internal_assert(len(set(ids)) == len(ids), "Duplicated tensors in the return tuples")
567
        if len(ids) < len(self.outputs):
568
            logging.log(logging.CRITICAL, '[Warning] Not all the output buffers returned!')
569
        self.outputs = [self.symbols[i][1] for i in ids]
570
        self.returned = True
571 572 573 574 575
        return util.make_nop()


    def visit_Tuple(self, node):
        return tuple(self.visit(i) for i in node.elts)
576 577


578 579 580 581
    def visit_Str(self, node):
        return node.s


582 583 584 585 586 587
    def visit_Assert(self, node):
        test = self.visit(node.test)
        mesg = _api.convert(self.visit(node.msg))
        return _make.AssertStmt(test, mesg, util.make_nop())


588
def parse_python(src, args, symbols, closure_vars):
589 590 591 592
    """The helper function of calling the AST visitor

    Parameters
    ----------
593 594 595
    src : ast.node or str
        If an ast.node, then directly lower it.
        If a str, then parse it to ast and lower it.
596 597 598 599 600 601

    args : list of Tensors or Vars
        The argument lists to the function.
        It is NOT encouraged to write a function without arguments.
        It is NOT encouraged to write a function with side effect.

602 603 604 605 606 607
    symbols : list of str
        The symbol list of the global context of the function.

    closure_vars: dict
        A dict of external name reference captured by this function.

608 609 610 611 612
    Returns
    -------
    root : Stmt
        The result Halide IR and the parser class instance.
    """
613 614
    root = ast.parse(src) if isinstance(src, str) else src
    _internal_assert(root, ast.AST)
615 616
    var_usage = determine_variable_usage(root, args, symbols, closure_vars)
    parser = HybridParser(args, var_usage, symbols, closure_vars)
617 618 619
    parser.parsed_body = parser.visit(root)
    _internal_assert(parser.returned, 'No valid return found in the function body!')
    return parser
620 621


622
def source_to_op(src, args, symbols, closure_vars):
623 624 625 626 627 628 629 630 631 632 633 634 635
    """Another level of wrapper

    Parameters
    ----------
    src : ast.node or str
        If an ast.node, then directly lower it.
        If a str, then parse it to ast and lower it.

    args : list of Tensors or Vars
        The argument lists to the function.
        It is NOT encouraged to write a function without arguments.
        It is NOT encouraged to write a function with side effect.

636 637 638 639 640 641
    symbols : list of str
        The symbol list of the global context of the function.

    closure_vars: dict
        A dict of external name reference captured by this function.

642 643 644 645 646
    Returns
    -------
    res : list of output tensors
        The result of output tensors of the formed OpNode.
    """
647
    parser = parse_python(src, args, symbols, closure_vars)
648 649

    input_tensors = []
650 651 652 653 654 655 656
    def get_input_tensors(arg):
        if isinstance(arg, Tensor):
            input_tensors.append(arg)
        elif isinstance(arg, Array):
            for i in arg:
                get_input_tensors(i)

657
    for i in args:
658
        get_input_tensors(i)
659 660 661 662
    op = _tvm_internal._HybridOp(parser.func_name, "HybridOp", None, input_tensors,
                                 parser.outputs, parser.parsed_body)
    res = [op.output(i) for i in range(len(parser.outputs))]
    return res[0] if len(res) == 1 else res