# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""Statement AST Node in TVM.

Each statement node have subfields that can be visited from python side.

.. code-block:: python

    x = tvm.tir.Var("n", "int32")
    a = tvm.tir.Var("array", "handle")
    st = tvm.tir.stmt.Store(a, x + 1, 1)
    assert isinstance(st, tvm.tir.stmt.Store)
    assert(st.buffer_var == a)
"""
import tvm._ffi

from tvm.runtime import Object
from . import _ffi_api


class Stmt(Object):
    """Base class of all the statements."""


@tvm._ffi.register_object
class LetStmt(Stmt):
    """LetStmt node.

    Parameters
    ----------
    var : Var
        The variable in the binding.

    value : PrimExpr
        The value in to be binded.

    body : Stmt
        The body statement.
    """
    def __init__(self, var, value, body):
        self.__init_handle_by_constructor__(
            _ffi_api.LetStmt, var, value, body)


@tvm._ffi.register_object
class AssertStmt(Stmt):
    """AssertStmt node.

    Parameters
    ----------
    condition : PrimExpr
        The assert condition.

    message : PrimExpr
        The error message.

    body : Stmt
        The body statement.
    """
    def __init__(self, condition, message, body):
        self.__init_handle_by_constructor__(
            _ffi_api.AssertStmt, condition, message, body)


@tvm._ffi.register_object
class ProducerConsumer(Stmt):
    """ProducerConsumer node.

    Parameters
    ----------
    func : Operation
        The Operation.

    is_producer : bool
        Whether if the node is producer.

    body : Stmt
        The body statement.
    """
    def __init__(self, func, is_producer, body):
        self.__init_handle_by_constructor__(
            _ffi_api.ProducerConsumer, func, is_producer, body)


@tvm._ffi.register_object
class For(Stmt):
    """For node.

    Parameters
    ----------
    loop_var : Var
        The loop variable.

    min_val : PrimExpr
        The begining value.

    extent : PrimExpr
        The length of the loop.

    for_type : int
        The for type.

    device_api : int
        The device api type.

    body : Stmt
        The body statement.
    """
    Serial = 0
    Parallel = 1
    Vectorized = 2
    Unrolled = 3
    def __init__(self,
                 loop_var,
                 min_val,
                 extent,
                 for_type,
                 device_api,
                 body):
        self.__init_handle_by_constructor__(
            _ffi_api.For, loop_var, min_val, extent,
            for_type, device_api, body)


@tvm._ffi.register_object
class Store(Stmt):
    """Store node.

    Parameters
    ----------
    buffer_var : Var
        The buffer Variable.

    value : PrimExpr
        The value we want to store.

    index : PrimExpr
        The index in the store expression.

    predicate : PrimExpr
        The store predicate.
    """
    def __init__(self, buffer_var, value, index, predicate=None):
        args = [] if predicate is None else [predicate]
        self.__init_handle_by_constructor__(
            _ffi_api.Store, buffer_var, value, index, *args)


@tvm._ffi.register_object
class BufferStore(Stmt):
    """Buffer store node.

    Parameters
    ----------
    buffer : Buffer
        The buffer.

    value : PrimExpr
        The value we to be stored.

    indices : List[PrimExpr]
        The indices location to be stored.
    """
    def __init__(self, buffer, value, indices):
        self.__init_handle_by_constructor__(
            _ffi_api.BufferStore, buffer, value, indices)


@tvm._ffi.register_object
class Provide(Stmt):
    """Provide node.

    Parameters
    ----------
    func : Operation
        The operation to create the function.

    value_index : int
        The output value index

    value : PrimExpr
        The value to be stored.

    args : list of Expr
        The index arguments of the Provide.
    """
    def __init__(self, func, value_index, value, args):
        self.__init_handle_by_constructor__(
            _ffi_api.Provide, func, value_index, value, args)


@tvm._ffi.register_object
class Allocate(Stmt):
    """Allocate node.

    Parameters
    ----------
    buffer_var : Var
        The buffer variable.

    dtype : str
        The data type of the buffer.

    extents : list of Expr
        The extents of the allocate

    condition : PrimExpr
        The condition.

    body : Stmt
        The body statement.
    """
    def __init__(self,
                 buffer_var,
                 dtype,
                 extents,
                 condition,
                 body):
        self.__init_handle_by_constructor__(
            _ffi_api.Allocate, buffer_var, dtype,
            extents, condition, body)


@tvm._ffi.register_object
class AttrStmt(Stmt):
    """AttrStmt node.

    Parameters
    ----------
    node : Node
        The node to annotate the attribute

    attr_key : str
        Attribute type key.

    value : PrimExpr
        The value of the attribute

    body : Stmt
        The body statement.
    """
    def __init__(self, node, attr_key, value, body):
        self.__init_handle_by_constructor__(
            _ffi_api.AttrStmt, node, attr_key, value, body)


@tvm._ffi.register_object
class Free(Stmt):
    """Free node.

    Parameters
    ----------
    buffer_var : Var
        The buffer variable.
    """
    def __init__(self, buffer_var):
        self.__init_handle_by_constructor__(
            _ffi_api.Free, buffer_var)


@tvm._ffi.register_object
class Realize(Stmt):
    """Realize node.

    Parameters
    ----------
    func : Operation
        The operation to create the function.

    value_index : int
        The output value index

    dtype : str
        The data type of the operation.

    bounds : list of range
        The bound of realize

    condition : PrimExpr
        The realize condition.

    body : Stmt
        The realize body
    """
    def __init__(self,
                 func,
                 value_index,
                 dtype,
                 bounds,
                 condition,
                 body):
        self.__init_handle_by_constructor__(
            _ffi_api.Realize, func, value_index, dtype,
            bounds, condition, body)


@tvm._ffi.register_object
class SeqStmt(Stmt):
    """Sequence of statements.

    Parameters
    ----------
    seq : List[Stmt]
        The statements
    """
    def __init__(self, seq):
        self.__init_handle_by_constructor__(
            _ffi_api.SeqStmt, seq)

    def __getitem__(self, i):
        return self.seq[i]

    def __len__(self):
        return len(self.seq)


@tvm._ffi.register_object
class IfThenElse(Stmt):
    """IfThenElse node.

    Parameters
    ----------
    condition : PrimExpr
        The expression

    then_case : Stmt
        The statement to execute if condition is true.

    else_case : Stmt
        The statement to execute if condition is false.
    """
    def __init__(self, condition, then_case, else_case):
        self.__init_handle_by_constructor__(
            _ffi_api.IfThenElse, condition, then_case, else_case)


@tvm._ffi.register_object
class Evaluate(Stmt):
    """Evaluate node.

    Parameters
    ----------
    value : PrimExpr
        The expression to be evalued.
    """
    def __init__(self, value):
        self.__init_handle_by_constructor__(
            _ffi_api.Evaluate, value)


@tvm._ffi.register_object
class Prefetch(Stmt):
    """Prefetch node.

    Parameters
    ----------
    func : Operation
        The operation to create the function.

    value_index : int
        The output value index

    dtype : str
        The data type to be prefetched.

    bounds : list of Range
        The bounds to be prefetched.
    """
    def __init__(self, func, value_index, dtype, bounds):
        self.__init_handle_by_constructor__(
            _ffi_api.Prefetch, func, value_index, dtype, bounds)


@tvm._ffi.register_object
class LoweredFunc(Object):
    """Represent a LoweredFunc in TVM."""
    MixedFunc = 0
    HostFunc = 1
    DeviceFunc = 2


def stmt_seq(*args):
    """Make sequence of statements

    Parameters
    ----------
    args : list of Expr or Var
        List of statements to be combined as sequence.

    Returns
    -------
    stmt : Stmt
        The combined statement.
    """
    ret = []
    for value in args:
        if not isinstance(value, Stmt):
            value = Evaluate(value)
        ret.append(value)
    if len(ret) == 1:
        return ret[0]
    return SeqStmt(ret)


def stmt_list(stmt):
    """Make list of stmt from blocks.

    Parameters
    ----------
    stmt : A block statement

    Returns
    -------
    stmt_list : list of Stmt
         The unpacked list of statements
    """
    if isinstance(stmt, SeqStmt):
        res = []
        for x in stmt:
            res += stmt_list(x)
        return res
    if isinstance(stmt, ProducerConsumer):
        return stmt_list(stmt.body)
    return [stmt]