xgboost_cost_model.py 20.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
# pylint: disable=invalid-name
"""XGBoost as cost model"""

import multiprocessing
import logging
import time

import numpy as np
try:
    import xgboost as xgb
except ImportError:
    xgb = None

from .. import feature
from ..util import get_rank
from .metric import max_curve, recall_curve, cover_curve
from .model_based_tuner import CostModel, FeatureCache

35 36
logger = logging.getLogger('autotvm')

37 38 39 40 41 42 43 44 45 46 47 48 49
class XGBoostCostModel(CostModel):
    """XGBoost as cost model

    Parameters
    ----------
    task: Task
        The tuning task
    feature_type: str, optional
        If is 'itervar', use features extracted from IterVar (loop variable).
        If is 'knob', use flatten ConfigEntity directly.
        If is 'curve', use sampled curve feature (relation feature).

        Note on choosing feature type:
50
        For single task tuning, 'itervar' and 'knob' are good.
51
                                'itervar' is more accurate but 'knob' is much faster.
52 53
                                There are some constraints on 'itervar', if you meet
                                problems with feature extraction when using 'itervar',
54
                                you can switch to 'knob'.
55

56 57 58 59 60 61 62 63 64 65 66
        For cross-shape tuning (e.g. many convolutions with different shapes),
                               'itervar' and 'curve' has better transferability,
                               'knob' is faster.
        For cross-device or cross-operator tuning, you can use 'curve' only.
    loss_type: str
        If is 'reg', use regression loss to train cost model.
                     The cost model predicts the normalized flops.
        If is 'rank', use pairwise rank loss to train cost model.
                     The cost model predicts relative rank score.
    num_threads: int, optional
        The number of threads.
67 68
    log_interval: int, optional
        If is not none, the cost model will print training log every `log_interval` iterations.
69 70
    upper_model: XGBoostCostModel, optional
        The upper model used in transfer learning
71
    """
72
    def __init__(self, task, feature_type, loss_type, num_threads=None, log_interval=25,
73
                 upper_model=None):
74 75 76 77 78 79 80 81 82 83 84 85 86 87
        super(XGBoostCostModel, self).__init__()

        if xgb is None:
            raise RuntimeError("XGBoost is required for XGBoostCostModel. "
                               "Please install its python package first. "
                               "Help: (https://xgboost.readthedocs.io/en/latest/) ")

        self.task = task
        self.target = task.target
        self.space = task.config_space

        self.fea_type = feature_type
        self.loss_type = loss_type
        self.num_threads = num_threads
88
        self.log_interval = log_interval
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134

        if loss_type == 'reg':
            self.xgb_params = {
                'max_depth': 3,
                'gamma': 0.0001,
                'min_child_weight': 1,

                'subsample': 1.0,

                'eta': 0.3,
                'lambda': 1.00,
                'alpha': 0,

                'objective': 'reg:linear',
            }
        elif loss_type == 'rank':
            self.xgb_params = {
                'max_depth': 3,
                'gamma': 0.0001,
                'min_child_weight': 1,

                'subsample': 1.0,

                'eta': 0.3,
                'lambda': 1.00,
                'alpha': 0,

                'objective': 'rank:pairwise',
            }
        else:
            raise RuntimeError("Invalid loss type: " + loss_type)

        self.xgb_params['silent'] = 1
        if num_threads:
            self.xgb_params['nthread'] = num_threads
        self.bst = None

        if feature_type == 'itervar':
            self.feature_extract_func = _extract_itervar_feature_index
        elif feature_type == 'knob':
            self.feature_extract_func = _extract_knob_feature_index
        elif feature_type == 'curve':
            self.feature_extract_func = _extract_curve_feature_index
        else:
            raise RuntimeError("Invalid feature type " + feature_type)

135 136 137 138 139
        if upper_model:  # share a same feature cache with upper model
            self.feature_cache = upper_model.feature_cache
        else:
            self.feature_cache = FeatureCache()
        self.upper_model = upper_model
