Source code for hypernets.tabular.estimator_detector

# -*- coding:utf-8 -*-
"""

"""

import pandas as pd
from sklearn.datasets import make_classification, make_regression

from hypernets.core import randint
from hypernets.utils import logging, load_module, const

logger = logging.get_logger(__name__)


[docs]class EstimatorDetector: def __init__(self, name_or_cls, task, *, init_kwargs=None, fit_kwargs=None, n_samples=100, n_features=5): assert isinstance(name_or_cls, (str, type)) if init_kwargs is None: init_kwargs = {} if fit_kwargs is None: fit_kwargs = {} self.name_or_cls = name_or_cls self.task = task self.init_kwargs = init_kwargs self.fit_kwargs = fit_kwargs self.n_samples = n_samples self.n_features = n_features
[docs] def prepare_data(self): if self.task == const.TASK_BINARY: X, y = make_classification(n_samples=self.n_samples, n_features=self.n_features, n_classes=2, random_state=randint()) elif self.task == const.TASK_MULTICLASS: X, y = make_classification(n_samples=self.n_samples, n_features=self.n_features, n_classes=5, random_state=randint()) else: X, y = make_regression(n_samples=self.n_samples, n_features=self.n_features, random_state=randint()) X = pd.DataFrame(X, columns=[f'c{i}' for i in range(X.shape[1])]) return X, y
[docs] def get_estimator_cls(self): if isinstance(self.name_or_cls, str): estimator_cls = load_module(self.name_or_cls) else: estimator_cls = self.name_or_cls return estimator_cls
[docs] def create_estimator(self, estimator_cls): return estimator_cls(**self.init_kwargs)
[docs] def fit_estimator(self, estimator, X, y): return estimator.fit(X, y, **self.fit_kwargs)
def __call__(self, *args, **kwargs): result = set([]) # detect: installed try: estimator_cls = self.get_estimator_cls() result.add('installed') except ImportError: return result # detect: create estimator instance try: estimator = self.create_estimator(estimator_cls) result.add('initialized') except Exception as e: logger.info(e) return result # make training data X, y = self.prepare_data() # detect: fit try: self.fit_estimator(estimator, X, y) result.add('fitted') except Exception as e: logger.info(f'EstimatorDetector error: {e}, ') # logger.info(e) if logger.is_debug_enabled(): logger.debug(e) return result