Commit 8ad36a17 by Lianmin Zheng Committed by Wuwei Lin

[AutoTVM] Fix hang/crash issues on feature extraction (#3689)

* [AutoTVM] Fix hang/crash issues on feature extraction

* Update xgboost_cost_model.py

* fix lint
parent f9ba0db3
......@@ -312,9 +312,16 @@ class XGBoostCostModel(CostModel):
for i, fea in zip(need_extract, feas):
fea_cache[i] = fea
ret = np.empty((len(indexes), fea_cache[indexes[0]].shape[-1]), dtype=np.float32)
feature_len = None
for idx in indexes:
if fea_cache[idx] is not None:
feature_len = fea_cache[idx].shape[-1]
break
ret = np.empty((len(indexes), feature_len), dtype=np.float32)
for i, ii in enumerate(indexes):
ret[i, :] = fea_cache[ii]
t = fea_cache[ii]
ret[i, :] = t if t is not None else 0
return ret
def __del__(self):
......@@ -327,15 +334,19 @@ _extract_task = None
def _extract_itervar_feature_index(index):
"""extract iteration var feature for an index in extract_space"""
try:
config = _extract_space.get(index)
with _extract_target:
sch, args = _extract_task.instantiate(config)
fea = feature.get_itervar_feature_flatten(sch, args, take_log=True)
fea = np.concatenate((fea, list(config.get_other_option().values())))
return fea
except Exception: # pylint: disable=broad-except
return None
def _extract_itervar_feature_log(arg):
"""extract iteration var feature for log items"""
try:
inp, res = arg
config = inp.config
with inp.target:
......@@ -348,14 +359,20 @@ def _extract_itervar_feature_log(arg):
else:
y = 0.0
return x, y
except Exception: # pylint: disable=broad-except
return None
def _extract_knob_feature_index(index):
"""extract knob feature for an index in extract_space"""
try:
config = _extract_space.get(index)
return config.get_flatten_feature()
except Exception: # pylint: disable=broad-except
return None
def _extract_knob_feature_log(arg):
"""extract knob feature for log items"""
try:
inp, res = arg
config = inp.config
x = config.get_flatten_feature()
......@@ -367,18 +384,24 @@ def _extract_knob_feature_log(arg):
else:
y = 0.0
return x, y
except Exception: # pylint: disable=broad-except
return None
def _extract_curve_feature_index(index):
"""extract sampled curve feature for an index in extract_space"""
try:
config = _extract_space.get(index)
with _extract_target:
sch, args = _extract_task.instantiate(config)
fea = feature.get_buffer_curve_sample_flatten(sch, args, sample_n=20)
fea = np.concatenate((fea, list(config.get_other_option().values())))
return np.array(fea)
except Exception: # pylint: disable=broad-except
return None
def _extract_curve_feature_log(arg):
"""extract sampled curve feature for log items"""
try:
inp, res = arg
config = inp.config
with inp.target:
......@@ -391,7 +414,8 @@ def _extract_curve_feature_log(arg):
else:
y = 0.0
return x, y
except Exception: # pylint: disable=broad-except
return None
def custom_callback(stopping_rounds, metric, fevals, evals=(), log_file=None,
maximize=False, verbose_eval=True):
......
......@@ -131,7 +131,9 @@ void TouchExtractor::ExitItervar_() {
}
itervar_stack_.pop_back();
topdown_product_ /= itervar_map[var].length;
int64_t length = itervar_map[var].length;
if (length != 0)
topdown_product_ /= length;
int64_t bottomup_product = -1;
for (auto kv : itervar_map[var].touch_feature) {
bottomup_product = std::max(bottomup_product, kv.second.count * kv.second.reuse);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment