expr.py 4.54 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
"""Expression AST Node in TVM.

User do not need to deal with expression AST node directly.
But they can be helpful for developer to do quick proptyping.
While not displayed in the document and python file.
Each expression node have subfields that can be visited from python side.

For example, you can use addexp.a to get the left operand of an Add node.

.. code-block:: python

  x = tvm.var("n")
  y = x + 2
  assert(isinstance(y, tvm.expr.Add))
  assert(y.a == x)
"""
17
# pylint: disable=missing-docstring
tqchen committed
18
from __future__ import absolute_import as _abs
19
from ._ffi.node import NodeBase, register_node
tqchen committed
20
from . import make as _make
tqchen committed
21

22
class ExprOp(object):
tqchen committed
23
    def __add__(self, other):
tqchen committed
24
        return _make.Add(self, other)
tqchen committed
25 26 27 28 29

    def __radd__(self, other):
        return self.__add__(other)

    def __sub__(self, other):
tqchen committed
30
        return _make.Sub(self, other)
tqchen committed
31 32

    def __rsub__(self, other):
tqchen committed
33
        return _make.Sub(other, self)
tqchen committed
34 35

    def __mul__(self, other):
tqchen committed
36
        return _make.Mul(self, other)
tqchen committed
37 38

    def __rmul__(self, other):
tqchen committed
39
        return _make.Mul(other, self)
tqchen committed
40 41

    def __div__(self, other):
tqchen committed
42
        return _make.Div(self, other)
tqchen committed
43 44

    def __rdiv__(self, other):
tqchen committed
45
        return _make.Div(other, self)
tqchen committed
46 47 48 49 50 51 52

    def __truediv__(self, other):
        return self.__div__(other)

    def __rtruediv__(self, other):
        return self.__rdiv__(other)

53 54 55 56 57 58
    def __floordiv__(self, other):
        return self.__div__(other)

    def __rfloordiv__(self, other):
        return self.__rdiv__(other)

59 60 61
    def __mod__(self, other):
        return _make.Mod(self, other)

tqchen committed
62 63 64
    def __neg__(self):
        return self.__mul__(-1)

65 66 67 68 69 70 71
    def __lt__(self, other):
        return _make.LT(self, other)

    def __le__(self, other):
        return _make.LE(self, other)

    def __eq__(self, other):
72
        return self.equal(other)
73 74 75 76 77 78 79 80 81 82

    def __ne__(self, other):
        return _make.NE(self, other)

    def __gt__(self, other):
        return _make.GT(self, other)

    def __ge__(self, other):
        return _make.GE(self, other)

83 84 85 86 87 88 89
    def __nonzero__(self):
        raise ValueError("Cannot use and / or / not operator to Expr, hint: " +
                         "use tvm.all / tvm.any instead")

    def __bool__(self):
        return self.__nonzero__()

90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
    def equal(self, other):
        """Build an equal check expression with other expr.

        Parameters
        ----------
        other : Expr
            The other expression

        Returns
        -------
        ret : Expr
            The equality expression.
        """
        return _make.EQ(self, other)

ziheng committed
105
    def astype(self, dtype):
106 107
        """Cast the expression to other type.

ziheng committed
108 109
        Parameters
        ----------
110
        dtype : str
ziheng committed
111 112 113 114 115 116 117
            The type of new expression

        Returns
        -------
        expr : Expr
            Expression with new type
        """
118
        return _make.static_cast(dtype, self)
ziheng committed
119

tqchen committed
120

121
class Expr(NodeBase, ExprOp):
122
    """Base class of all tvm Expressions"""
tqchen committed
123 124
    pass

tqchen committed
125 126 127 128 129 130 131 132 133 134 135 136
class ConstExpr(Expr):
    pass

class BinaryOpExpr(Expr):
    pass

class CmpExpr(Expr):
    pass

class LogicalExpr(Expr):
    pass

tqchen committed
137 138
@register_node("Variable")
class Var(Expr):
139
    """Symbolic variable."""
tqchen committed
140
    pass
tqchen committed
141

tqchen committed
142 143
@register_node
class Reduce(Expr):
tqchen committed
144 145
    pass

tqchen committed
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
@register_node
class FloatImm(ConstExpr):
    pass

@register_node
class IntImm(ConstExpr):
    pass

@register_node
class UIntImm(ConstExpr):
    pass

@register_node
class StringImm(ConstExpr):
    pass

@register_node
class Cast(Expr):
    pass

@register_node
class Add(BinaryOpExpr):
    pass

@register_node
class Sub(BinaryOpExpr):
    pass

@register_node
class Mul(BinaryOpExpr):
    pass

@register_node
class Div(BinaryOpExpr):
    pass

@register_node
class Mod(BinaryOpExpr):
    pass

@register_node
class Min(BinaryOpExpr):
    pass

@register_node
class Max(BinaryOpExpr):
    pass

@register_node
class EQ(CmpExpr):
    pass

@register_node
class NE(CmpExpr):
    pass

@register_node
class LT(CmpExpr):
    pass

@register_node
class LE(CmpExpr):
    pass

@register_node
class GT(CmpExpr):
    pass

@register_node
class GE(CmpExpr):
    pass

@register_node
class And(LogicalExpr):
    pass

@register_node
class Or(LogicalExpr):
    pass

@register_node
class Not(LogicalExpr):
    pass

@register_node
class Select(Expr):
    pass

@register_node
class Load(Expr):
    pass

@register_node
class Ramp(Expr):
    pass
tqchen committed
241

tqchen committed
242 243
@register_node
class Broadcast(Expr):
244
    pass
245

tqchen committed
246
@register_node
247 248 249 250
class Shuffle(Expr):
    pass

@register_node
tqchen committed
251
class Call(Expr):
tqchen committed
252 253 254 255 256 257
    Extern = 0
    ExternCPlusPlus = 1
    PureExtern = 2
    Halide = 3
    Intrinsic = 4
    PureIntrinsic = 5
258

259

tqchen committed
260 261
@register_node
class Let(Expr):
262
    pass