expr.py 17.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
"""Expression AST Node in TVM.

User do not need to deal with expression AST node directly.
But they can be helpful for developer to do quick proptyping.
While not displayed in the document and python file.
Each expression node have subfields that can be visited from python side.

For example, you can use addexp.a to get the left operand of an Add node.

.. code-block:: python

  x = tvm.var("n")
  y = x + 2
  assert(isinstance(y, tvm.expr.Add))
  assert(y.a == x)
"""
33
# pylint: disable=missing-docstring
tqchen committed
34
from __future__ import absolute_import as _abs
35
from ._ffi.node import NodeBase, NodeGeneric, register_node
36
from ._ffi.runtime_ctypes import TVMType, TypeCode
tqchen committed
37
from . import make as _make
38
from . import generic as _generic
39
from . import _api_internal
tqchen committed
40

41

42 43 44 45 46 47 48 49 50 51 52 53 54
def div_ambiguity_error():
    return RuntimeError(
        "TVM supports multiple types of integer divisions, " +
        "please call div, indexdiv/indexmod, floordiv/floormod " +
        " or truncdiv/truncmod directly to avoid ambiguity in the code.")

def _dtype_is_int(value):
    if isinstance(value, int):
        return True
    return (isinstance(value, ExprOp) and
            TVMType(value.dtype).type_code == TypeCode.INT)


55
class ExprOp(object):
tqchen committed
56
    def __add__(self, other):
57
        return _generic.add(self, other)
tqchen committed
58 59 60 61 62

    def __radd__(self, other):
        return self.__add__(other)

    def __sub__(self, other):
63
        return _generic.subtract(self, other)
tqchen committed
64 65

    def __rsub__(self, other):
66
        return _generic.subtract(other, self)
tqchen committed
67 68

    def __mul__(self, other):
69
        return _generic.multiply(self, other)
tqchen committed
70 71

    def __rmul__(self, other):
72
        return _generic.multiply(other, self)
tqchen committed
73 74

    def __div__(self, other):
75 76
        if _dtype_is_int(self) and _dtype_is_int(other):
            raise div_ambiguity_error()
77
        return _generic.divide(self, other)
tqchen committed
78 79

    def __rdiv__(self, other):
80 81
        if _dtype_is_int(self) and _dtype_is_int(other):
            raise div_ambiguity_error()
82
        return _generic.divide(other, self)
tqchen committed
83 84

    def __truediv__(self, other):
85 86
        if _dtype_is_int(self) and _dtype_is_int(other):
            raise div_ambiguity_error()
87
        return _generic.divide(self, other)
tqchen committed
88 89

    def __rtruediv__(self, other):
90 91
        if _dtype_is_int(self) and _dtype_is_int(other):
            raise div_ambiguity_error()
92
        return _generic.divide(other, self)
tqchen committed
93

94
    def __floordiv__(self, other):
95
        return _generic.floordiv(self, other)
96 97

    def __rfloordiv__(self, other):
98
        return _generic.floordiv(other, self)
99

100
    def __mod__(self, other):
101
        return _make._OpFloorMod(self, other)
102

tqchen committed
103
    def __neg__(self):
104 105
        neg_one = _api_internal._const(-1, self.dtype)
        return self.__mul__(neg_one)
tqchen committed
106

107
    def __lshift__(self, other):
108
        return _make.left_shift(self, other)
109 110

    def __rshift__(self, other):
111
        return _make.right_shift(self, other)
112 113

    def __and__(self, other):
114
        return _make.bitwise_and(self, other)
115 116

    def __or__(self, other):
117
        return _make.bitwise_or(self, other)
118 119

    def __xor__(self, other):
120
        return _make.bitwise_xor(self, other)
121 122 123 124

    def __invert__(self):
        return _make.Call(self.dtype, "bitwise_not", [self], Call.PureIntrinsic, None, 0)

125
    def __lt__(self, other):
126
        return _make._OpLT(self, other)
127 128

    def __le__(self, other):
129
        return _make._OpLE(self, other)
130 131

    def __eq__(self, other):
132
        return EqualOp(self, other)
133 134

    def __ne__(self, other):
135
        return NotEqualOp(self, other)
136 137

    def __gt__(self, other):
138
        return _make._OpGT(self, other)
139 140

    def __ge__(self, other):
141
        return _make._OpGE(self, other)
142

143 144 145 146 147 148 149
    def __nonzero__(self):
        raise ValueError("Cannot use and / or / not operator to Expr, hint: " +
                         "use tvm.all / tvm.any instead")

    def __bool__(self):
        return self.__nonzero__()

150 151 152 153 154 155 156 157 158 159 160 161 162
    def equal(self, other):
        """Build an equal check expression with other expr.

        Parameters
        ----------
        other : Expr
            The other expression

        Returns
        -------
        ret : Expr
            The equality expression.
        """
163
        return _make._OpEQ(self, other)
164

ziheng committed
165
    def astype(self, dtype):
166 167
        """Cast the expression to other type.

ziheng committed
168 169
        Parameters
        ----------
170
        dtype : str
ziheng committed
171 172 173 174 175 176 177
            The type of new expression

        Returns
        -------
        expr : Expr
            Expression with new type
        """
178
        return _generic.cast(self, dtype)
ziheng committed
179

tqchen committed
180

181 182 183 184 185 186 187 188 189 190 191 192 193 194
class EqualOp(NodeGeneric, ExprOp):
    """Deferred equal operator.

    This is used to support sugar that a == b can either
    mean NodeBase.same_as or NodeBase.equal.

    Parameters
    ----------
    a : Expr
        Left operand.

    b : Expr
        Right operand.
    """
195 196 197
    # This class is not manipulated by C++. So use python's identity check function is sufficient
    same_as = object.__eq__

198 199 200 201 202 203 204 205 206 207 208 209
    def __init__(self, a, b):
        self.a = a
        self.b = b

    def __nonzero__(self):
        return self.a.same_as(self.b)

    def __bool__(self):
        return self.__nonzero__()

    def asnode(self):
        """Convert node."""
210
        return _make._OpEQ(self.a, self.b)
211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226


class NotEqualOp(NodeGeneric, ExprOp):
    """Deferred NE operator.

    This is used to support sugar that a != b can either
    mean not NodeBase.same_as or make.NE.

    Parameters
    ----------
    a : Expr
        Left operand.

    b : Expr
        Right operand.
    """
227 228 229
    # This class is not manipulated by C++. So use python's identity check function is sufficient
    same_as = object.__eq__

230 231 232 233 234 235 236 237 238 239 240 241
    def __init__(self, a, b):
        self.a = a
        self.b = b

    def __nonzero__(self):
        return not self.a.same_as(self.b)

    def __bool__(self):
        return self.__nonzero__()

    def asnode(self):
        """Convert node."""
242
        return _make._OpNE(self.a, self.b)
243 244


245
class Expr(ExprOp, NodeBase):
246
    """Base class of all tvm Expressions"""
247
    # In Python3, We have to explicitly tell interpreter to retain __hash__ if we overide __eq__
248 249
    # https://docs.python.org/3.1/reference/datamodel.html#object.__hash__
    __hash__ = NodeBase.__hash__
tqchen committed
250

251

tqchen committed
252 253 254 255 256 257 258 259 260 261 262 263
class ConstExpr(Expr):
    pass

class BinaryOpExpr(Expr):
    pass

class CmpExpr(Expr):
    pass

class LogicalExpr(Expr):
    pass

tqchen committed
264 265
@register_node("Variable")
class Var(Expr):
266 267 268 269 270 271 272 273 274 275 276 277 278 279
    """Symbolic variable.

    Parameters
    ----------
    name : str
        The name

    dtype : int
        The data type
    """
    def __init__(self, name, dtype):
        self.__init_handle_by_constructor__(
            _api_internal._Var, name, dtype)

tqchen committed
280

tqchen committed
281 282
@register_node
class Reduce(Expr):
283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306
    """Reduce node.

    Parameters
    ----------
    combiner : CommReducer
        The combiner.

    src : list of Expr
        The source expression.

    rdom : list of IterVar
        The iteration domain

    condition : Expr
        The reduce condition.

    value_index : int
        The value index.
    """
    def __init__(self, combiner, src, rdom, condition, value_index):
        self.__init_handle_by_constructor__(
            _make.Reduce, combiner, src, rdom,
            condition, value_index)

