callback.py 4.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18
# pylint: disable=consider-using-enumerate,invalid-name
"""Namespace of callback utilities of AutoTVM"""
19 20
import sys
import time
21
import logging
22 23 24 25 26

import numpy as np

from .. import record

27
logger = logging.getLogger('autotvm')
28

29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
def log_to_file(file_out, protocol='json'):
    """Log the tuning records into file.
    The rows of the log are stored in the format of autotvm.record.encode.

    Parameters
    ----------
    file_out : File or str
        The file to log to.
    protocol: str, optional
        The log protocol. Can be 'json' or 'pickle'

    Returns
    -------
    callback : callable
        Callback function to do the logging.
    """
    def _callback(_, inputs, results):
        """Callback implementation"""
        if isinstance(file_out, str):
            with open(file_out, "a") as f:
                for inp, result in zip(inputs, results):
                    f.write(record.encode(inp, result, protocol) + "\n")
        else:
            for inp, result in zip(inputs, results):
                file_out.write(record.encode(inp, result, protocol) + "\n")
    return _callback


57 58
def log_to_database(db):
    """Save the tuning records to a database object.
59 60 61

    Parameters
    ----------
62 63
    db: Database
        The database
64 65 66 67
    """
    def _callback(_, inputs, results):
        """Callback implementation"""
        for inp, result in zip(inputs, results):
68
            db.save(inp, result)
69 70
    return _callback

71

72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
class Monitor(object):
    """A monitor to collect statistic during tuning"""
    def __init__(self):
        self.scores = []
        self.timestamps = []

    def __call__(self, tuner, inputs, results):
        for inp, res in zip(inputs, results):
            if res.error_no == 0:
                flops = inp.task.flop / np.mean(res.costs)
                self.scores.append(flops)
            else:
                self.scores.append(0)

            self.timestamps.append(res.timestamp)

    def reset(self):
        self.scores = []
        self.timestamps = []

    def trial_scores(self):
        """get scores (currently is flops) of all trials"""
        return np.array(self.scores)

    def trial_timestamps(self):
        """get wall clock time stamp of all trials"""
        return np.array(self.timestamps)
99 100 101 102 103 104 105 106 107 108 109 110


def progress_bar(total, prefix=''):
    """Display progress bar for tuning

    Parameters
    ----------
    total: int
        The total number of trials
    prefix: str
        The prefix of output message
    """
111
    class _Context(object):
112 113 114 115 116 117 118 119
        """Context to store local variables"""
        def __init__(self):
            self.best_flops = 0
            self.cur_flops = 0
            self.ct = 0
            self.total = total

        def __del__(self):
120 121
            if logger.level < logging.DEBUG:  # only print progress bar in non-debug mode
                sys.stdout.write(' Done.\n')
122 123 124 125

    ctx = _Context()
    tic = time.time()

126 127 128 129 130
    if logger.level < logging.DEBUG:  # only print progress bar in non-debug mode
        sys.stdout.write('\r%s Current/Best: %7.2f/%7.2f GFLOPS | Progress: (%d/%d) '
                         '| %.2f s' % (prefix, 0, 0, 0, total, time.time() - tic))
        sys.stdout.flush()

131 132 133 134 135 136 137 138
    def _callback(tuner, inputs, results):
        ctx.ct += len(inputs)

        flops = 0
        for inp, res in zip(inputs, results):
            if res.error_no == 0:
                flops = inp.task.flop / np.mean(res.costs)

139 140 141
        if logger.level < logging.DEBUG:  # only print progress bar in non-debug mode
            ctx.cur_flops = flops
            ctx.best_flops = tuner.best_flops
142

143 144
            sys.stdout.write('\r%s Current/Best: %7.2f/%7.2f GFLOPS | Progress: (%d/%d) '
                             '| %.2f s' %
145 146 147
                             (prefix, ctx.cur_flops/1e9, ctx.best_flops/1e9, ctx.ct, ctx.total,
                              time.time() - tic))
            sys.stdout.flush()
148 149

    return _callback