140 141 142
        self.feature_extra_ct = 0
        self.pool = None
        self.base_model = None
143 144

        self._sample_size = 0
145
        self._reset_pool(self.space, self.target, self.task)
146

147 148 149 150 151 152 153 154
    def _reset_pool(self, space, target, task):
        """reset processing pool for feature extraction"""

        if self.upper_model:  # base model will reuse upper model's pool,
            self.upper_model._reset_pool(space, target, task)
            return

        self._close_pool()
155 156 157

        # use global variable to pass common arguments
        global _extract_space, _extract_target, _extract_task
158 159 160
        _extract_space = space
        _extract_target = target
        _extract_task = task
161 162
        self.pool = multiprocessing.Pool(self.num_threads)

163 164 165 166 167 168 169 170 171 172 173
    def _close_pool(self):
        if self.pool:
            self.pool.terminate()
            self.pool.join()
            self.pool = None

    def _get_pool(self):
        if self.upper_model:
            return self.upper_model._get_pool()
        return self.pool

174
    def _base_model_discount(self):
175
        return 1.0 / (2 ** (self._sample_size / 64.0))
176

177 178
    def fit(self, xs, ys, plan_size):
        tic = time.time()
179
        self._reset_pool(self.space, self.target, self.task)
180 181 182

        x_train = self._get_feature(xs)
        y_train = np.array(ys)
183 184
        y_max = np.max(y_train)
        y_train = y_train / max(y_max, 1e-8)
185 186 187 188

        valid_index = y_train > 1e-6
        index = np.random.permutation(len(x_train))
        dtrain = xgb.DMatrix(x_train[index], y_train[index])
189
        self._sample_size = len(x_train)
190 191

        if self.base_model:
192 193 194 195 196 197
            discount = self._base_model_discount()
            if discount < 0.05:  # discard base model
                self.base_model.upper_model = None
                self.base_model = None
            else:
                dtrain.set_base_margin(discount * self.base_model.predict(xs, output_margin=True))
198 199 200 201 202 203 204 205 206 207 208

        self.bst = xgb.train(self.xgb_params, dtrain,
                             num_boost_round=8000,
                             callbacks=[custom_callback(
                                 stopping_rounds=20,
                                 metric='tr-a-recall@%d' % plan_size,
                                 evals=[(dtrain, 'tr')],
                                 maximize=True,
                                 fevals=[
                                     xgb_average_recalln_curve_score(plan_size),
                                 ],
209
                                 verbose_eval=self.log_interval)])
210

211 212 213 214
        logger.debug("XGB train: %.2f\tobs: %d\terror: %d\tn_cache: %d",
                     time.time() - tic, len(xs),
                     len(xs) - np.sum(valid_index),
                     self.feature_cache.size(self.fea_type))
215 216 217 218

    def fit_log(self, records, plan_size):
        tic = time.time()

219 220 221 222 223 224 225 226
        # filter data, only pick the data with a same task
        data = []
        for inp, res in records:
            if inp.task.name == self.task.name and \
                            inp.config.template_key == self.task.config_space.template_key:
                data.append((inp, res))

        logger.debug("XGB load %d entries from history log file", len(data))
227

228 229 230
        # extract feature
        self._reset_pool(self.space, self.target, self.task)
        pool = self._get_pool()
231 232 233 234 235 236 237 238
        if self.fea_type == 'itervar':
            feature_extract_func = _extract_itervar_feature_log
        elif self.fea_type == 'knob':
            feature_extract_func = _extract_knob_feature_log
        elif self.fea_type == 'curve':
            feature_extract_func = _extract_curve_feature_log
        else:
            raise RuntimeError("Invalid feature type: " + self.fea_type)
239 240 241 242 243 244 245 246 247 248
        res = pool.map(feature_extract_func, data)

        # filter out feature with different shapes
        fea_len = len(self._get_feature([0])[0])

        xs, ys = [], []
        for x, y in res:
            if len(x) == fea_len:
                xs.append(x)
                ys.append(y)