tqchen committed
307

tqchen committed
308 309
@register_node
class FloatImm(ConstExpr):
310 311 312 313 314 315 316 317 318 319 320 321 322
    """Float constant.

    Parameters
    ----------
    dtype : str
        The data type

    value : float
        The constant value.
    """
    def __init__(self, dtype, value):
        self.__init_handle_by_constructor__(
            _make.FloatImm, dtype, value)
tqchen committed
323 324 325

@register_node
class IntImm(ConstExpr):
326 327 328 329 330 331 332 333 334 335 336 337 338 339
    """Int constant.

    Parameters
    ----------
    dtype : str
        The data type

    value : int
        The constant value.
    """
    def __init__(self, dtype, value):
        self.__init_handle_by_constructor__(
            _make.IntImm, dtype, value)

340 341 342
    def __int__(self):
        return self.value

tqchen committed
343 344 345

@register_node
class UIntImm(ConstExpr):
346 347 348 349 350 351 352 353 354 355 356 357 358 359
    """UInt constant.

    Parameters
    ----------
    dtype : str
        The data type

    value : int
        The constant value.
    """
    def __init__(self, dtype, value):
        self.__init_handle_by_constructor__(
            _make.UIntImm, dtype, value)

tqchen committed
360 361 362

@register_node
class StringImm(ConstExpr):
363 364 365 366 367 368 369 370 371 372 373
    """String constant.

    Parameters
    ----------
    value : str
        The value of the function.
    """
    def __init__(self, value):
        self.__init_handle_by_constructor__(
            _make.StringImm, value)

374 375 376 377 378 379 380 381 382 383
    def __eq__(self, other):
        if isinstance(other, ConstExpr):
            return self.value == other.value
        return self.value == other

    def __ne__(self, other):
        if isinstance(other, ConstExpr):
            return self.value != other.value
        return self.value != other

tqchen committed
384 385 386

@register_node
class Cast(Expr):
387 388 389 390 391 392 393 394 395 396 397 398 399 400
    """Cast expression.

    Parameters
    ----------
    dtype : str
        The data type

    value : Expr
        The value of the function.
    """
    def __init__(self, dtype, value):
        self.__init_handle_by_constructor__(
            _make.Cast, dtype, value)

tqchen committed
401 402 403

@register_node
class Add(BinaryOpExpr):
404 405 406 407 408 409 410 411 412 413 414 415 416 417
    """Add node.

    Parameters
    ----------
    a : Expr
        The left hand operand.

    b : Expr
        The right hand operand.
    """
    def __init__(self, a, b):
        self.__init_handle_by_constructor__(
            _make.Add, a, b)

tqchen committed
418 419 420

@register_node
class Sub(BinaryOpExpr):
421 422 423 424 425 426 427 428 429 430 431 432 433 434
    """Sub node.

    Parameters
    ----------
    a : Expr
        The left hand operand.

    b : Expr
        The right hand operand.
    """
    def __init__(self, a, b):
        self.__init_handle_by_constructor__(
            _make.Sub, a, b)

tqchen committed
435 436 437

@register_node
class Mul(BinaryOpExpr):
438 439 440 441 442 443 444 445 446 447 448 449 450 451
    """Mul node.

    Parameters
    ----------
    a : Expr
        The left hand operand.

    b : Expr
        The right hand operand.
    """
    def __init__(self, a, b):
        self.__init_handle_by_constructor__(
            _make.Mul, a, b)

tqchen committed
452 453 454

@register_node
class Div(BinaryOpExpr):
455 456 457 458 459 460 461 462 463 464 465 466 467 468
    """Div node.

    Parameters
    ----------
    a : Expr
        The left hand operand.

    b : Expr
        The right hand operand.
    """
    def __init__(self, a, b):
        self.__init_handle_by_constructor__(
            _make.Div, a, b)

tqchen committed
469 470 471

@register_node
class Mod(BinaryOpExpr):
472 473 474 475 476 477 478 479 480 481 482 483 484 485
    """Mod node.

    Parameters
    ----------
    a : Expr
        The left hand operand.

    b : Expr
        The right hand operand.
    """
    def __init__(self, a, b):
        self.__init_handle_by_constructor__(
            _make.Mod, a, b)

