# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Template dispatcher module.

A dispatcher is a function that can contains multiple behaviors.
Its specific behavior is can be controlled by DispatchContext.

DispatchContext is used in two ways, usually via different implementation
of the DispatchContext base class.

- During search, we can use it to pass the current proposal from tuner.
- During evaluation, we can use it to set pick the best policy.
"""
# pylint: disable=invalid-name

from __future__ import absolute_import as _abs

import logging

import numpy as np
from decorator import decorate

from tvm import target as _target

from .space import FallbackConfigEntity

logger = logging.getLogger('autotvm')


class DispatchContext(object):
    """
    Base class of dispatch context.

    DispatchContext enables the target and workload
    specific dispatch mechanism for templates.
    """
    current = None

    def __init__(self):
        self._old_ctx = DispatchContext.current

    def query(self, target, workload):
        """
        Query the context to get the specific config for a template.
        If cannot find the result inside this context, this function will query it
        from the upper contexts.

        Parameters
        ----------
        target: Target
            The current target
        workload : Workload
            The current workload.

        Returns
        -------
        cfg : ConfigSpace
            The specific configuration.
        """
        ret = self._query_inside(target, workload)
        if ret is None:
            ret = self._old_ctx.query(target, workload)
        return ret

    def update(self, target, workload, cfg):
        """
        Update context with a specific config.

        Parameters
        ----------
        target: Target
            The current target
        workload : Workload
            The current workload.
        cfg : ConfigSpace
            The specific configuration.

        Note
        ----
        This interface is for cases when TVM decides to replace an operator in the graph.
        For example, `AlterOpLayout` pass (enables when `opt_level = 3`) replaces `NCHW`
        convolution with `NCHW[x]c` implementation on x86 CPUs.
        Thus in TOPI, we first query schedule using original `NCHW` workload,
        then update the dispatcher with the new `NCHW[x]c` workload.
        So that later on, `NCHW[x]c` convolution can get schedule from the dispatcher using
        its own workload directly.

        .. code-block:: python

            @conv2d_alter_layout.register("cpu")
            def _alter_conv2d_layout(attrs, inputs, tinfo):
                workload = get_conv2d_workload(...)
                dispatch_ctx = autotvm.task.DispatchContext.current
                target = tvm.target.current_target()
                config = dispatch_ctx.query(target, workload)

                # Get conv2d_NCHWc workload from config
                # new_workload = ...
                # new_inputs = ...
                # new_attrs = ...

                # Store altered operator's config
                dispatch_ctx.update(target, new_workload, config)
                return sym.contrib.conv2d_NCHWc(*new_inputs, **new_attrs)

        We directly store `config` back because `conv2d_NCHW` and `conv2d_NCHWc`
        share the same schedule parameters.
        One can construct a new `ConfigEntity` if this is not the case.
        """
        raise NotImplementedError()

    def _query_inside(self, target, workload):
        """
        Query the context to get the specific config for a template.
        This function only query config inside this context.

        Parameters
        ----------
        target: Target
            The current target
        workload : Workload
            The current workload.

        Returns
        -------
        cfg : ConfigSpace
            The specific configuration.
        """
        raise NotImplementedError()

    def __enter__(self):
        self._old_ctx = DispatchContext.current
        DispatchContext.current = self
        return self

    def __exit__(self, ptype, value, trace):
        DispatchContext.current = self._old_ctx


def dispatcher(fworkload):
    """Wrap a workload dispatcher function.

    Parameters
    ----------
    fworkload : function
        The workload extraction function from arguments.

    Returns
    -------
    fdispatcher : function
        A wrapped dispatcher function, which will
        dispatch based on DispatchContext and
        the current workload.
    """
    dispatch_dict = {}
    func_name = fworkload.__name__

    def register(key, func=None, override=False):
        """Register template function.

        Parameters
        ----------
        key : str or List of str
            The template key to identify the template
            under this dispatcher.
        func : function
            The function to be registered.
            The first argument of the function is always
            cfg returned by DispatchContext,
            the rest arguments are the same as the fworkload.
        override : bool
            Whether override existing registration.

        Returns
        -------
        The register function if necessary.
        """
        if isinstance(key, str):
            key = [key]

        def _do_reg(myf):
            for x in key:
                if x in dispatch_dict and not override:
                    raise ValueError(
                        "Key %s is already registered for %s" % (x, func_name))
                dispatch_dict[x] = myf
            return myf

        if func:
            return _do_reg(func)
        return _do_reg

    def dispatch_func(func, *args, **kwargs):
        """The wrapped dispatch function"""
        tgt = _target.current_target()
        workload = func(*args, **kwargs)
        cfg = DispatchContext.current.query(tgt, workload)
        if cfg.is_fallback and not cfg.template_key:
            # first try 'direct' template
            if 'direct' in dispatch_dict:
                return dispatch_dict['direct'](cfg, *args, **kwargs)
            # otherwise pick a random template
            for v in dispatch_dict.values():
                return v(cfg, *args, **kwargs)
        else:
            return dispatch_dict[cfg.template_key](cfg, *args, **kwargs)

    fdecorate = decorate(fworkload, dispatch_func)
    fdecorate.register = register
    return fdecorate


class ApplyConfig(DispatchContext):
    """Apply a deterministic config entity for all queries.

    Parameters
    ----------
    config : ConfigSpace or ConfigEntity
        The specific configuration we care about.
    """
    def __init__(self, config):
        super(ApplyConfig, self).__init__()
        self._config = config
        self.workload = None

    def _query_inside(self, target, workload):
        """Override query"""
        self.workload = workload
        return self._config

    def update(self, target, workload, cfg):
        """Override update"""
        self.workload = workload
        self._config = cfg


class ApplyHistoryBest(DispatchContext):
    """
    Apply the history best config

    Parameters
    ----------
    records : str or iterator of (MeasureInput, MeasureResult)
        Collection of tuning records.
        If is str, then it should be the filename of a records log file.
                   Each row of this file is an encoded record pair.
        Otherwise, it is an iterator.
    """
    def __init__(self, records):
        super(ApplyHistoryBest, self).__init__()

        self.best_by_targetkey = {}
        self.best_by_model = {}
        self._best_user_defined = {}

        if records:
            self.load(records)

    def load(self, records):
        """Load records to this dispatch context

        Parameters
        ----------
        records : str or iterator of (MeasureInput, MeasureResult)
            Collection of tuning records.
            If is str, then it should be the filename of a records log file.
                       Each row of this file is an encoded record pair.
            Otherwise, it is an iterator.
        """
        from pathlib import Path
        from ..record import load_from_file

        if isinstance(records, Path):
            records = str(records)

        if isinstance(records, str):
            records = load_from_file(records)
        if not records:
            return

        best_by_targetkey = self.best_by_targetkey
        best_by_model = self.best_by_model

        counter = 0
        for inp, res in records:
            counter += 1
            if res.error_no != 0:
                continue

            # use target keys in tvm target system as key to build best map
            for k in inp.target.keys:
                key = (k, inp.task.workload)
                if key not in best_by_targetkey:
                    best_by_targetkey[key] = (inp, res)
                else:
                    _, other_res = best_by_targetkey[key]
                    if np.mean(other_res.costs) > np.mean(res.costs):
                        best_by_targetkey[key] = (inp, res)

            # use model as key to build best map
            key = (inp.target.model, inp.task.workload)
            if key not in best_by_model:
                if inp.target.model != 'unknown':
                    best_by_model[key] = (inp, res)
            else:
                _, other_res = best_by_model[key]
                if np.mean(other_res.costs) > np.mean(res.costs):
                    best_by_model[key] = (inp, res)

        logger.debug("Finish loading %d records", counter)

    def _query_inside(self, target, workload):
        if target is None:
            raise RuntimeError("Need a target context to find the history best. "
                               "Hint: If your target is llvm, use `with tvm.target.create('llvm'):`"
                               " above the dispatcher call. So does other target. ")

        # first try matching by model
        key = (target.model, workload)
        if key in self._best_user_defined:
            return self._best_user_defined[key]
        if key in self.best_by_model:
            return self.best_by_model[key][0].config

        # then try matching by target key
        for k in target.keys:
            key = (k, workload)
            if key in self._best_user_defined:
                return self._best_user_defined[key]
            if key in self.best_by_targetkey:
                return self.best_by_targetkey[key][0].config

        return None

    def update(self, target, workload, cfg):
        model = target.model
        key = (model, workload)
        self._best_user_defined[key] = cfg

        for k in target.keys:
            key = (k, workload)
            self._best_user_defined[key] = cfg


class FallbackContext(DispatchContext):
    """
    A fallback dispatch context.

    Any tunable template can be called under this context.
    This is the root context.
    """

    def __init__(self):
        super(FallbackContext, self).__init__()
        self.memory = {}
        self.silent = False

        # a set to prevent print duplicated message
        self.messages = set()

    def _query_inside(self, target, workload):
        key = (str(target), workload)
        if key in self.memory:
            return self.memory[key]

        if not self.silent:
            msg = "Cannot find config for target=%s, workload=%s. A fallback configuration "\
                  "is used, which may bring great performance regression." % (target, workload)
            if msg not in self.messages:
                self.messages.add(msg)
                logger.warning(msg)
        cfg = FallbackConfigEntity()

        # cache this config
        self.memory[key] = cfg
        return cfg

    def clear_cache(self, target, workload):
        """Clear fallback cache. Pass the same argument as _query_inside to this function
        to clean the cache.

        Parameters
        ----------
        target: Target
            The current target
        workload : Workload
            The current workload.
        """
        key = (str(target), workload)
        if key in self.memory:
            del self.memory[key]

    def update(self, target, workload, cfg):
        key = (str(target), workload)
        self.memory[key] = cfg


DispatchContext.current = FallbackContext()


def clear_fallback_cache(target, workload):
    """Clear fallback cache. Pass the same argument as _query_inside to this function
    to clean the cache.

    Parameters
    ----------
    target: Target
        The current target
    workload : Workload
        The current workload.

    Note
    ----
    This is used in alter_op_layout to clear the bad cache created before call topi compute function
    """
    context = DispatchContext.current
    while not isinstance(context, FallbackContext):
        context = context._old_ctx
    context.clear_cache(target, workload)


class ApplyGraphBest(DispatchContext):
    """Load the graph level tuning optimal schedules.

    The input records should be in the ascending order of
    node index for target operator. Usually this can be obtained
    with graph tuner.

    This context maintains an internal counter to indicate the current
    node index.
    """
    def __init__(self, records):
        """
        Parameters
        ----------
        records : str or iterator of (MeasureInput, MeasureResult)
            Collection of tuning records.
            If is str, then it should be the filename of a records log file.
                   Each row of this file is an encoded record pair.
            Otherwise, it is an iterator.
        """
        from ..record import load_from_file

        super(ApplyGraphBest, self).__init__()
        if isinstance(records, str):
            records = load_from_file(records)
        self._records = list(records)
        self._counter = 0
        self._global_cfg_dict = {}

    def _query_inside(self, target, workload):
        """
        Query the context to get config from records.

        Parameters
        ----------
        target : Target
            The current target
        workload : Workload
            The current workload.

        Returns
        -------
        cfg : ConfigSpace
            The specific configuration.
        """
        if self._counter < len(self._records):
            cfg = self._records[self._counter][0].config
            self._counter += 1
            self.update(target, workload, cfg)
            return cfg
        key = (str(target), workload)
        if key not in self._global_cfg_dict:
            msg = "Config for target=%s, workload=%s is missing in ApplyGraphBest context. " \
                  "A fallback configuration is used, which may bring great performance " \
                  "regression." % (target, workload)
            logger.warning(msg)
            cfg = FallbackConfigEntity()
            self._global_cfg_dict[key] = cfg
        else:
            cfg = self._global_cfg_dict[key]
        return cfg

    def update(self, target, workload, cfg):
        key = (str(target), workload)
        self._global_cfg_dict[key] = cfg