Source code for hypernets.discriminators

# -*- coding:utf-8 -*-
__author__ = 'yangjian'
"""

"""

from ._base import get_previous_trials_scores, get_percentile_score, UnPromisingTrial, BaseDiscriminator
from .percentile import PercentileDiscriminator, ProgressivePercentileDiscriminator, OncePercentileDiscriminator

_discriminators = {
    'percentile': PercentileDiscriminator,
    'once_percentile': OncePercentileDiscriminator,
    'percentile_discriminator': PercentileDiscriminator,
    'progressive': ProgressivePercentileDiscriminator,
    'progressive_percentile': ProgressivePercentileDiscriminator,
    'progressive_percentile_discriminator': ProgressivePercentileDiscriminator,
}


def _get_discriminator_cls(identifier):
    if isinstance(identifier, str):
        cls = _discriminators.get(identifier.lower(), None)
        if cls is not None:
            return cls
    elif isinstance(identifier, type) and issubclass(identifier, BaseDiscriminator):
        return identifier

    raise ValueError(f'Illegal discriminator:{identifier}')


[docs]def make_discriminator(cls, optimize_direction='min', **kwargs): cls = _get_discriminator_cls(cls) if cls == PercentileDiscriminator: default_kwargs = dict(percentile=0) elif cls == ProgressivePercentileDiscriminator: default_kwargs = dict(percentile_list=[0]) else: default_kwargs = {} kwargs = {**default_kwargs, **kwargs} discriminator = cls(optimize_direction=optimize_direction, **kwargs) return discriminator