# pylint: disable=unused-variable
"""Definition of task function.

Task can be constructed from tuple of func, args, and kwargs.
func is a state-less function, or a string that
registers the standard task.
"""

import numpy as np

from ... import tensor, expr, container, target as _target

from ..util import get_const_int, get_const_tuple, get_func_name
from .dispatcher import DispatchContext, ApplyConfig, dispatcher
from .space import ConfigSpace

def _raise_error(*args, **kwargs):  # pylint: disable=unused-argument
    raise RuntimeError("The function of this task is not found. Possibly the function "
                       "of this task is registered in another python file "
                       "which is not imported in this run")

class Task(object):
    """A Tunable Task

    Parameters
    ----------
    name: str
        The name of the task.
    args: Tuple
        Positional argument of func
    """
    def __init__(self, name, args):
        self.name = name
        self.args = args
        self.kwargs = {}  # currently unused

        # init null config space
        self.config_space = None
        self.func = TASK_TABLE.get(name, _raise_error)

        # auxiliary info, available after `init_space` is called
        self.workload = None
        self.flop = None
        self.target = None
        self.target_host = None

    def instantiate(self, config):
        """Instantiate this task function (template) with a config.
        Returns corresponding schedule.

        Parameters
        ----------
        config: template.ConfigEntity
            parameter config for this template

        Returns
        -------
        sch: tvm.schedule.Schedule
            The tvm schedule
        arg_bufs: Array of tvm.tensor.Tensor
            The input/output buffers
        """
        config.flop = 0
        with ApplyConfig(config):
            sch, arg_bufs = self.func(*self.args, **self.kwargs)
        if not self.flop:
            config.flop = config.flop or compute_flop(sch)
            self.flop = config.flop
        return sch, arg_bufs

    def __getstate__(self):
        # custom pickle implementation is required for
        # some unpickable local task functions.
        # So we only pickle the name of the function
        # and restore the function by name when unpickling it.
        return {
            "name": self.name,
            "args": self.args,
            "kwargs": self.kwargs,
            "config_space": self.config_space,
            "workload": self.workload,
            "flop": self.flop,
            "target": self.target,
            "target_host": self.target_host
        }

    def __setstate__(self, state):
        self.name = state["name"]
        self.args = state["args"]
        self.kwargs = state["kwargs"]
        self.config_space = state["config_space"]
        self.func = TASK_TABLE.get(state["name"], _raise_error)
        self.workload = state["workload"]
        self.flop = state["flop"]
        self.target = state["target"]
        self.target_host = state["target_host"]

    def __repr__(self):
        return "Task(func_name=%s, args=%s, kwargs=%s, workload=%s)" % (
            self.name, self.args, self.kwargs, self.workload
        )

TASK_TABLE = {
}

def register(name, func=None, override=False):
    """Register a task function.

    Parameters
    ----------
    name : str
        The name to identify the task.
    func : callable
        The function to be registered.
    override : bool
        Whether override existing registration.

    Returns
    -------
    func: callable
        The registered function
    """
    def _do_reg(myf):
        if name in TASK_TABLE and not override:
            raise ValueError(
                "Key %s is already registered" % name)
        TASK_TABLE[name] = myf
        return myf
    if func:
        return _do_reg(func)
    return _do_reg

def create(func_name, args, target, target_host=None, template_key=None):
    """Create a tuning task and initialize its search space

    Parameters
    ----------
    func_name : str or callable
        The task function
    args : List
        Positional arguments
    target : Target
        The compilation target
    target_host: Target, optional
        The compilation target for host side

    Returns
    -------
    tsk: Task
        a task object
    """
    if callable(func_name):
        # register this function if it is not registered before
        func = func_name
        func_name = func.func_name if hasattr(func, 'func_name') else func.__name__
        if func_name in TASK_TABLE:
            assert func == TASK_TABLE[func_name], "Find name conflict in task registration. " \
                                                  "Consider to choose another name for this task"
        else:
            register(func_name, func=func)

    func = TASK_TABLE[func_name]
    ret = Task(func_name, args)

    if isinstance(target, str):
        target = _target.create(target)

    # init config space
    ret.config_space = ConfigSpace()
    ret.config_space.template_key = template_key or ""

    ctx = ApplyConfig(ret.config_space)
    with ctx:
        with target:
            sch, _ = func(*args)
            ret.config_space.code_hash = getattr(sch, 'code_hash', None)

    ret.workload = ctx.workload
    ret.flop = ret.config_space.flop or compute_flop(sch)
    ret.target = target
    ret.target_host = target_host

    return ret

def args_to_workload(x):
    """Convert argument list to hashable workload tuple.
    This function will convert list to tuple, tvm node to python value and
    flatten tvm.tensor.Tensor to a tuple

    Parameters
    ----------
    x: primitive hashable types or tensor.Tensor
        The original value

    Returns
    -------
    ret: hashable
        The hashable value
    """
    if isinstance(x, tensor.Tensor):
        return get_const_tuple(x.shape) + (x.dtype, )
    elif isinstance(x, (tuple, list, container.Array)):
        return tuple([args_to_workload(a) for a in x])
    elif isinstance(x, (str, int, float, np.int, np.float)):
        return x
    elif isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)):
        return x.value
    elif x is None:
        return None
    else:
        raise RuntimeError('Do not support type "%s" in argument. Consider to use'
                           'primitive types only' % type(x))

