xgboost_cost_model.py 18.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# pylint: disable=invalid-name
"""XGBoost as cost model"""

import multiprocessing
import logging
import time

import numpy as np
try:
    import xgboost as xgb
except ImportError:
    xgb = None

from .. import feature
from ..util import get_rank
from .metric import max_curve, recall_curve, cover_curve
from .model_based_tuner import CostModel, FeatureCache

19 20
logger = logging.getLogger('autotvm')

21 22 23 24 25 26 27 28 29 30 31 32 33
class XGBoostCostModel(CostModel):
    """XGBoost as cost model

    Parameters
    ----------
    task: Task
        The tuning task
    feature_type: str, optional
        If is 'itervar', use features extracted from IterVar (loop variable).
        If is 'knob', use flatten ConfigEntity directly.
        If is 'curve', use sampled curve feature (relation feature).

        Note on choosing feature type:
34
        For single task tuning, 'itervar' and 'knob' are good.
35
                                'itervar' is more accurate but 'knob' is much faster.
36 37 38 39
                                There are some constraints on 'itervar', if you meet
                                problems with feature extraction when using 'itervar',
                                you can swith to 'knob'.

40 41 42 43 44 45 46 47 48 49 50
        For cross-shape tuning (e.g. many convolutions with different shapes),
                               'itervar' and 'curve' has better transferability,
                               'knob' is faster.
        For cross-device or cross-operator tuning, you can use 'curve' only.
    loss_type: str
        If is 'reg', use regression loss to train cost model.
                     The cost model predicts the normalized flops.
        If is 'rank', use pairwise rank loss to train cost model.
                     The cost model predicts relative rank score.
    num_threads: int, optional
        The number of threads.
51 52
    log_interval: int, optional
        If is not none, the cost model will print training log every `log_interval` iterations.
53 54
    upper_model: XGBoostCostModel, optional
        The upper model used in transfer learning
55
    """
56
    def __init__(self, task, feature_type, loss_type, num_threads=None, log_interval=25,
57
                 upper_model=None):
58 59 60 61 62 63 64 65 66 67 68 69 70 71
        super(XGBoostCostModel, self).__init__()

        if xgb is None:
            raise RuntimeError("XGBoost is required for XGBoostCostModel. "
                               "Please install its python package first. "
                               "Help: (https://xgboost.readthedocs.io/en/latest/) ")

        self.task = task
        self.target = task.target
        self.space = task.config_space

        self.fea_type = feature_type
        self.loss_type = loss_type
        self.num_threads = num_threads
72
        self.log_interval = log_interval
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118

        if loss_type == 'reg':
            self.xgb_params = {
                'max_depth': 3,
                'gamma': 0.0001,
                'min_child_weight': 1,

                'subsample': 1.0,

                'eta': 0.3,
                'lambda': 1.00,
                'alpha': 0,

                'objective': 'reg:linear',
            }
        elif loss_type == 'rank':
            self.xgb_params = {
                'max_depth': 3,
                'gamma': 0.0001,
                'min_child_weight': 1,

                'subsample': 1.0,

                'eta': 0.3,
                'lambda': 1.00,
                'alpha': 0,

                'objective': 'rank:pairwise',
            }
        else:
            raise RuntimeError("Invalid loss type: " + loss_type)

        self.xgb_params['silent'] = 1
        if num_threads:
            self.xgb_params['nthread'] = num_threads
        self.bst = None

        if feature_type == 'itervar':
            self.feature_extract_func = _extract_itervar_feature_index
        elif feature_type == 'knob':
            self.feature_extract_func = _extract_knob_feature_index
        elif feature_type == 'curve':
            self.feature_extract_func = _extract_curve_feature_index
        else:
            raise RuntimeError("Invalid feature type " + feature_type)

119 120 121 122 123
        if upper_model:  # share a same feature cache with upper model
            self.feature_cache = upper_model.feature_cache
        else:
            self.feature_cache = FeatureCache()
        self.upper_model = upper_model
124 125 126
        self.feature_extra_ct = 0
        self.pool = None
        self.base_model = None
127 128

        self._sample_size = 0
129
        self._reset_pool(self.space, self.target, self.task)
130

131 132 133 134 135 136 137 138
    def _reset_pool(self, space, target, task):
        """reset processing pool for feature extraction"""

        if self.upper_model:  # base model will reuse upper model's pool,
            self.upper_model._reset_pool(space, target, task)
            return

        self._close_pool()
139 140 141

        # use global variable to pass common arguments
        global _extract_space, _extract_target, _extract_task
142 143 144
        _extract_space = space
        _extract_target = target
        _extract_task = task
