tuner.py 5.01 KB
Newer Older
1 2 3 4 5 6
# pylint: disable=unused-argument, no-self-use, invalid-name
"""Base class of tuner"""
import logging

import numpy as np

7
from ..measure import MeasureInput, create_measure_batch
8

9
from ..env import GLOBAL_SCOPE
10

11 12
logger = logging.getLogger('autotvm')

13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
class Tuner(object):
    """Base class for tuners

    Parameters
    ----------
    task: autotvm.task.Task
        Tuning Task
    """

    def __init__(self, task, **kwargs):
        self.param = kwargs
        self.recorder = None

        self.task = task

        # keep the current best
        self.best_config = None
        self.best_flops = 0
        self.best_measure_pair = None
32
        self.best_iter = 0
33

34 35 36
        # time to leave
        self.ttl = None
        self.n_trial = None
37
        self.early_stopping = None
38

39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
    def has_next(self):
        """Whether has next untried config in the space

        Returns
        -------
        has_next: bool
        """
        raise NotImplementedError()

    def next_batch(self, batch_size):
        """get the next batch of configs to be measure on real hardware

        Parameters
        ----------
        batch_size: int
            The size of the batch

        Returns
        -------
        a batch of configs
        """
        raise NotImplementedError()

    def update(self, inputs, results):
        """Update parameters of the tuner according to measurement results

        Parameters
        ----------
        inputs: Array of autotvm.measure.MeasureInput
            The input for measurement
        results: Array of autotvm.measure.MeasureResult
            result for measurement
        """
        pass

74
    def tune(self, n_trial, measure_option, early_stopping=None, callbacks=()):
75 76 77 78 79 80 81 82 83
        """Begin tuning

        Parameters
        ----------
        n_trial: int
            Maximum number of configs to try (measure on real hardware)
        measure_option: dict
            The options for how to measure generated code.
            You should use the return value ot autotvm.measure_option for this argument.
84
        early_stopping: int, optional
85
            Early stop the tuning when not finding better configs in this number of trials
86 87 88 89 90 91 92
        callbacks: List of callable
            A list of callback functions. The signature of callback function is
            (Tuner, List of MeasureInput, List of MeasureResult)
            with no return value. These callback functions will be called on
            every measurement pair. See autotvm/tuner/callback.py for some examples.
        """
        measure_batch = create_measure_batch(self.task, measure_option)
93
        n_parallel = getattr(measure_batch, 'n_parallel', 1)
94
        early_stopping = early_stopping or 1e9
95
        self.n_trial = n_trial
96
        self.early_stopping = early_stopping
97

98
        old_level = logger.level
99

100
        GLOBAL_SCOPE.in_tuning = True
101
        i = error_ct = 0
102 103 104 105
        while i < n_trial:
            if not self.has_next():
                break

106
            configs = self.next_batch(min(n_parallel, n_trial - i))
107 108 109 110

            inputs = [MeasureInput(self.task.target, self.task, config) for config in configs]
            results = measure_batch(inputs)

111 112 113 114 115
            # keep best config
            for k, (inp, res) in enumerate(zip(inputs, results)):
                config = inp.config
                if res.error_no == 0:
                    flops = inp.task.flop / np.mean(res.costs)
116
                    error_ct = 0
117 118
                else:
                    flops = 0
119 120
                    error_ct += 1

121 122 123 124 125 126
                if flops > self.best_flops:
                    self.best_flops = flops
                    self.best_config = config
                    self.best_measure_pair = (inp, res)
                    self.best_iter = i + k

127 128 129
                logger.debug("No: %d\tGFLOPS: %.2f/%.2f\tresult: %s\t%s",
                             i + k + 1, flops / 1e9, self.best_flops / 1e9,
                             res, config)
130 131

            i += len(results)
132
            self.ttl = min(early_stopping + self.best_iter, n_trial) - i
133 134 135 136 137

            self.update(inputs, results)
            for callback in callbacks:
                callback(self, inputs, results)

138
            if i >= self.best_iter + early_stopping:
139
                logger.debug("Early stopped. Best iter: %d.", self.best_iter)
140 141
                break

142
            if error_ct > 150:
143
                logging.basicConfig()
144 145 146 147
                logger.warning("Too many errors happen in the tuning. Now is in debug mode")
                logger.setLevel(logging.DEBUG)
            else:
                logger.setLevel(old_level)
148

149
        GLOBAL_SCOPE.in_tuning = False
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
        del measure_batch

    def reset(self):
        """reset the status of tuner"""
        self.best_config = None
        self.best_flops = 0
        self.best_measure_pair = None

    def load_history(self, data_set):
        """load history data for transfer learning

        Parameters
        ----------
        data_set: Array of (MeasureInput, MeasureResult) pair
            Previous tuning records
        """
        raise NotImplementedError()