tqchen committed
486 487

@register_node
488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521
class FloorDiv(BinaryOpExpr):
    """FloorDiv node.

    Parameters
    ----------
    a : Expr
        The left hand operand.

    b : Expr
        The right hand operand.
    """
    def __init__(self, a, b):
        self.__init_handle_by_constructor__(
            _make.FloorDiv, a, b)


@register_node
class FloorMod(BinaryOpExpr):
    """FloorMod node.

    Parameters
    ----------
    a : Expr
        The left hand operand.

    b : Expr
        The right hand operand.
    """
    def __init__(self, a, b):
        self.__init_handle_by_constructor__(
            _make.FloorMod, a, b)


@register_node
tqchen committed
522
class Min(BinaryOpExpr):
523 524 525 526 527 528 529 530 531 532 533 534 535 536
    """Min node.

    Parameters
    ----------
    a : Expr
        The left hand operand.

    b : Expr
        The right hand operand.
    """
    def __init__(self, a, b):
        self.__init_handle_by_constructor__(
            _make.Min, a, b)

tqchen committed
537 538 539

@register_node
class Max(BinaryOpExpr):
540 541 542 543 544 545 546 547 548 549 550 551 552 553
    """Max node.

    Parameters
    ----------
    a : Expr
        The left hand operand.

    b : Expr
        The right hand operand.
    """
    def __init__(self, a, b):
        self.__init_handle_by_constructor__(
            _make.Max, a, b)

tqchen committed
554 555 556

@register_node
class EQ(CmpExpr):
557 558 559 560 561 562 563 564 565 566 567 568 569 570
    """EQ node.

    Parameters
    ----------
    a : Expr
        The left hand operand.

    b : Expr
        The right hand operand.
    """
    def __init__(self, a, b):
        self.__init_handle_by_constructor__(
            _make.EQ, a, b)

tqchen committed
571 572 573

@register_node
class NE(CmpExpr):
574 575 576 577 578 579 580 581 582 583 584 585 586 587
    """NE node.

    Parameters
    ----------
    a : Expr
        The left hand operand.

    b : Expr
        The right hand operand.
    """
    def __init__(self, a, b):
        self.__init_handle_by_constructor__(
            _make.NE, a, b)

tqchen committed
588 589 590

@register_node
class LT(CmpExpr):
591 592 593 594 595 596 597 598 599 600 601 602 603 604
    """LT node.

    Parameters
    ----------
    a : Expr
        The left hand operand.

    b : Expr
        The right hand operand.
    """
    def __init__(self, a, b):
        self.__init_handle_by_constructor__(
            _make.LT, a, b)

tqchen committed
605 606 607

@register_node
class LE(CmpExpr):
608 609 610 611 612 613 614 615 616 617 618 619 620 621
    """LE node.

    Parameters
    ----------
    a : Expr
        The left hand operand.

    b : Expr
        The right hand operand.
    """
    def __init__(self, a, b):
        self.__init_handle_by_constructor__(
            _make.LE, a, b)

tqchen committed
622 623 624

@register_node
class GT(CmpExpr):
625 626 627 628 629 630 631 632 633 634 635 636 637 638
    """GT node.

    Parameters
    ----------
    a : Expr
        The left hand operand.

    b : Expr
        The right hand operand.
    """
    def __init__(self, a, b):
        self.__init_handle_by_constructor__(
            _make.GT, a, b)

tqchen committed
639 640 641

@register_node
class GE(CmpExpr):
642 643 644 645 646 647 648 649 650 651 652 653 654 655
    """GE node.

    Parameters
    ----------
    a : Expr
        The left hand operand.

    b : Expr
        The right hand operand.
    """
    def __init__(self, a, b):
        self.__init_handle_by_constructor__(
            _make.GE, a, b)

tqchen committed
656 657 658

@register_node
class And(LogicalExpr):
659 660 661 662 663 664 665 666 667 668 669 670 671 672
    """And node.

    Parameters
    ----------
    a : Expr
        The left hand operand.

    b : Expr
        The right hand operand.
    """
    def __init__(self, a, b):
        self.__init_handle_by_constructor__(
            _make.And, a, b)

tqchen committed
673 674 675