145 146
        self.pool = multiprocessing.Pool(self.num_threads)

147 148 149 150 151 152 153 154 155 156 157
    def _close_pool(self):
        if self.pool:
            self.pool.terminate()
            self.pool.join()
            self.pool = None

    def _get_pool(self):
        if self.upper_model:
            return self.upper_model._get_pool()
        return self.pool

158
    def _base_model_discount(self):
159
        return 1.0 / (2 ** (self._sample_size / 64.0))
160

161 162
    def fit(self, xs, ys, plan_size):
        tic = time.time()
163
        self._reset_pool(self.space, self.target, self.task)
164 165 166

        x_train = self._get_feature(xs)
        y_train = np.array(ys)
167 168
        y_max = np.max(y_train)
        y_train = y_train / max(y_max, 1e-8)
169 170 171 172

        valid_index = y_train > 1e-6
        index = np.random.permutation(len(x_train))
        dtrain = xgb.DMatrix(x_train[index], y_train[index])
173
        self._sample_size = len(x_train)
174 175

        if self.base_model:
176 177 178 179 180 181
            discount = self._base_model_discount()
            if discount < 0.05:  # discard base model
                self.base_model.upper_model = None
                self.base_model = None
            else:
                dtrain.set_base_margin(discount * self.base_model.predict(xs, output_margin=True))
182 183 184 185 186 187 188 189 190 191 192

        self.bst = xgb.train(self.xgb_params, dtrain,
                             num_boost_round=8000,
                             callbacks=[custom_callback(
                                 stopping_rounds=20,
                                 metric='tr-a-recall@%d' % plan_size,
                                 evals=[(dtrain, 'tr')],
                                 maximize=True,
                                 fevals=[
                                     xgb_average_recalln_curve_score(plan_size),
                                 ],
193
                                 verbose_eval=self.log_interval)])
194

195 196 197 198
        logger.debug("XGB train: %.2f\tobs: %d\terror: %d\tn_cache: %d",
                     time.time() - tic, len(xs),
                     len(xs) - np.sum(valid_index),
                     self.feature_cache.size(self.fea_type))
199 200 201 202

    def fit_log(self, records, plan_size):
        tic = time.time()

203 204 205 206 207 208 209 210
        # filter data, only pick the data with a same task
        data = []
        for inp, res in records:
            if inp.task.name == self.task.name and \
                            inp.config.template_key == self.task.config_space.template_key:
                data.append((inp, res))

        logger.debug("XGB load %d entries from history log file", len(data))
211

212 213 214
        # extract feature
        self._reset_pool(self.space, self.target, self.task)
        pool = self._get_pool()
215 216 217 218 219 220 221 222
        if self.fea_type == 'itervar':
            feature_extract_func = _extract_itervar_feature_log
        elif self.fea_type == 'knob':
            feature_extract_func = _extract_knob_feature_log
        elif self.fea_type == 'curve':
            feature_extract_func = _extract_curve_feature_log
        else:
            raise RuntimeError("Invalid feature type: " + self.fea_type)
223 224 225 226 227 228 229 230 231 232
        res = pool.map(feature_extract_func, data)

        # filter out feature with different shapes
        fea_len = len(self._get_feature([0])[0])

        xs, ys = [], []
        for x, y in res:
            if len(x) == fea_len:
                xs.append(x)
                ys.append(y)
233

234 235 236 237
        if len(xs) < 500:  # no enough samples
            return False

        xs, ys = np.array(xs), np.array(ys)
238 239
        x_train = xs
        y_train = ys
240 241
        y_max = np.max(y_train)
        y_train = y_train / max(y_max, 1e-8)
242 243 244 245 246 247

        index = np.random.permutation(len(x_train))
        dtrain = xgb.DMatrix(x_train[index], y_train[index])

        plan_size *= 2
        self.bst = xgb.train(self.xgb_params, dtrain,
248
                             num_boost_round=400,
249 250 251 252 253 254 255 256
                             callbacks=[custom_callback(
                                 stopping_rounds=100,
                                 metric='tr-a-recall@%d' % plan_size,
                                 evals=[(dtrain, 'tr')],
                                 maximize=True,
                                 fevals=[
                                     xgb_average_recalln_curve_score(plan_size),
                                 ],
257
                                 verbose_eval=self.log_interval)])
258

259
        logger.debug("XGB train: %.2f\tobs: %d", time.time() - tic, len(xs))
260

261 262
        return True

263 264 265 266 267
    def predict(self, xs, output_margin=False):
        feas = self._get_feature(xs)
        dtest = xgb.DMatrix(feas)

        if self.base_model:
