sa_model_optimizer.py 4.79 KB
Newer Older
1
# pylint: disable=consider-using-enumerate, invalid-name
2 3 4 5 6 7 8 9 10 11 12 13 14
"""
Cost model optimizer based on simulated annealing
"""

import heapq
import logging
import time

import numpy as np

from ..util import sample_ints
from .model_based_tuner import ModelOptimizer, knob2point, point2knob

15 16
logger = logging.getLogger('autotvm')

17 18 19 20 21 22 23 24 25 26 27 28 29 30
class SimulatedAnnealingOptimizer(ModelOptimizer):
    """parallel simulated annealing optimization algorithm

    Parameters
    ----------
    task: Task
        The tuning task
    n_iter: int
        The number of iterations of simulated annealing
    temp: float or Array of float
        If is a single float, then use a constant temperature.
        If is an Array, then perform linear cooling from temp[0] to temp[1]
    early_stop: int, optional
        Stop iteration if the optimal set do not change in `early_stop` rounds
31 32
    log_interval: int, optional
        Print log every `log_interval` iterations
33 34
    """
    def __init__(self, task, n_iter=500, temp=(1, 0), persistent=True, parallel_size=128,
35
                 early_stop=50, log_interval=50):
36 37 38 39 40 41 42 43
        super(SimulatedAnnealingOptimizer, self).__init__()

        self.task = task
        self.dims = [len(x) for x in self.task.config_space.space_map.values()]

        self.n_iter = n_iter
        self.temp = temp
        self.persistent = persistent
44 45
        self.parallel_size = min(parallel_size, len(self.task.config_space))
        self.early_stop = early_stop or 1e9
46
        self.log_interval = log_interval
47 48 49 50
        self.points = None

    def find_maximums(self, model, num, exclusive):
        tic = time.time()
51 52
        temp, n_iter, early_stop, log_interval = \
                self.temp, self.n_iter, self.early_stop, self.log_interval
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89

        if self.persistent and self.points is not None:
            points = self.points
        else:
            points = np.array(sample_ints(0, len(self.task.config_space), self.parallel_size))

        scores = model.predict(points)

        # build heap and insert initial points
        heap_items = [(float('-inf'), -i) for i in range(num)]
        heapq.heapify(heap_items)
        in_heap = set(exclusive)
        in_heap.update([-i for i in range(num)])

        for s, p in zip(scores, points):
            if s > heap_items[0][0] and p not in in_heap:
                pop = heapq.heapreplace(heap_items, (s, p))
                in_heap.remove(pop[1])
                in_heap.add(p)

        k = 0
        k_last_modify = 0

        if isinstance(temp, (tuple, list, np.ndarray)):
            t = temp[0]
            cool = 1.0 * (temp[0] - temp[1]) / (n_iter + 1)
        else:
            t = temp
            cool = 0

        while k < n_iter and k < k_last_modify + early_stop:
            new_points = np.empty_like(points)
            for i, p in enumerate(points):
                new_points[i] = random_walk(p, self.dims)

            new_scores = model.predict(new_points)

90
            ac_prob = np.exp(np.minimum((new_scores - scores) / (t + 1e-5), 1))
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
            ac_index = np.random.random(len(ac_prob)) < ac_prob

            points[ac_index] = new_points[ac_index]
            scores[ac_index] = new_scores[ac_index]

            for s, p in zip(new_scores, new_points):
                if s > heap_items[0][0] and p not in in_heap:
                    pop = heapq.heapreplace(heap_items, (s, p))
                    in_heap.remove(pop[1])
                    in_heap.add(p)
                    k_last_modify = k

            k += 1
            t -= cool

106
            if log_interval and k % log_interval == 0:
107
                t_str = "%.2f" % t
108 109 110 111 112
                logger.debug("SA iter: %d\tlast_update: %d\tmax-0: %.2f\tmax-1: %.2f\ttemp: %s\t"
                             "elapsed: %.2f",
                             k, k_last_modify, heap_items[0][0],
                             np.max([v for v, _ in heap_items]), t_str,
                             time.time() - tic)
113 114

        heap_items.sort(key=lambda item: -item[0])
115 116 117
        logger.debug("SA iter: %d\tlast_update: %d\tmax-0: %.2f\tmax-1: %.2f\telapsed: %.2f",
                     k, k_last_modify, heap_items[-1][0], heap_items[0][0], time.time() - tic)
        logger.debug("SA Maximums: %s", heap_items)
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150

        if self.persistent:
            self.points = points

        return [x[1] for x in heap_items]

def random_walk(p, dims):
    """random walk as local transition

    Parameters
    ----------
    p: int
        index of the ConfigEntity
    dims: Array of int
        sizes of each dimension

    Returns
    -------
    new_p: int
        new neighborhood index
    """
    # transform to knob form
    old = point2knob(p, dims)
    new = list(old)

    # mutate
    while new == old:
        from_i = np.random.randint(len(old))
        to_v = np.random.randint(dims[from_i])
        new[from_i] = to_v

    # transform to index form
    return knob2point(new, dims)