Source code for hypernets.dispatchers.in_process_dispatcher

# -*- coding:utf-8 -*-
import gc

from .cfg import DispatchCfg as c
from ..core.callbacks import EarlyStoppingError
from ..core.dispatcher import Dispatcher
from ..core.trial import Trial
from ..tabular import get_tool_box
from ..utils import logging, fs, const

logger = logging.get_logger(__name__)


[docs]class InProcessDispatcher(Dispatcher): def __init__(self, models_dir): super(InProcessDispatcher, self).__init__() self.models_dir = models_dir fs.makedirs(models_dir, exist_ok=True)
[docs] def dispatch(self, hyper_model, X, y, X_eval, y_eval, X_test, cv, num_folds, max_trials, dataset_id, trial_store, **fit_kwargs): retry_limit = c.trial_retry_limit trial_no = 1 retry_counter = 0 space_options = {} if hyper_model.searcher.kind() == const.SEARCHER_MOO: if 'feature_usage' in [_.name for _ in hyper_model.searcher.objectives]: tb = get_tool_box(X, y) preprocessor = tb.general_preprocessor(X) estimator = tb.general_estimator(X, y, task=hyper_model.task) estimator.fit(preprocessor.fit_transform(X, y), y) importances = list(zip(estimator.feature_name_, estimator.feature_importances_)) space_options['importances'] = importances while trial_no <= max_trials: gc.collect() try: space_sample = hyper_model.searcher.sample(space_options=space_options) if hyper_model.history.is_existed(space_sample): if retry_counter >= retry_limit: logger.info(f'Unable to take valid sample and exceed the retry limit {retry_limit}.') break trial = hyper_model.history.get_trial(space_sample) for callback in hyper_model.callbacks: try: callback.on_skip_trial(hyper_model, space_sample, trial_no, 'trial_existed', trial.reward, False, trial.elapsed) except EarlyStoppingError: raise except Exception as e: logger.warn(e) retry_counter += 1 continue if trial_store is not None: trial_hit = trial_store.get(dataset_id, space_sample) if trial_hit is not None and fs.exists(trial_hit.model_file): reward = trial_hit.reward elapsed = trial_hit.elapsed trial = Trial(space_sample, trial_no, reward=reward, elapsed=elapsed, model_file=trial_hit.model_file, succeeded=trial_hit.succeeded) trial.memo = trial_hit.memo.copy() trial.iteration_scores = trial_hit.iteration_scores.copy() improved = hyper_model.history.append(trial) hyper_model.searcher.update_result(space_sample, reward) for callback in hyper_model.callbacks: try: callback.on_skip_trial(hyper_model, space_sample, trial_no, 'hit_trial_store', reward, improved, elapsed) except EarlyStoppingError: raise except Exception as e: logger.warn(e) continue for callback in hyper_model.callbacks: try: callback.on_trial_begin(hyper_model, space_sample, trial_no) except EarlyStoppingError: raise except Exception as e: logger.warn(e) model_file = '%s/%05d_%s.pkl' % (self.models_dir, trial_no, space_sample.space_id) trial = hyper_model._run_trial(space_sample, trial_no, X, y, X_eval, y_eval, X_test, cv, num_folds, model_file, **fit_kwargs) if trial.succeeded: improved = hyper_model.history.append(trial) for callback in hyper_model.callbacks: try: callback.on_trial_end(hyper_model, space_sample, trial_no, trial.reward, improved, trial.elapsed) except EarlyStoppingError: raise except Exception as e: logger.warn(e) else: hyper_model.history.append(trial) for callback in hyper_model.callbacks: try: callback.on_trial_error(hyper_model, space_sample, trial_no) except EarlyStoppingError: raise except Exception as e: logger.warn(e) if logger.is_info_enabled(): msg = f'Trial {trial_no} done, reward: {trial.reward}, ' \ f'best_trial_no:{hyper_model.best_trial_no}, best_reward:{hyper_model.best_reward}\n' logger.info(msg) if trial_store is not None: trial_store.put(dataset_id, trial) except EarlyStoppingError: break except Exception as e: import sys import traceback msg = f'{">" * 20} Trial {trial_no} failed! {"<" * 20}\n' \ + f'{e.__class__.__name__}: {e}\n' \ + traceback.format_exc() \ + '*' * 50 logger.error(msg) finally: trial_no += 1 retry_counter = 0 return trial_no