parser.py 23.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18 19 20
"""Hybrid Script Parser"""

import ast
import operator
21
import logging
22
import sys
23 24 25 26
import types
import numbers

from enum import Enum
27

28
from .util import _internal_assert
29 30
from . import calls
from . import util
Jian Weng committed
31
from .preprocessor import determine_variable_usage
32 33
from ..api import all as _all
from ..api import any as _any
34

35
from ..container import Array
36
from ..tensor import Tensor, Operation
37
from .. import _api_internal as _tvm_internal
38 39
from .. import expr as _expr
from .. import make as _make
40 41
from .. import stmt as _stmt

42 43 44
from .. import api  as _api
from .. import ir_pass as _ir_pass

45

46 47
def concat_list_to_block(lst):
    """Concatenate a list of Python IR nodes to HalideIR Block"""
48 49
    if not lst:
        return util.make_nop()
50 51
    n = len(lst)
    if n == 1:
52
        return lst[0]
53
    return _stmt.SeqStmt(lst)
54 55


56
def visit_list_to_block(visit, lst):
57
    """Visit and concatenate a list of Python IR nodes to HalideIR Block"""
58 59 60 61
    lst = [visit(stmt) for stmt in lst if not util.is_docstring(stmt)]
    lst = [stmt for stmt in lst if not _ir_pass.Equal(stmt, util.make_nop())]
    if not lst:
        return util.make_nop()
62
    return concat_list_to_block(lst)
63 64 65 66 67 68 69 70 71 72 73 74 75 76


class Symbol(Enum):
    """Enumerates types in the symbol table"""
    Callable = 0
    Input = 1
    OutputBuffer = 2
    GlobalBuffer = 3
    LocalBuffer = 4
    SharedBuffer = 5
    ConstVar = 6
    BufferVar = 7
    LoopVar = 8
    ConstLoopVar = 9
Jian Weng committed
77
    ThreadBind = 10
78 79


80 81 82 83 84 85 86 87 88 89 90 91
def _floordiv(x, y):
    if isinstance(x, _expr.ExprOp) or isinstance(y, _expr.ExprOp):
        return _api.floordiv(x, y)
    return operator.floordiv(x, y)


def _floormod(x, y):
    if isinstance(x, _expr.ExprOp) or isinstance(y, _expr.ExprOp):
        return _api.floormod(x, y)
    return operator.mod(x, y)


92 93 94 95 96
class HybridParser(ast.NodeVisitor):
    """Python AST visitor pass which finally lowers it to HalideIR"""


    _binop_maker = {
97 98 99 100
        ast.Add     : operator.add,
        ast.Sub     : operator.sub,
        ast.Mult    : operator.mul,
        ast.Div     : operator.div if sys.version_info[0] == 2 else operator.truediv,
101 102
        ast.FloorDiv: _floordiv,
        ast.Mod     : _floormod,
103 104 105 106 107 108 109 110 111
        ast.BitOr   : operator.or_,
        ast.BitAnd  : operator.and_,
        ast.BitXor  : operator.xor,
        ast.Gt      : operator.gt,
        ast.GtE     : operator.ge,
        ast.Lt      : operator.lt,
        ast.LtE     : operator.le,
        ast.Eq      : operator.eq,
        ast.NotEq   : operator.ne,
112 113
        ast.And     : _all,
        ast.Or      : _any,
114 115 116 117 118 119 120 121 122 123
    }


    _unaryop_maker = {
        ast.USub   : operator.neg,
        ast.Invert : operator.invert,
        ast.Not    : operator.not_
    }


124
    def __init__(self, args, usage, symbols, closure_vars, func_name=None):
125 126 127 128 129 130 131 132 133
        """
        Parameters
        ----------
        args: A list of tvm.placeholder or tvm.var
            Provided by the user, the argument list of the function to be lowered.

        usage: A dict of variables used in last in this function
            Provided by last lower pass, which collects this information

134 135 136 137 138 139
        symbols : list of str
            The symbol list of the global context of the function.

        closure_vars: dict
            A dict of external name reference captured by this function.

140 141 142 143 144 145
        Returns
        -------
        func_name: str
            The name of the function to be lowered; if not provided,
            the compiler will use the name in the AST
        """
