"""Expression AST Node in TVM.

User do not need to deal with expression AST node directly.
But they can be helpful for developer to do quick proptyping.
While not displayed in the document and python file.
Each expression node have subfields that can be visited from python side.

For example, you can use addexp.a to get the left operand of an Add node.

.. code-block:: python

  x = tvm.var("n")
  y = x + 2
  assert(isinstance(y, tvm.expr.Add))
  assert(y.a == x)
"""
# pylint: disable=missing-docstring
from __future__ import absolute_import as _abs
from ._ffi.node import NodeBase, register_node
from . import make as _make
from . import _api_internal

class ExprOp(object):
    def __add__(self, other):
        return _make.Add(self, other)

    def __radd__(self, other):
        return self.__add__(other)

    def __sub__(self, other):
        return _make.Sub(self, other)

    def __rsub__(self, other):
        return _make.Sub(other, self)

    def __mul__(self, other):
        return _make.Mul(self, other)

    def __rmul__(self, other):
        return _make.Mul(other, self)

    def __div__(self, other):
        return _make.Div(self, other)

    def __rdiv__(self, other):
        return _make.Div(other, self)

    def __truediv__(self, other):
        return self.__div__(other)

    def __rtruediv__(self, other):
        return self.__rdiv__(other)

    def __floordiv__(self, other):
        return self.__div__(other)

    def __rfloordiv__(self, other):
        return self.__rdiv__(other)

    def __mod__(self, other):
        return _make.Mod(self, other)

    def __neg__(self):
        neg_one = _api_internal._const(-1, self.dtype)
        return self.__mul__(neg_one)

    def __lshift__(self, other):
        return _make.Call(self.dtype, "shift_left", [self, other], Call.PureIntrinsic, None, 0)

    def __rshift__(self, other):
        return _make.Call(self.dtype, "shift_right", [self, other], Call.PureIntrinsic, None, 0)

    def __and__(self, other):
        return _make.Call(self.dtype, "bitwise_and", [self, other], Call.PureIntrinsic, None, 0)

    def __or__(self, other):
        return _make.Call(self.dtype, "bitwise_or", [self, other], Call.PureIntrinsic, None, 0)

    def __xor__(self, other):
        return _make.Call(self.dtype, "bitwise_xor", [self, other], Call.PureIntrinsic, None, 0)

    def __invert__(self):
        return _make.Call(self.dtype, "bitwise_not", [self], Call.PureIntrinsic, None, 0)

    def __lt__(self, other):
        return _make.LT(self, other)

    def __le__(self, other):
        return _make.LE(self, other)

    def __eq__(self, other):
        return self.equal(other)

    def __ne__(self, other):
        return _make.NE(self, other)

    def __gt__(self, other):
        return _make.GT(self, other)

    def __ge__(self, other):
        return _make.GE(self, other)

    def __nonzero__(self):
        raise ValueError("Cannot use and / or / not operator to Expr, hint: " +
                         "use tvm.all / tvm.any instead")

    def __bool__(self):
        return self.__nonzero__()

    def equal(self, other):
        """Build an equal check expression with other expr.

        Parameters
        ----------
        other : Expr
            The other expression

        Returns
        -------
        ret : Expr
            The equality expression.
        """
        return _make.EQ(self, other)

    def astype(self, dtype):
        """Cast the expression to other type.

        Parameters
        ----------
        dtype : str
            The type of new expression

        Returns
        -------
        expr : Expr
            Expression with new type
        """
        return _make.static_cast(dtype, self)


class Expr(NodeBase, ExprOp):
    """Base class of all tvm Expressions"""
    pass

class ConstExpr(Expr):
    pass

class BinaryOpExpr(Expr):
    pass

class CmpExpr(Expr):
    pass

class LogicalExpr(Expr):
    pass

@register_node("Variable")
class Var(Expr):
    """Symbolic variable."""
    pass

@register_node
class Reduce(Expr):
    pass

@register_node
class FloatImm(ConstExpr):
    pass

@register_node
class IntImm(ConstExpr):
    pass

@register_node
class UIntImm(ConstExpr):
    pass

@register_node
class StringImm(ConstExpr):
    pass

@register_node
class Cast(Expr):
    pass

@register_node
class Add(BinaryOpExpr):
    pass

@register_node
class Sub(BinaryOpExpr):
    pass

@register_node
class Mul(BinaryOpExpr):
    pass

@register_node
class Div(BinaryOpExpr):
    pass

@register_node
class Mod(BinaryOpExpr):
    pass

@register_node
class Min(BinaryOpExpr):
    pass

@register_node
class Max(BinaryOpExpr):
    pass

@register_node
class EQ(CmpExpr):
    pass

@register_node
class NE(CmpExpr):
    pass

@register_node
class LT(CmpExpr):
    pass

@register_node
class LE(CmpExpr):
    pass

@register_node
class GT(CmpExpr):
    pass

@register_node
class GE(CmpExpr):
    pass

@register_node
class And(LogicalExpr):
    pass

@register_node
class Or(LogicalExpr):
    pass

@register_node
class Not(LogicalExpr):
    pass

@register_node
class Select(Expr):
    pass

@register_node
class Load(Expr):
    pass

@register_node
class Ramp(Expr):
    pass

@register_node
class Broadcast(Expr):
    pass

@register_node
class Shuffle(Expr):
    pass

@register_node
class Call(Expr):
    Extern = 0
    ExternCPlusPlus = 1
    PureExtern = 2
    Halide = 3
    Intrinsic = 4
    PureIntrinsic = 5


@register_node
class Let(Expr):
    pass