268 269
            dtest.set_base_margin(self._base_model_discount() *
                                  self.base_model.predict(xs, output_margin=True))
270 271 272 273 274

        return self.bst.predict(dtest, output_margin=output_margin)

    def load_basemodel(self, base_model):
        self.base_model = base_model
275 276 277 278
        self.base_model._close_pool()
        self.base_model.upper_model = self

    def spawn_base_model(self):
279
        return XGBoostCostModel(self.task, self.fea_type, self.loss_type,
280
                                self.num_threads, self.log_interval, self)
281 282 283 284 285 286 287 288 289 290 291 292 293

    def _get_feature(self, indexes):
        """get features for indexes, run extraction if we do not have cache for them"""
        # free feature cache
        if self.feature_cache.size(self.fea_type) >= 100000:
            self.feature_cache.clear(self.fea_type)

        fea_cache = self.feature_cache.get(self.fea_type)

        indexes = np.array(indexes)
        need_extract = [x for x in indexes if x not in fea_cache]

        if need_extract:
294
            pool = self._get_pool()
295
            feas = pool.map(self.feature_extract_func, need_extract)
296 297 298 299 300 301 302 303
            for i, fea in zip(need_extract, feas):
                fea_cache[i] = fea

        ret = np.empty((len(indexes), fea_cache[indexes[0]].shape[-1]), dtype=np.float32)
        for i, ii in enumerate(indexes):
            ret[i, :] = fea_cache[ii]
        return ret

304 305 306
    def __del__(self):
        self._close_pool()

307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332

_extract_space = None
_extract_target = None
_extract_task = None

def _extract_itervar_feature_index(index):
    """extract iteration var feature for an index in extract_space"""
    config = _extract_space.get(index)
    with _extract_target:
        sch, args = _extract_task.instantiate(config)
    fea = feature.get_itervar_feature_flatten(sch, args, take_log=True)
    fea = np.concatenate((fea, list(config.get_other_option().values())))
    return fea

def _extract_itervar_feature_log(arg):
    """extract iteration var feature for log items"""
    inp, res = arg
    config = inp.config
    with inp.target:
        sch, args = inp.task.instantiate(config)
    fea = feature.get_itervar_feature_flatten(sch, args, take_log=True)
    x = np.concatenate((fea, list(config.get_other_option().values())))

    if res.error_no == 0:
        y = inp.task.flop / np.mean(res.costs)
    else:
333
        y = 0.0
334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351
    return x, y

def _extract_knob_feature_index(index):
    """extract knob feature for an index in extract_space"""
    config = _extract_space.get(index)
    return config.get_flatten_feature()

def _extract_knob_feature_log(arg):
    """extract knob feature for log items"""
    inp, res = arg
    config = inp.config
    x = config.get_flatten_feature()

    if res.error_no == 0:
        with inp.target:  # necessary, for calculating flops of this task
            inp.task.instantiate(config)
        y = inp.task.flop / np.mean(res.costs)
    else:
352
        y = 0.0
353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375
    return x, y

def _extract_curve_feature_index(index):
    """extract sampled curve feature for an index in extract_space"""
    config = _extract_space.get(index)
    with _extract_target:
        sch, args = _extract_task.instantiate(config)
    fea = feature.get_buffer_curve_sample_flatten(sch, args, sample_n=20)
    fea = np.concatenate((fea, list(config.get_other_option().values())))
    return np.array(fea)

def _extract_curve_feature_log(arg):
    """extract sampled curve feature for log items"""
    inp, res = arg
    config = inp.config
    with inp.target:
        sch, args = inp.task.instantiate(config)
    fea = feature.get_buffer_curve_sample_flatten(sch, args, sample_n=20)
    x = np.concatenate((fea, list(config.get_other_option().values())))

    if res.error_no == 0:
        y = inp.task.flop / np.mean(res.costs)
    else:
376
        y = 0.0
377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449
    return x, y