def template(func):
    """
    Decorate a function as a tunable schedule template

    Parameters
    ----------
    func: callable
        A callable template function.
        Its argument should be hashable values.
        Its return value should be a Tuple(Schedule, Array of Tensor)

    Returns
    -------
    func: callable
        The decorated function

    Examples
    --------
    The following code is a tunable template for a blocked matrix multiplication

    .. code-block:: python

        @autotvm.template
        def matmul(N, L, M, dtype):
            A = tvm.placeholder((N, L), name='A', dtype=dtype)
            B = tvm.placeholder((L, M), name='B', dtype=dtype)

            k = tvm.reduce_axis((0, L), name='k')
            C = tvm.compute((N, M), lambda i, j: tvm.sum(A[i, k] * B[k, j], axis=k), name='C')
            s = tvm.create_schedule(C.op)

            # schedule
            y, x = s[C].op.axis
            k = s[C].op.reduce_axis[0]

            ##### define space begin #####
            cfg = autotvm.get_config()
            cfg.define_split("tile_y", y, num_outputs=2)
            cfg.define_split("tile_x", x, num_outputs=2)
            ##### define space end #####

            # schedule according to config
            yo, yi = cfg["tile_y"].apply(s, C, y)
            xo, xi = cfg["tile_x"].apply(s, C, x)

            s[C].reorder(yo, xo, k, yi, xi)

            return s, [A, B, C]
    """
    # pylint: disable=unused-variable

    fname = get_func_name(func)

    @register(fname)
    @dispatcher
    def config_dispatcher(*args, **kwargs):
        assert not kwargs, "Do not support kwargs in template function call"
        return (fname, ) + args_to_workload(args)

    @config_dispatcher.register("")
    def template_call(cfg, *args, **kwargs):
        assert not kwargs, "Do not support kwargs in template function call"
        with ApplyConfig(cfg):
            return func(*args, **kwargs)

    config_dispatcher.func_name = fname
    return config_dispatcher

def get_config():
    """Get current config object

    Returns
    -------
    cfg: ConfigSpace or ConfigEntity
        The current config
    """
    return DispatchContext.current.query(None, None)

class FlopCalculationError(RuntimeError):
    """Error happens when estimating FLOP for a compute op"""
    pass

def compute_flop(sch):
    """Calculate number of FLOP (floating number operations) of the compute ops in a schedule

    Parameters
    ----------
    sch: tvm.schedule.Schedule
        schedule

    Returns
    -------
    flop: int
        number of FLOP in this schedule
    """
    def _prod_length(axes):
        """compute product of the lengths of a list of axes"""
        try:
            num_iter = int(np.prod([get_const_int(axis.dom.extent) for axis in axes]))
        except ValueError:
            raise FlopCalculationError("The length of axis is not constant. ")
        return num_iter

    def _count_flop(exp):
        """compute flop for a single expression"""
        if isinstance(exp, expr.Reduce):
            num_iter = _prod_length(exp.axis)
            combiner = exp.combiner.result
            source = exp.source
            if len(combiner) != 1:
                raise FlopCalculationError("Found multiple output in the combiner of reduce op")
            if len(source) != 1:
                raise FlopCalculationError("Found multiple output in the source of reduce op")
            return num_iter * (_count_flop(combiner[0]) + _count_flop(source[0]))
        elif isinstance(exp, (expr.FloatImm, expr.IntImm, expr.UIntImm)):
            return 0
        elif isinstance(exp, expr.Cast):
            return _count_flop(exp.value)
        elif isinstance(exp, expr.Var):
            return 0
        elif isinstance(exp, (expr.Add, expr.Sub, expr.Mul, expr.Div, expr.Mod,
                              expr.Max, expr.Min,
                              expr.EQ, expr.NE, expr.LT, expr.LE, expr.GT, expr.GE,
                              expr.And, expr.Or, expr.Not)):
            base = 1 if "float" in exp.a.dtype else 0

            if isinstance(exp, expr.Not):  # unary
                return base + _count_flop(exp.a)

            return base + _count_flop(exp.a) + _count_flop(exp.b)
        elif isinstance(exp, expr.Select):
            return _count_flop(exp.condition) + max(_count_flop(exp.true_value),
                                                    _count_flop(exp.false_value))
        elif isinstance(exp, expr.Call):
            return sum([_count_flop(x) for x in exp.args])
        else:
            raise FlopCalculationError("Found unsupported operator in the compute expr")

    def traverse(ops):
        """accumulate flops"""
        ret = 0
        for op in ops:
            if isinstance(op, tensor.ComputeOp):
                num_element = _prod_length(op.axis)

                body = op.body
                if len(body) != 1:
                    raise FlopCalculationError("Found multiple output in the compute")
                exp = body[0]

                ret += num_element * _count_flop(exp)
                ret += traverse([t.op for t in op.input_tensors])

            elif isinstance(op, tensor.PlaceholderOp):
                pass
            else:
                raise FlopCalculationError("Only support tvm.compute currently. "
                                           "Other ops like tvm.scan is not supported")
        return ret

    try:
        ret = traverse(sch.outputs)
    except FlopCalculationError as exc:
        raise RuntimeError("FLOP estimator fails for this operator. Error msg: "
                           + str(exc) + ". Please use `cfg.add_flop` to manually set "
                                        "FLOP for this operator")

    if ret == 0:
        raise RuntimeError("Cannot find float number operation in this operator. "
                           "Please use `cfg.add_flop` to manually set "
                           "FLOP for this operator")
    return ret