146
        self.args = list(args)
147
        self.usage = usage.copy()
148 149 150 151

        self.symbols = {} # Symbol table
        for k, v in symbols.items():
            if isinstance(v, types.FunctionType):
Jian Weng committed
152 153
                self.add_symbol(k, Symbol.Callable, v)

154 155
        self.closure_vars = closure_vars

Jian Weng committed
156 157
        self.binds = {} # Thread binds
        self.device = 0 # Is it generating device
158

159
        self.func_name = func_name # The name of the function to be lowered
160 161 162
        self.outputs = [] # Output tensors' name
        self.side_effect = set() # Tensors with side effects
        self.parsed_body = None # The parsed HalideIR body
163
        self.returned = False # If this function has a valid return
164

165

Jian Weng committed
166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
    def add_symbol(self, key, ty, val): #pylint: disable=invalid-name
        """Add value to the symbol table context"""
        if key in self.symbols.keys():
            old = str(self.symbols[key])
            new = str((ty, val))
            _internal_assert(False,
                             "Name conflict in symbol table! [%s] %s -> %s" % (key, old, new))

        self.symbols[key] = ty, val

        if ty == Symbol.ThreadBind:
            if val.var.name not in self.binds.keys():
                self.binds[val.var.name] = val
                return
            val_ = self.binds[val.var.name]
            _internal_assert(_ir_pass.Equal(val_.dom.extent, val.dom.extent),
                             "Thread extents should be uniform!")
            self.symbols[key] = ty, val_

185 186 187

    def wrap_up_realize(self, node, body):
        """Wrap up all the variables which will no longer be used"""
188
        to_pop = []
189
        for key, val in self.usage.items():
190
            _, level, _ = val
191 192 193
            if key not in self.symbols:
                # don't realize the symbols that are never visited
                continue
194 195
            if level != node:
                continue
196 197 198 199
            _internal_assert(key in self.symbols.keys(), "Unknown symbol %s!" % key)

            ty, entry = self.symbols[key] #pylint: disable=invalid-name
            if ty in [Symbol.Input, Symbol.OutputBuffer]:
200
                continue
201 202
            elif 'Buffer' in ty.name:
                _buf = entry
Jian Weng committed
203
                _scope = 'global' if ty is Symbol.BufferVar else ty.name[:-6].lower()
204
                to_pop.append(key)
205
            else:
206 207
                continue

Jian Weng committed
208 209 210
            if _scope == 'global':
                body = self.wrap_up_binds(body)

211 212 213 214 215 216
            _domain = [_make.range_by_min_extent(0, i) for i in _buf.shape]
            _dtype = _buf.dtype
            _true = _api.convert(True)
            body = _make.Realize(_buf.op, 0, _dtype, _domain, _true, body)
            body = _make.AttrStmt(_buf.op, 'realize_scope', _api.convert(_scope), body)

217 218
        for elem in to_pop:
            self.symbols.pop(elem)
219

220
        return body
221 222


Jian Weng committed
223 224 225 226 227 228 229 230
    def wrap_up_binds(self, body):
        for _, iter_var in self.binds.items():
            ext = iter_var.dom.extent
            body = _make.AttrStmt(iter_var, 'thread_extent', ext, body)
        self.binds = {}
        return body


231 232
    #pylint: disable=invalid-name, missing-docstring
    def visit_Module(self, node):
233
        _internal_assert(len(node.body) == 1, \
234
                         "Only one-function source code will be fed to this parser!")
235 236 237 238
        return self.visit(node.body[0])


    def visit_FunctionDef(self, node):
239 240 241
        _internal_assert(len(node.args.args) == len(self.args), \
                         "The number of arguments passed to the \
                         function should be the same as it is defined!")
242 243
        if self.func_name is None:
            self.func_name = node.name
244 245
        for idx, arg in enumerate(node.args.args):
            _attr = 'id' if sys.version_info[0] < 3 else 'arg' # To make py2 and 3 compatible
Jian Weng committed
246
            self.add_symbol(getattr(arg, _attr), Symbol.Input, self.args[idx])
247
        res = visit_list_to_block(self.visit, node.body)
248
        res = self.wrap_up_realize(node, res)
Jian Weng committed
249
        return self.wrap_up_binds(res)