def custom_callback(stopping_rounds, metric, fevals, evals=(), log_file=None,
                    maximize=False, verbose_eval=True):
    """callback function for xgboost to support multiple custom evaluation functions"""
    from xgboost.core import EarlyStopException
    from xgboost.callback import _fmt_metric
    from xgboost.training import aggcv

    state = {}
    metric_shortname = metric.split("-")[1]

    def init(env):
        """internal function"""
        bst = env.model

        state['maximize_score'] = maximize
        state['best_iteration'] = 0
        if maximize:
            state['best_score'] = float('-inf')
        else:
            state['best_score'] = float('inf')

        if bst is not None:
            if bst.attr('best_score') is not None:
                state['best_score'] = float(bst.attr('best_score'))
                state['best_iteration'] = int(bst.attr('best_iteration'))
                state['best_msg'] = bst.attr('best_msg')
            else:
                bst.set_attr(best_iteration=str(state['best_iteration']))
                bst.set_attr(best_score=str(state['best_score']))
        else:
            assert env.cvfolds is not None

    def callback(env):
        """internal function"""
        if not state:
            init(env)

        bst = env.model
        i = env.iteration
        cvfolds = env.cvfolds

        res_dict = {}

        ##### evaluation #####
        if cvfolds is not None:
            for feval in fevals:
                tmp = aggcv([f.eval(i, feval) for f in cvfolds])
                for k, mean, std in tmp:
                    res_dict[k] = [mean, std]
        else:
            for feval in fevals:
                bst_eval = bst.eval_set(evals, i, feval)
                res = [x.split(':') for x in bst_eval.split()]
                for kv in res[1:]:
                    res_dict[kv[0]] = [float(kv[1])]

        eval_res = []
        keys = list(res_dict.keys())
        keys.sort(key=lambda x: x if metric_shortname not in x else "a" + x)
        for key in keys:
            v = res_dict[key]
            eval_res.append([key] + v)

        ##### print eval result #####
        infos = ["XGB iter: %3d" % i]
        for item in eval_res:
            if 'null' in item[0]:
                continue
            infos.append("%s: %.6f" % (item[0], item[1]))

450
        if not isinstance(verbose_eval, bool) and verbose_eval and i % verbose_eval == 0:
451
            logger.debug("\t".join(infos))
452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482
        if log_file:
            with open(log_file, "a") as fout:
                fout.write("\t".join(infos) + '\n')

        ##### choose score and do early stopping #####
        score = None
        for item in eval_res:
            if item[0] == metric:
                score = item[1]
                break
        assert score is not None

        best_score = state['best_score']
        best_iteration = state['best_iteration']
        maximize_score = state['maximize_score']
        if (maximize_score and score > best_score) or \
                (not maximize_score and score < best_score):
            msg = '[%d] %s' % (
                env.iteration,
                '\t'.join([_fmt_metric(x) for x in eval_res]))
            state['best_msg'] = msg
            state['best_score'] = score
            state['best_iteration'] = env.iteration
            # save the property to attributes, so they will occur in checkpoint.
            if env.model is not None:
                env.model.set_attr(best_score=str(state['best_score']),
                                   best_iteration=str(state['best_iteration']),
                                   best_msg=state['best_msg'])
        elif env.iteration - best_iteration >= stopping_rounds:
            best_msg = state['best_msg']
            if verbose_eval and env.rank == 0:
483
                logger.debug("XGB stopped. Best iteration: %s ", best_msg)
484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544
            raise EarlyStopException(best_iteration)

    return callback


# feval wrapper for xgboost
def xgb_max_curve_score(N):
    """evaluate max curve score for xgb"""
    def feval(preds, labels):
        labels = labels.get_label()
        trials = np.argsort(preds)[::-1]
        scores = labels[trials]
        curve = max_curve(scores)
        return "Smax@%d" % N, curve[N] / np.max(labels)
    return feval

def xgb_recalln_curve_score(N):
    """evaluate recall-n curve score for xgb"""
    def feval(preds, labels):
        labels = labels.get_label()
        trials = np.argsort(preds)[::-1]
        ranks = get_rank(labels[trials])
        curve = recall_curve(ranks)
        return "recall@%d" % N, curve[N]
    return feval

def xgb_average_recalln_curve_score(N):
    """evaluate average recall-n curve score for xgb"""
    def feval(preds, labels):
        labels = labels.get_label()
        trials = np.argsort(preds)[::-1]
        ranks = get_rank(labels[trials])
        curve = recall_curve(ranks)
        return "a-recall@%d" % N, np.sum(curve[:N]) / N
    return feval

def xgb_recallk_curve_score(N, topk):
    """evaluate recall-k curve score for xgb"""
    def feval(preds, labels):
        labels = labels.get_label()
        trials = np.argsort(preds)[::-1]
        ranks = get_rank(labels[trials])
        curve = recall_curve(ranks, topk)
        return "recall@%d" % topk, curve[N]
    return feval

def xgb_cover_curve_score(N):
    """evaluate cover curve score for xgb"""
    def feval(preds, labels):
        labels = labels.get_label()
        trials = np.argsort(preds)[::-1]
        ranks = get_rank(labels[trials])
        curve = cover_curve(ranks)
        return "cover@%d" % N, curve[N]
    return feval

def xgb_null_score(_):
    """empty score function for xgb"""
    def feval(__, ___):
        return "null", 0
    return feval