# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""Expression AST Node in TVM.

User do not need to deal with expression AST node directly.
But they can be helpful for developer to do quick proptyping.
While not displayed in the document and python file.
Each expression node have subfields that can be visited from python side.

For example, you can use addexp.a to get the left operand of an Add node.

.. code-block:: python

  x = tvm.var("n")
  y = x + 2
  assert(isinstance(y, tvm.expr.Add))
  assert(y.a == x)
"""
# pylint: disable=missing-docstring
from __future__ import absolute_import as _abs
from ._ffi.node import NodeBase, NodeGeneric, register_node
from . import make as _make
from . import generic as _generic
from . import _api_internal


class ExprOp(object):
    def __add__(self, other):
        return _generic.add(self, other)

    def __radd__(self, other):
        return self.__add__(other)

    def __sub__(self, other):
        return _generic.subtract(self, other)

    def __rsub__(self, other):
        return _generic.subtract(other, self)

    def __mul__(self, other):
        return _generic.multiply(self, other)

    def __rmul__(self, other):
        return _generic.multiply(other, self)

    def __div__(self, other):
        return _generic.divide(self, other)

    def __rdiv__(self, other):
        return _generic.divide(other, self)

    def __truediv__(self, other):
        return self.__div__(other)

    def __rtruediv__(self, other):
        return self.__rdiv__(other)

    def __floordiv__(self, other):
        return self.__div__(other)

    def __rfloordiv__(self, other):
        return self.__rdiv__(other)

    def __mod__(self, other):
        return _make._OpMod(self, other)

    def __neg__(self):
        neg_one = _api_internal._const(-1, self.dtype)
        return self.__mul__(neg_one)

    def __lshift__(self, other):
        return _make.left_shift(self, other)

    def __rshift__(self, other):
        return _make.right_shift(self, other)

    def __and__(self, other):
        return _make.bitwise_and(self, other)

    def __or__(self, other):
        return _make.bitwise_or(self, other)

    def __xor__(self, other):
        return _make.bitwise_xor(self, other)

    def __invert__(self):
        return _make.Call(self.dtype, "bitwise_not", [self], Call.PureIntrinsic, None, 0)

    def __lt__(self, other):
        return _make._OpLT(self, other)

    def __le__(self, other):
        return _make._OpLE(self, other)

    def __eq__(self, other):
        return EqualOp(self, other)

    def __ne__(self, other):
        return NotEqualOp(self, other)

    def __gt__(self, other):
        return _make._OpGT(self, other)

    def __ge__(self, other):
        return _make._OpGE(self, other)

    def __nonzero__(self):
        raise ValueError("Cannot use and / or / not operator to Expr, hint: " +
                         "use tvm.all / tvm.any instead")

    def __bool__(self):
        return self.__nonzero__()

    def equal(self, other):
        """Build an equal check expression with other expr.

        Parameters
        ----------
        other : Expr
            The other expression

        Returns
        -------
        ret : Expr
            The equality expression.
        """
        return _make._OpEQ(self, other)

    def astype(self, dtype):
        """Cast the expression to other type.

        Parameters
        ----------
        dtype : str
            The type of new expression

        Returns
        -------
        expr : Expr
            Expression with new type
        """
        return _generic.cast(self, dtype)


class EqualOp(NodeGeneric, ExprOp):
    """Deferred equal operator.

    This is used to support sugar that a == b can either
    mean NodeBase.same_as or NodeBase.equal.

    Parameters
    ----------
    a : Expr
        Left operand.

    b : Expr
        Right operand.
    """
    # This class is not manipulated by C++. So use python's identity check function is sufficient
    same_as = object.__eq__

    def __init__(self, a, b):
        self.a = a
        self.b = b

    def __nonzero__(self):
        return self.a.same_as(self.b)

    def __bool__(self):
        return self.__nonzero__()

    def asnode(self):
        """Convert node."""
        return _make._OpEQ(self.a, self.b)


class NotEqualOp(NodeGeneric, ExprOp):
    """Deferred NE operator.

    This is used to support sugar that a != b can either
    mean not NodeBase.same_as or make.NE.

    Parameters
    ----------
    a : Expr
        Left operand.

    b : Expr
        Right operand.
    """
    # This class is not manipulated by C++. So use python's identity check function is sufficient
    same_as = object.__eq__

    def __init__(self, a, b):
        self.a = a
        self.b = b

    def __nonzero__(self):
        return not self.a.same_as(self.b)

    def __bool__(self):
        return self.__nonzero__()

    def asnode(self):
        """Convert node."""
        return _make._OpNE(self.a, self.b)


class Expr(ExprOp, NodeBase):
    """Base class of all tvm Expressions"""
    # In Python3, We have to explicitly tell interpreter to retain __hash__ if we overide __eq__
    # https://docs.python.org/3.1/reference/datamodel.html#object.__hash__
    __hash__ = NodeBase.__hash__


class ConstExpr(Expr):
    pass

class BinaryOpExpr(Expr):
    pass

class CmpExpr(Expr):
    pass

class LogicalExpr(Expr):
    pass

@register_node("Variable")
class Var(Expr):
    """Symbolic variable.

    Parameters
    ----------
    name : str
        The name

    dtype : int
        The data type
    """
    def __init__(self, name, dtype):
        self.__init_handle_by_constructor__(
            _api_internal._Var, name, dtype)


@register_node
class Reduce(Expr):
    """Reduce node.

    Parameters
    ----------
    combiner : CommReducer
        The combiner.

    src : list of Expr
        The source expression.

    rdom : list of IterVar
        The iteration domain

    condition : Expr
        The reduce condition.

    value_index : int
        The value index.
    """
    def __init__(self, combiner, src, rdom, condition, value_index):
        self.__init_handle_by_constructor__(
            _make.Reduce, combiner, src, rdom,
            condition, value_index)


@register_node
class FloatImm(ConstExpr):
    """Float constant.

    Parameters
    ----------
    dtype : str
        The data type

    value : float
        The constant value.
    """
    def __init__(self, dtype, value):
        self.__init_handle_by_constructor__(
            _make.FloatImm, dtype, value)

@register_node
class IntImm(ConstExpr):
    """Int constant.

    Parameters
    ----------
    dtype : str
        The data type

    value : int
        The constant value.
    """
    def __init__(self, dtype, value):
        self.__init_handle_by_constructor__(
            _make.IntImm, dtype, value)

    def __int__(self):
        return self.value


@register_node
class UIntImm(ConstExpr):
    """UInt constant.

    Parameters
    ----------
    dtype : str
        The data type

    value : int
        The constant value.
    """
    def __init__(self, dtype, value):
        self.__init_handle_by_constructor__(
            _make.UIntImm, dtype, value)


@register_node
class StringImm(ConstExpr):
    """String constant.

    Parameters
    ----------
    value : str
        The value of the function.
    """
    def __init__(self, value):
        self.__init_handle_by_constructor__(
            _make.StringImm, value)

    def __eq__(self, other):
        if isinstance(other, ConstExpr):
            return self.value == other.value
        return self.value == other

    def __ne__(self, other):
        if isinstance(other, ConstExpr):
            return self.value != other.value
        return self.value != other


@register_node
class Cast(Expr):
    """Cast expression.

    Parameters
    ----------
    dtype : str
        The data type

    value : Expr
        The value of the function.
    """
    def __init__(self, dtype, value):
        self.__init_handle_by_constructor__(
            _make.Cast, dtype, value)


@register_node
class Add(BinaryOpExpr):
    """Add node.

    Parameters
    ----------
    a : Expr
        The left hand operand.

    b : Expr
        The right hand operand.
    """
    def __init__(self, a, b):
        self.__init_handle_by_constructor__(
            _make.Add, a, b)


@register_node
class Sub(BinaryOpExpr):
    """Sub node.

    Parameters
    ----------
    a : Expr
        The left hand operand.

    b : Expr
        The right hand operand.
    """
    def __init__(self, a, b):
        self.__init_handle_by_constructor__(
            _make.Sub, a, b)


@register_node
class Mul(BinaryOpExpr):
    """Mul node.

    Parameters
    ----------
    a : Expr
        The left hand operand.

    b : Expr
        The right hand operand.
    """
    def __init__(self, a, b):
        self.__init_handle_by_constructor__(
            _make.Mul, a, b)


@register_node
class Div(BinaryOpExpr):
    """Div node.

    Parameters
    ----------
    a : Expr
        The left hand operand.

    b : Expr
        The right hand operand.
    """
    def __init__(self, a, b):
        self.__init_handle_by_constructor__(
            _make.Div, a, b)


@register_node
class Mod(BinaryOpExpr):
    """Mod node.

    Parameters
    ----------
    a : Expr
        The left hand operand.

    b : Expr
        The right hand operand.
    """
    def __init__(self, a, b):
        self.__init_handle_by_constructor__(
            _make.Mod, a, b)


@register_node
class FloorDiv(BinaryOpExpr):
    """FloorDiv node.

    Parameters
    ----------
    a : Expr
        The left hand operand.

    b : Expr
        The right hand operand.
    """
    def __init__(self, a, b):
        self.__init_handle_by_constructor__(
            _make.FloorDiv, a, b)


@register_node
class FloorMod(BinaryOpExpr):
    """FloorMod node.

    Parameters
    ----------
    a : Expr
        The left hand operand.

    b : Expr
        The right hand operand.
    """
    def __init__(self, a, b):
        self.__init_handle_by_constructor__(
            _make.FloorMod, a, b)


@register_node
class Min(BinaryOpExpr):
    """Min node.

    Parameters
    ----------
    a : Expr
        The left hand operand.

    b : Expr
        The right hand operand.
    """
    def __init__(self, a, b):
        self.__init_handle_by_constructor__(
            _make.Min, a, b)


@register_node
class Max(BinaryOpExpr):
    """Max node.

    Parameters
    ----------
    a : Expr
        The left hand operand.

    b : Expr
        The right hand operand.
    """
    def __init__(self, a, b):
        self.__init_handle_by_constructor__(
            _make.Max, a, b)


@register_node
class EQ(CmpExpr):
    """EQ node.

    Parameters
    ----------
    a : Expr
        The left hand operand.

    b : Expr
        The right hand operand.
    """
    def __init__(self, a, b):
        self.__init_handle_by_constructor__(
            _make.EQ, a, b)


@register_node
class NE(CmpExpr):
    """NE node.

    Parameters
    ----------
    a : Expr
        The left hand operand.

    b : Expr
        The right hand operand.
    """
    def __init__(self, a, b):
        self.__init_handle_by_constructor__(
            _make.NE, a, b)


@register_node
class LT(CmpExpr):
    """LT node.

    Parameters
    ----------
    a : Expr
        The left hand operand.

    b : Expr
        The right hand operand.
    """
    def __init__(self, a, b):
        self.__init_handle_by_constructor__(
            _make.LT, a, b)


@register_node
class LE(CmpExpr):
    """LE node.

    Parameters
    ----------
    a : Expr
        The left hand operand.

    b : Expr
        The right hand operand.
    """
    def __init__(self, a, b):
        self.__init_handle_by_constructor__(
            _make.LE, a, b)


@register_node
class GT(CmpExpr):
    """GT node.

    Parameters
    ----------
    a : Expr
        The left hand operand.

    b : Expr
        The right hand operand.
    """
    def __init__(self, a, b):
        self.__init_handle_by_constructor__(
            _make.GT, a, b)


@register_node
class GE(CmpExpr):
    """GE node.

    Parameters
    ----------
    a : Expr
        The left hand operand.

    b : Expr
        The right hand operand.
    """
    def __init__(self, a, b):
        self.__init_handle_by_constructor__(
            _make.GE, a, b)


@register_node
class And(LogicalExpr):
    """And node.

    Parameters
    ----------
    a : Expr
        The left hand operand.

    b : Expr
        The right hand operand.
    """
    def __init__(self, a, b):
        self.__init_handle_by_constructor__(
            _make.And, a, b)


@register_node
class Or(LogicalExpr):
    """Or node.

    Parameters
    ----------
    a : Expr
        The left hand operand.

    b : Expr
        The right hand operand.
    """
    def __init__(self, a, b):
        self.__init_handle_by_constructor__(
            _make.Or, a, b)


@register_node
class Not(LogicalExpr):
    """Not node.

    Parameters
    ----------
    a : Expr
        The input value
    """
    def __init__(self, a):
        self.__init_handle_by_constructor__(
            _make.Not, a)


@register_node
class Select(Expr):
    """Select node.

    Note
    ----
    Select may compute both true_value and false_value.
    Use :any:`tvm.if_then_else` instead if you want to
    get a conditional expression that only evaluates
    the correct branch.

    Parameters
    ----------
    condition : Expr
        The condition expression.

    true_value : Expr
        The value to take when condition is true.

    false_value : Expr
        The value to take when condition is false.

    """
    def __init__(self, condition, true_value, false_value):
        self.__init_handle_by_constructor__(
            _make.Select, condition, true_value, false_value)


@register_node
class Load(Expr):
    """Load node.

    Parameters
    ----------
    dtype : str
        The data type.

    buffer_var : Var
        The buffer variable in the load expression.

    index : Expr
        The index in the load.

    predicate : Expr
        The load predicate.
    """
    def __init__(self, dtype, buffer_var, index, predicate):
        self.__init_handle_by_constructor__(
            _make.Load, dtype, buffer_var, index, predicate)


@register_node
class Ramp(Expr):
    """Ramp node.

    Parameters
    ----------
    base : Expr
        The base expression.

    stride : ramp stride
        The stride of the ramp.

    lanes : int
        The lanes of the expression.
    """
    def __init__(self, base, stride, lanes):
        self.__init_handle_by_constructor__(
            _make.Ramp, base, stride, lanes)


@register_node
class Broadcast(Expr):
    """Broadcast node.

    Parameters
    ----------
    value : Expr
        The value of the expression.

    lanes : int
        The lanes of the expression.
    """
    def __init__(self, value, lanes):
        self.__init_handle_by_constructor__(
            _make.Broadcast, value, lanes)


@register_node
class Shuffle(Expr):
    """Shuffle node.

    Parameters
    ----------
    vectors : Array of Expr
        The vectors

    indices : Array of indices
        The indices
    """
    def __init__(self, vectors, indices):
        self.__init_handle_by_constructor__(
            _make.Shuffle, vectors, indices)


@register_node
class Call(Expr):
    """Call node.

    Parameters
    ----------
    dtype : str
        The return data type

    name : str
        The name of the function

    args : list of Expr
        The input arguments to the call

    call_type : int
        The type of the call

    func : Operation, optional
        Operation if call_type is Halide

    value_index : int
        The output value index
    """
    Extern = 0
    ExternCPlusPlus = 1
    PureExtern = 2
    Halide = 3
    Intrinsic = 4
    PureIntrinsic = 5
    def __init__(self, dtype, name, args, call_type, func, value_index):
        self.__init_handle_by_constructor__(
            _make.Call, dtype, name, args, call_type, func, value_index)


@register_node
class Let(Expr):
    """Let node.

    Parameters
    ----------
    var : Var
        The variable in the binding.

    value : Expr
        The value in to be binded.

    body : Expr
        The body expression.
    """
    def __init__(self, var, value, body):
        self.__init_handle_by_constructor__(
            _make.Let, var, value, body)