250 251 252 253 254 255 256


    def visit_Expr(self, node):
        return self.visit(node.value)


    def visit_Name(self, node):
257
        name = node.id
258
        if sys.version_info[0] == 2 and name in ['True', 'False']:
259 260 261 262 263
            return _api.convert(ast.literal_eval(name))

        if name in self.closure_vars:
            return _api.convert(self.closure_vars[name])

264 265 266 267
        ty, entry = self.symbols[name]
        _internal_assert(name in self.symbols, "Unknown symbol %s!" % name)
        if ty in [Symbol.LoopVar, Symbol.Input, Symbol.ConstLoopVar]:
            return entry
Jian Weng committed
268 269
        if ty is Symbol.ThreadBind:
            return entry.var
270
        if ty is Symbol.ConstVar:
271
            return entry if isinstance(node.ctx, ast.Load) else None
272
        if ty is Symbol.BufferVar:
273
            if isinstance(node.ctx, ast.Load):
274 275 276 277 278
                return _make.Call(entry.dtype, entry.name, [_api.const(0, 'int32')], \
                                  _expr.Call.Halide, entry.op, entry.value_index)
            return entry, [_api.const(0, 'int32')]
        # Do I need any assertion here?
        return entry
279 280 281


    def visit_Num(self, node):
282 283 284 285 286 287 288 289 290
        if isinstance(node.n, numbers.Integral):
            dtype = "int32"
        elif isinstance(node.n, float):
            dtype = "float32"
        else:
            _internal_assert(isinstance(node.n, bool),
                             "The data type should be one of (int, float, bool)")
            dtype = "bool"
        return _api.const(node.n, dtype)
291 292


293 294 295 296
    def visit_NameConstant(self, node):
        return _api.convert(node.value)


297
    def visit_AugAssign(self, node):
298
        buf = self.visit(node.target)
299
        rhs = self.visit(node.value)
300 301 302 303
        if isinstance(buf, tuple):
            _internal_assert(len(buf) == 2, "LHS is supposed to be (buf, args)!")
            buf, args = buf
        else:
304
            args = [_api.const(0, 'int32')]
305 306 307 308 309 310
        _internal_assert(isinstance(buf, Tensor), "LHS is supposed to be Tensor!")

        read = _make.Call(buf.dtype, buf.name, args, _expr.Call.Halide, buf.op, buf.value_index)
        value = HybridParser._binop_maker[type(node.op)](read, rhs)

        return _make.Provide(buf.op, 0, value, args)
311 312


313
    def visit_Assign(self, node):
314 315 316 317 318 319 320 321
        rhs = self.visit(node.value)
        if isinstance(rhs, Operation):
            rmap = {}
            _internal_assert(len(node.targets) == rhs.num_outputs, \
                             "Unable to detuple the outs to targets")
            for i in range(rhs.num_outputs):
                _internal_assert(isinstance(node.targets[i], ast.Name),
                                 "You should bind a pure name to the tensors")
Jian Weng committed
322
                self.add_symbol(node.targets[i].id, Symbol.GlobalBuffer, rhs.output(i))
323 324 325
                rmap[rhs.outputs[i].op] = rhs.output(i)
            return util.replace_io(rhs.body, rmap)

326
        _internal_assert(len(node.targets) == 1, "So far only one-valued assignment is supported!")
327
        lhs = node.targets[0]
328
        if isinstance(rhs, _expr.PrimExpr):
329
            rhs = _ir_pass.Simplify(rhs)
330 331 332 333
        if isinstance(lhs, ast.Name):
            #TODO: support defined intermediate buffer later
            lhs_ = lhs
            lhs = lhs.id
334 335 336 337
            if lhs in self.symbols.keys():
                ty, _ = self.symbols[lhs]
                _internal_assert(ty != Symbol.LoopVar, \
                                 "Loop variable cannot be overwritten!")
338 339
            decl, _, rw = self.usage[lhs]
            if decl == lhs_:
340
                _internal_assert(lhs not in self.symbols.keys(),
341
                                 "This value should not be defined before this point!")
342 343 344
                if isinstance(rhs, tuple):
                    shape, dtype, scope = rhs
                    ph = _api.placeholder(shape, dtype=dtype, name=lhs)