249

250 251 252 253
        if len(xs) < 500:  # no enough samples
            return False

        xs, ys = np.array(xs), np.array(ys)
254 255
        x_train = xs
        y_train = ys
256 257
        y_max = np.max(y_train)
        y_train = y_train / max(y_max, 1e-8)
258 259 260 261 262 263

        index = np.random.permutation(len(x_train))
        dtrain = xgb.DMatrix(x_train[index], y_train[index])

        plan_size *= 2
        self.bst = xgb.train(self.xgb_params, dtrain,
264
                             num_boost_round=400,
265 266 267 268 269 270 271 272
                             callbacks=[custom_callback(
                                 stopping_rounds=100,
                                 metric='tr-a-recall@%d' % plan_size,
                                 evals=[(dtrain, 'tr')],
                                 maximize=True,
                                 fevals=[
                                     xgb_average_recalln_curve_score(plan_size),
                                 ],
273
                                 verbose_eval=self.log_interval)])
274

275
        logger.debug("XGB train: %.2f\tobs: %d", time.time() - tic, len(xs))
276

277 278
        return True

279 280 281 282 283
    def predict(self, xs, output_margin=False):
        feas = self._get_feature(xs)
        dtest = xgb.DMatrix(feas)

        if self.base_model:
284 285
            dtest.set_base_margin(self._base_model_discount() *
                                  self.base_model.predict(xs, output_margin=True))
286 287 288 289 290

        return self.bst.predict(dtest, output_margin=output_margin)

    def load_basemodel(self, base_model):
        self.base_model = base_model
291 292 293 294
        self.base_model._close_pool()
        self.base_model.upper_model = self

    def spawn_base_model(self):
295
        return XGBoostCostModel(self.task, self.fea_type, self.loss_type,
296
                                self.num_threads, self.log_interval, self)
297 298 299 300 301 302 303 304 305 306 307 308 309

    def _get_feature(self, indexes):
        """get features for indexes, run extraction if we do not have cache for them"""
        # free feature cache
        if self.feature_cache.size(self.fea_type) >= 100000:
            self.feature_cache.clear(self.fea_type)

        fea_cache = self.feature_cache.get(self.fea_type)

        indexes = np.array(indexes)
        need_extract = [x for x in indexes if x not in fea_cache]

        if need_extract:
310
            pool = self._get_pool()
311
            feas = pool.map(self.feature_extract_func, need_extract)
312 313 314
            for i, fea in zip(need_extract, feas):
                fea_cache[i] = fea

315 316 317 318 319 320 321
        feature_len = None
        for idx in indexes:
            if fea_cache[idx] is not None:
                feature_len = fea_cache[idx].shape[-1]
                break

        ret = np.empty((len(indexes), feature_len), dtype=np.float32)
322
        for i, ii in enumerate(indexes):
323 324
            t = fea_cache[ii]
            ret[i, :] = t if t is not None else 0
325 326
        return ret

327 328 329
    def __del__(self):
        self._close_pool()

330 331 332 333 334 335 336

_extract_space = None
_extract_target = None
_extract_task = None

def _extract_itervar_feature_index(index):
    """extract iteration var feature for an index in extract_space"""
337 338 339 340 341 342 343 344 345
    try:
        config = _extract_space.get(index)
        with _extract_target:
            sch, args = _extract_task.instantiate(config)
        fea = feature.get_itervar_feature_flatten(sch, args, take_log=True)
        fea = np.concatenate((fea, list(config.get_other_option().values())))
        return fea
    except Exception:  # pylint: disable=broad-except
        return None
346 347 348

def _extract_itervar_feature_log(arg):
    """extract iteration var feature for log items"""
