# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=consider-using-enumerate,invalid-name
"""Namespace of callback utilities of AutoTVM"""
import sys
import time
import logging

import numpy as np

from .. import record

logger = logging.getLogger('autotvm')


def log_to_file(file_out, protocol='json'):
    """Log the tuning records into file.
    The rows of the log are stored in the format of autotvm.record.encode.

    Parameters
    ----------
    file_out : File or str
        The file to log to.
    protocol: str, optional
        The log protocol. Can be 'json' or 'pickle'

    Returns
    -------
    callback : callable
        Callback function to do the logging.
    """
    def _callback(_, inputs, results):
        """Callback implementation"""
        if isinstance(file_out, str):
            with open(file_out, "a") as f:
                for inp, result in zip(inputs, results):
                    f.write(record.encode(inp, result, protocol) + "\n")
        else:
            for inp, result in zip(inputs, results):
                file_out.write(record.encode(inp, result, protocol) + "\n")

    # pylint: disable=import-outside-toplevel
    from pathlib import Path
    if isinstance(file_out, Path):
        file_out = str(file_out)

    return _callback


def log_to_database(db):
    """Save the tuning records to a database object.

    Parameters
    ----------
    db: Database
        The database
    """
    def _callback(_, inputs, results):
        """Callback implementation"""
        for inp, result in zip(inputs, results):
            db.save(inp, result)
    return _callback


class Monitor(object):
    """A monitor to collect statistic during tuning"""
    def __init__(self):
        self.scores = []
        self.timestamps = []

    def __call__(self, tuner, inputs, results):
        for inp, res in zip(inputs, results):
            if res.error_no == 0:
                flops = inp.task.flop / np.mean(res.costs)
                self.scores.append(flops)
            else:
                self.scores.append(0)

            self.timestamps.append(res.timestamp)

    def reset(self):
        self.scores = []
        self.timestamps = []

    def trial_scores(self):
        """get scores (currently is flops) of all trials"""
        return np.array(self.scores)

    def trial_timestamps(self):
        """get wall clock time stamp of all trials"""
        return np.array(self.timestamps)


def progress_bar(total, prefix=''):
    """Display progress bar for tuning

    Parameters
    ----------
    total: int
        The total number of trials
    prefix: str
        The prefix of output message
    """
    class _Context(object):
        """Context to store local variables"""
        def __init__(self):
            self.best_flops = 0
            self.cur_flops = 0
            self.ct = 0
            self.total = total

        def __del__(self):
            if logger.level < logging.DEBUG:  # only print progress bar in non-debug mode
                sys.stdout.write(' Done.\n')

    ctx = _Context()
    tic = time.time()

    if logger.level < logging.DEBUG:  # only print progress bar in non-debug mode
        sys.stdout.write('\r%s Current/Best: %7.2f/%7.2f GFLOPS | Progress: (%d/%d) '
                         '| %.2f s' % (prefix, 0, 0, 0, total, time.time() - tic))
        sys.stdout.flush()

    def _callback(tuner, inputs, results):
        ctx.ct += len(inputs)

        flops = 0
        for inp, res in zip(inputs, results):
            if res.error_no == 0:
                flops = inp.task.flop / np.mean(res.costs)

        if logger.level < logging.DEBUG:  # only print progress bar in non-debug mode
            ctx.cur_flops = flops
            ctx.best_flops = tuner.best_flops

            sys.stdout.write('\r%s Current/Best: %7.2f/%7.2f GFLOPS | Progress: (%d/%d) '
                             '| %.2f s' %
                             (prefix, ctx.cur_flops/1e9, ctx.best_flops/1e9, ctx.ct, ctx.total,
                              time.time() - tic))
            sys.stdout.flush()

    return _callback