Jian Weng committed
345
                    self.add_symbol(lhs, getattr(Symbol, scope.title() + "Buffer"), ph)
346
                    if scope == 'output':
347
                        self.outputs.append(lhs)
348 349
                    return util.make_nop()
                if isinstance(rhs, util.halide_imm_types) and ast.Store not in rw:
Jian Weng committed
350
                    self.add_symbol(lhs, Symbol.ConstVar, rhs)
351
                else:
Jian Weng committed
352 353 354 355
                    _internal_assert(self.device == 0,
                                     "Single variable not supported in devices' side!\n" + \
                                     "If you are using GPU, please allocate a 'local' spad " + \
                                     "outside the bind body")
356
                    ph = _api.placeholder((1, ), dtype=rhs.dtype, name=lhs)
Jian Weng committed
357
                    self.add_symbol(lhs, Symbol.BufferVar, ph)
358 359 360 361 362
            lhs = self.visit(lhs_)
            if lhs is not None:
                buf, args = lhs
                return _make.Provide(buf.op, 0, rhs, args)
            return util.make_nop()
363 364 365 366 367 368

        lhs, args = self.visit(lhs)
        _internal_assert(isinstance(lhs, Tensor), \
                         "An array access's LHS is expected to be a expr.Call!")
        res = _make.Provide(lhs.op, lhs.value_index, rhs, args)
        return res
369 370 371 372


    def visit_Index(self, node):
        if isinstance(node.value, ast.Tuple):
373
            return self.visit(node.value)
374 375 376
        return [self.visit(node.value)]


377
    def visit_Attribute(self, node):
378
        buf = self.visit(node.value)
379 380
        return getattr(buf, node.attr)

381 382
    def visit_Subscript(self, node):
        args = self.visit(node.slice)
383 384 385 386 387 388
        arr = self.visit(node.value)
        if isinstance(arr, Array):
            for i in args:
                if isinstance(i, numbers.Integral):
                    arr = arr[i]
                else:
389
                    _internal_assert(isinstance(i, (_expr.IntImm,)), \
390 391 392 393 394 395 396
                                     "All indices are supposed to be constants")
                    arr = arr[i.value]
            return arr
        if isinstance(node.ctx, ast.Load):
            return _make.Call(arr.dtype, arr.name, args,
                              _expr.Call.Halide, arr.op, arr.value_index)
        return arr, args
397 398 399 400 401 402

    def visit_With(self, node):
        if sys.version_info[0] < 3:
            context = node.context_expr
            option = node.optional_vars
        else:
403
            _internal_assert(len(node.items) == 1, "Only one with element is supported so far!")
404 405
            context = node.items[0].context_expr
            option = node.items[0].optional_vars
406 407
        _internal_assert(isinstance(context, ast.Call), "The object must be a Python func call!")
        _internal_assert(isinstance(option, ast.Name), "The object after 'as' must be an id!")
408
        self.annotation[option.id] = context.func.id
409
        return visit_list_to_block(self.visit, node.body)
410 411 412


    def visit_If(self, node):
413
        cond = _ir_pass.CanonicalSimplify(self.visit(node.test))
414 415

        # Return no IfThenElse if proven
416
        if isinstance(cond, _expr.IntImm):
417 418
            if cond.value:
                return visit_list_to_block(self.visit, node.body)
419
            if node.orelse:
420 421 422
                return visit_list_to_block(self.visit, node.orelse)
            return util.make_nop()

423
        if_body = visit_list_to_block(self.visit, node.body)
424

425
        if node.orelse:
426
            else_body = visit_list_to_block(self.visit, node.orelse)
427
        else:
Jian Weng committed
428
            else_body = None
429 430 431 432 433 434 435 436 437 438 439
        return _make.IfThenElse(cond, if_body, else_body)


    def visit_IfExp(self, node):
        cond = self.visit(node.test)
        if_body = self.visit(node.body)
        else_body = self.visit(node.orelse)
        return _make.Select(cond, if_body, else_body)


    def visit_Compare(self, node):