@register_node
class Or(LogicalExpr):
676 677 678 679 680 681 682 683 684 685 686 687 688 689
    """Or node.

    Parameters
    ----------
    a : Expr
        The left hand operand.

    b : Expr
        The right hand operand.
    """
    def __init__(self, a, b):
        self.__init_handle_by_constructor__(
            _make.Or, a, b)

tqchen committed
690 691 692

@register_node
class Not(LogicalExpr):
693 694 695 696 697 698 699 700 701 702 703
    """Not node.

    Parameters
    ----------
    a : Expr
        The input value
    """
    def __init__(self, a):
        self.__init_handle_by_constructor__(
            _make.Not, a)

tqchen committed
704 705 706

@register_node
class Select(Expr):
707 708
    """Select node.

709 710 711 712 713 714 715
    Note
    ----
    Select may compute both true_value and false_value.
    Use :any:`tvm.if_then_else` instead if you want to
    get a conditional expression that only evaluates
    the correct branch.

716 717 718 719 720 721 722 723 724 725
    Parameters
    ----------
    condition : Expr
        The condition expression.

    true_value : Expr
        The value to take when condition is true.

    false_value : Expr
        The value to take when condition is false.
726

727 728 729 730 731
    """
    def __init__(self, condition, true_value, false_value):
        self.__init_handle_by_constructor__(
            _make.Select, condition, true_value, false_value)

tqchen committed
732 733 734

@register_node
class Load(Expr):
735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754
    """Load node.

    Parameters
    ----------
    dtype : str
        The data type.

    buffer_var : Var
        The buffer variable in the load expression.

    index : Expr
        The index in the load.

    predicate : Expr
        The load predicate.
    """
    def __init__(self, dtype, buffer_var, index, predicate):
        self.__init_handle_by_constructor__(
            _make.Load, dtype, buffer_var, index, predicate)

tqchen committed
755 756 757

@register_node
class Ramp(Expr):
758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774
    """Ramp node.

    Parameters
    ----------
    base : Expr
        The base expression.

    stride : ramp stride
        The stride of the ramp.

    lanes : int
        The lanes of the expression.
    """
    def __init__(self, base, stride, lanes):
        self.__init_handle_by_constructor__(
            _make.Ramp, base, stride, lanes)

tqchen committed
775

tqchen committed
776 777
@register_node
class Broadcast(Expr):
778 779 780 781 782 783 784 785 786 787 788 789 790 791
    """Broadcast node.

    Parameters
    ----------
    value : Expr
        The value of the expression.

    lanes : int
        The lanes of the expression.
    """
    def __init__(self, value, lanes):
        self.__init_handle_by_constructor__(
            _make.Broadcast, value, lanes)

792

tqchen committed
793
@register_node
794
class Shuffle(Expr):
795 796 797 798 799 800 801 802 803 804 805 806 807 808
    """Shuffle node.

    Parameters
    ----------
    vectors : Array of Expr
        The vectors

    indices : Array of indices
        The indices
    """
    def __init__(self, vectors, indices):
        self.__init_handle_by_constructor__(
            _make.Shuffle, vectors, indices)

809 810

@register_node
tqchen committed
811
class Call(Expr):
812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833
    """Call node.

    Parameters
    ----------
    dtype : str
        The return data type

    name : str
        The name of the function

    args : list of Expr
        The input arguments to the call

    call_type : int
        The type of the call

    func : Operation, optional
        Operation if call_type is Halide

    value_index : int
        The output value index
    """
tqchen committed
834 835 836 837 838 839
    Extern = 0
    ExternCPlusPlus = 1
    PureExtern = 2
    Halide = 3
    Intrinsic = 4
    PureIntrinsic = 5
840 841 842
    def __init__(self, dtype, name, args, call_type, func, value_index):
        self.__init_handle_by_constructor__(
            _make.Call, dtype, name, args, call_type, func, value_index)
843

844

tqchen committed
845 846
@register_node
class Let(Expr):
847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862
    """Let node.

    Parameters
    ----------
    var : Var
        The variable in the binding.

    value : Expr
        The value in to be binded.

    body : Expr
        The body expression.
    """
    def __init__(self, var, value, body):
        self.__init_handle_by_constructor__(
            _make.Let, var, value, body)