349 350 351 352 353 354 355 356 357 358 359 360 361 362 363
    try:
        inp, res = arg
        config = inp.config
        with inp.target:
            sch, args = inp.task.instantiate(config)
        fea = feature.get_itervar_feature_flatten(sch, args, take_log=True)
        x = np.concatenate((fea, list(config.get_other_option().values())))

        if res.error_no == 0:
            y = inp.task.flop / np.mean(res.costs)
        else:
            y = 0.0
        return x, y
    except Exception:  # pylint: disable=broad-except
        return None
364 365 366

def _extract_knob_feature_index(index):
    """extract knob feature for an index in extract_space"""
367 368 369 370 371
    try:
        config = _extract_space.get(index)
        return config.get_flatten_feature()
    except Exception:  # pylint: disable=broad-except
        return None
372 373 374

def _extract_knob_feature_log(arg):
    """extract knob feature for log items"""
375 376 377 378 379 380 381 382 383 384 385 386 387 388
    try:
        inp, res = arg
        config = inp.config
        x = config.get_flatten_feature()

        if res.error_no == 0:
            with inp.target:  # necessary, for calculating flops of this task
                inp.task.instantiate(config)
            y = inp.task.flop / np.mean(res.costs)
        else:
            y = 0.0
        return x, y
    except Exception:  # pylint: disable=broad-except
        return None
389 390 391

def _extract_curve_feature_index(index):
    """extract sampled curve feature for an index in extract_space"""
392 393 394 395 396 397 398 399 400
    try:
        config = _extract_space.get(index)
        with _extract_target:
            sch, args = _extract_task.instantiate(config)
        fea = feature.get_buffer_curve_sample_flatten(sch, args, sample_n=20)
        fea = np.concatenate((fea, list(config.get_other_option().values())))
        return np.array(fea)
    except Exception:  # pylint: disable=broad-except
        return None
401 402 403

def _extract_curve_feature_log(arg):
    """extract sampled curve feature for log items"""
404 405 406 407 408 409 410 411 412 413 414 415 416 417 418
    try:
        inp, res = arg
        config = inp.config
        with inp.target:
            sch, args = inp.task.instantiate(config)
        fea = feature.get_buffer_curve_sample_flatten(sch, args, sample_n=20)
        x = np.concatenate((fea, list(config.get_other_option().values())))

        if res.error_no == 0:
            y = inp.task.flop / np.mean(res.costs)
        else:
            y = 0.0
        return x, y
    except Exception:  # pylint: disable=broad-except
        return None
419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489

def custom_callback(stopping_rounds, metric, fevals, evals=(), log_file=None,
                    maximize=False, verbose_eval=True):
    """callback function for xgboost to support multiple custom evaluation functions"""
    from xgboost.core import EarlyStopException
    from xgboost.callback import _fmt_metric
    from xgboost.training import aggcv

    state = {}
    metric_shortname = metric.split("-")[1]

    def init(env):
        """internal function"""
        bst = env.model

        state['maximize_score'] = maximize
        state['best_iteration'] = 0
        if maximize:
            state['best_score'] = float('-inf')
        else:
            state['best_score'] = float('inf')

        if bst is not None:
            if bst.attr('best_score') is not None:
                state['best_score'] = float(bst.attr('best_score'))
                state['best_iteration'] = int(bst.attr('best_iteration'))
                state['best_msg'] = bst.attr('best_msg')
            else:
                bst.set_attr(best_iteration=str(state['best_iteration']))
                bst.set_attr(best_score=str(state['best_score']))
        else:
            assert env.cvfolds is not None

    def callback(env):
        """internal function"""
        if not state:
            init(env)

        bst = env.model
        i = env.iteration
        cvfolds = env.cvfolds

        res_dict = {}

        ##### evaluation #####
        if cvfolds is not None:
            for feval in fevals:
                tmp = aggcv([f.eval(i, feval) for f in cvfolds])
                for k, mean, std in tmp:
                    res_dict[k] = [mean, std]
        else:
            for feval in fevals:
                bst_eval = bst.eval_set(evals, i, feval)
                res = [x.split(':') for x in bst_eval.split()]
                for kv in res[1:]:
                    res_dict[kv[0]] = [float(kv[1])]

        eval_res = []
        keys = list(res_dict.keys())
        keys.sort(key=lambda x: x if metric_shortname not in x else "a" + x)
        for key in keys:
            v = res_dict[key]
            eval_res.append([key] + v)

        ##### print eval result #####
        infos = ["XGB iter: %3d" % i]
        for item in eval_res:
            if 'null' in item[0]:
                continue
            infos.append("%s: %.6f" % (item[0], item[1]))