440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457
        _internal_assert(len(node.ops) == len(node.comparators),
                         "#compare ops != #comparators")
        ops = [self.visit(node.left)]
        ops += [self.visit(i) for i in node.comparators]
        res = []
        for i in range(len(node.ops)):
            lhs = ops[i]
            rhs = ops[i + 1]
            res.append(HybridParser._binop_maker[type(node.ops[i])](lhs, rhs))
        return _all(*res)


    def visit_BoolOp(self, node):
        n = len(node.values)
        if n == 1:
            _internal_assert(isinstance(node.op, ast.Not), \
                             "Unary is supposed to be not!")
            return operator.not_(self.visit(node.values[0]))
458 459 460 461
        _internal_assert(isinstance(node.op, (ast.And, ast.Or)), \
                         "Binary is supposed to be and/or!")
        values = [self.visit(i) for i in node.values]
        return HybridParser._binop_maker[type(node.op)](*values)
462 463 464 465 466 467 468 469 470 471 472 473 474 475 476


    def visit_UnaryOp(self, node):
        operand = self.visit(node.operand)
        return HybridParser._unaryop_maker[type(node.op)](operand)


    def visit_BinOp(self, node):
        lhs = self.visit(node.left)
        rhs = self.visit(node.right)
        return HybridParser._binop_maker[type(node.op)](lhs, rhs)


    def visit_Call(self, node):
        # Yet, no function pointer supported
477
        _internal_assert(isinstance(node.func, ast.Name), \
478
                         "Only id-function function call is supported so far!")
479

480
        func_id = node.func.id
481
        args = [self.visit(i) for i in node.args]
482 483
        # Intrinsics'
        if hasattr(calls, func_id):
484
            return getattr(calls, func_id)(func_id, args)
485 486 487 488 489 490 491 492 493
        # Contexts'
        _internal_assert(func_id in self.symbols.keys(), \
                         "The function called (%s) is not in the context either!" % func_id)
        ty, entry = self.symbols[func_id]
        _internal_assert(ty is Symbol.Callable, \
                         "Are you sure what you call is a function?!")
        outs = entry(*args)
        op = outs.op if isinstance(outs, Tensor) else outs[0].op
        return op
494 495 496 497


    def visit_For(self, node):
        iter_var, low, ext, for_type = self.visit(node.iter)
498
        _internal_assert(isinstance(node.target, ast.Name), \
499
                         "The loop iterator should be a variable!")
500

501
        _name = node.target.id
502 503

        if isinstance(for_type, tuple):
504 505
            low = _ir_pass.CanonicalSimplify(low)
            ext = _ir_pass.CanonicalSimplify(ext)
506 507
            _internal_assert(isinstance(low, _expr.ConstExpr) and
                             isinstance(ext, _expr.ConstExpr), \
508
                             "Const range should start from a const " + \
509 510 511 512 513 514 515 516 517
                             "and iterate const times")

            low, ext = low.value, ext.value
            if ext > 114514:
                logging.log(logging.CRITICAL, \
                            '[Warning] Are you sure to unroll a large loop in Python?')

            bodies = []
            for i in range(low, low + ext):
Jian Weng committed
518
                self.add_symbol(_name, Symbol.ConstLoopVar, i)
519 520 521
                body = visit_list_to_block(self.visit, node.body)
                body = self.wrap_up_realize(node, body)
                bodies.append(body)
Jian Weng committed
522
                self.symbols.pop(_name)
523
            return concat_list_to_block(bodies)
524

525
        if iter_var is None:
Jian Weng committed
526
            _internal_assert(for_type is not None, "The loop iterating function parse error!")
527
            offset = iter_var = _api.var(_name)
528
            if not _ir_pass.Equal(low, _api.const(0, 'int32')):
529
                offset = iter_var + low
Jian Weng committed
530
            self.add_symbol(_name, Symbol.LoopVar, offset)
531
            _body = visit_list_to_block(self.visit, node.body)
532
        else:
Jian Weng committed
533 534 535
            _internal_assert(for_type is None, "The loop bind function parse error!")
            self.add_symbol(_name, Symbol.ThreadBind, iter_var)
            self.device += 1
536
            _body = visit_list_to_block(self.visit, node.body)
Jian Weng committed
537
            self.device -= 1
538

539
        _body = self.wrap_up_realize(node, _body)
540

541
        if for_type is None:
Jian Weng committed
542
            res = _body
543
        else:
544 545 546
            _internal_assert(not isinstance(for_type, tuple), \
                            "Micro expansion should be handled before!")
            res = _make.For(iter_var, _api.const(0, 'int32'), ext, for_type, 0, _body)
547

548
        self.symbols.pop(_name)
549 550 551
        return res


552
    def visit_Return(self, node):
553 554
        _internal_assert(all(ty != Symbol.LoopVar for ty, _ in self.symbols.values()), \
                         "Return should not be in a loop body!")
555 556
        ids = []
        if isinstance(node.value, ast.Name):
557
            ids = [node.value.id]
558 559
        else:
            _internal_assert(isinstance(node.value, ast.Tuple), \
560
                             "You should return either a single tensor or a tuple")
561 562 563
            _internal_assert(all(isinstance(i, ast.Name) for i in node.value.elts), \
                             "What do you return?")
            ids = [i.id for i in node.value.elts]
564
        _internal_assert(len(set(ids)) == len(ids), "Duplicated tensors in the return tuples")
565
        if len(ids) < len(self.outputs):
566
            logging.log(logging.CRITICAL, '[Warning] Not all the output buffers returned!')
567
        self.outputs = [self.symbols[i][1] for i in ids]
568
        self.returned = True
569 570 571 572 573
        return util.make_nop()


    def visit_Tuple(self, node):
        return tuple(self.visit(i) for i in node.elts)
574 575


576 577 578 579
    def visit_Str(self, node):
        return node.s


580 581 582 583 584 585
    def visit_Assert(self, node):
        test = self.visit(node.test)
        mesg = _api.convert(self.visit(node.msg))
        return _make.AssertStmt(test, mesg, util.make_nop())


586
def parse_python(src, args, symbols, closure_vars):
587 588 589 590
    """The helper function of calling the AST visitor

    Parameters
    ----------
591 592 593
    src : ast.node or str
        If an ast.node, then directly lower it.
        If a str, then parse it to ast and lower it.
594 595 596 597 598 599

    args : list of Tensors or Vars
        The argument lists to the function.
        It is NOT encouraged to write a function without arguments.
        It is NOT encouraged to write a function with side effect.

600 601 602 603 604 605
    symbols : list of str
        The symbol list of the global context of the function.

    closure_vars: dict
        A dict of external name reference captured by this function.

606 607 608 609 610
    Returns
    -------
    root : Stmt
        The result Halide IR and the parser class instance.
    """
611 612
    root = ast.parse(src) if isinstance(src, str) else src
    _internal_assert(root, ast.AST)
613 614
    var_usage = determine_variable_usage(root, args, symbols, closure_vars)
    parser = HybridParser(args, var_usage, symbols, closure_vars)
615 616 617
    parser.parsed_body = parser.visit(root)
    _internal_assert(parser.returned, 'No valid return found in the function body!')
    return parser
618 619


620
def source_to_op(src, args, symbols, closure_vars):
621 622 623 624 625 626 627 628 629 630 631 632 633
    """Another level of wrapper

    Parameters
    ----------
    src : ast.node or str
        If an ast.node, then directly lower it.
        If a str, then parse it to ast and lower it.

    args : list of Tensors or Vars
        The argument lists to the function.
        It is NOT encouraged to write a function without arguments.
        It is NOT encouraged to write a function with side effect.

634 635 636 637 638 639
    symbols : list of str
        The symbol list of the global context of the function.

    closure_vars: dict
        A dict of external name reference captured by this function.

640 641 642 643 644
    Returns
    -------
    res : list of output tensors
        The result of output tensors of the formed OpNode.
    """
645
    parser = parse_python(src, args, symbols, closure_vars)
646 647

    input_tensors = []
648 649 650 651 652 653 654
    def get_input_tensors(arg):
        if isinstance(arg, Tensor):
            input_tensors.append(arg)
        elif isinstance(arg, Array):
            for i in arg:
                get_input_tensors(i)

655
    for i in args:
656
        get_input_tensors(i)
657 658 659 660
    op = _tvm_internal._HybridOp(parser.func_name, "HybridOp", None, input_tensors,
                                 parser.outputs, parser.parsed_body)
    res = [op.output(i) for i in range(len(parser.outputs))]
    return res[0] if len(res) == 1 else res