490
        if not isinstance(verbose_eval, bool) and verbose_eval and i % verbose_eval == 0:
491
            logger.debug("\t".join(infos))
492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522
        if log_file:
            with open(log_file, "a") as fout:
                fout.write("\t".join(infos) + '\n')

        ##### choose score and do early stopping #####
        score = None
        for item in eval_res:
            if item[0] == metric:
                score = item[1]
                break
        assert score is not None

        best_score = state['best_score']
        best_iteration = state['best_iteration']
        maximize_score = state['maximize_score']
        if (maximize_score and score > best_score) or \
                (not maximize_score and score < best_score):
            msg = '[%d] %s' % (
                env.iteration,
                '\t'.join([_fmt_metric(x) for x in eval_res]))
            state['best_msg'] = msg
            state['best_score'] = score
            state['best_iteration'] = env.iteration
            # save the property to attributes, so they will occur in checkpoint.
            if env.model is not None:
                env.model.set_attr(best_score=str(state['best_score']),
                                   best_iteration=str(state['best_iteration']),
                                   best_msg=state['best_msg'])
        elif env.iteration - best_iteration >= stopping_rounds:
            best_msg = state['best_msg']
            if verbose_eval and env.rank == 0:
523
                logger.debug("XGB stopped. Best iteration: %s ", best_msg)
524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584
            raise EarlyStopException(best_iteration)

    return callback


# feval wrapper for xgboost
def xgb_max_curve_score(N):
    """evaluate max curve score for xgb"""
    def feval(preds, labels):
        labels = labels.get_label()
        trials = np.argsort(preds)[::-1]
        scores = labels[trials]
        curve = max_curve(scores)
        return "Smax@%d" % N, curve[N] / np.max(labels)
    return feval

def xgb_recalln_curve_score(N):
    """evaluate recall-n curve score for xgb"""
    def feval(preds, labels):
        labels = labels.get_label()
        trials = np.argsort(preds)[::-1]
        ranks = get_rank(labels[trials])
        curve = recall_curve(ranks)
        return "recall@%d" % N, curve[N]
    return feval

def xgb_average_recalln_curve_score(N):
    """evaluate average recall-n curve score for xgb"""
    def feval(preds, labels):
        labels = labels.get_label()
        trials = np.argsort(preds)[::-1]
        ranks = get_rank(labels[trials])
        curve = recall_curve(ranks)
        return "a-recall@%d" % N, np.sum(curve[:N]) / N
    return feval

def xgb_recallk_curve_score(N, topk):
    """evaluate recall-k curve score for xgb"""
    def feval(preds, labels):
        labels = labels.get_label()
        trials = np.argsort(preds)[::-1]
        ranks = get_rank(labels[trials])
        curve = recall_curve(ranks, topk)
        return "recall@%d" % topk, curve[N]
    return feval

def xgb_cover_curve_score(N):
    """evaluate cover curve score for xgb"""
    def feval(preds, labels):
        labels = labels.get_label()
        trials = np.argsort(preds)[::-1]
        ranks = get_rank(labels[trials])
        curve = cover_curve(ranks)
        return "cover@%d" % N, curve[N]
    return feval

def xgb_null_score(_):
    """empty score function for xgb"""
    def feval(__, ___):
        return "null", 0
    return feval