Source code for apax.config.optuna_config

from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, Type

from pydantic import BaseModel, PositiveInt

if TYPE_CHECKING:
    import optuna


def get_pruner(name: str) -> Type["optuna.pruners.BasePruner"]:
    """Get the pruner class from the name.

    Args:
        name (str): name of pruner

    Returns:
        Type[optuna.pruners.BasePruner]: uninstantiated pruner
    """

    try:
        import optuna
    except ImportError as e:
        raise ImportError(
            "optuna is required for hyperparameter optimisation. "
            "Install it via `pip install optuna`."
        ) from e

    if name not in optuna.pruners.__all__:
        raise ValueError(f"pruner with name {name} not in optuna.pruners")
    return getattr(optuna.pruners, name)


def get_sampler(name: str) -> Type["optuna.samplers.BaseSampler"]:
    """Get the sampler class from the name.

    Args:
        name (str): name of sampler

    Returns:
        Type[optuna.samplers.BaseSampler]: uninstantiated sampler

    Notes:
        Also can include the AutoSampler, which automatically infer the "best"
            sampler based on the study parameters, see
            https://hub.optuna.org/samplers/auto_sampler
    """
    try:
        import optuna
    except ImportError as e:
        raise ImportError(
            "optuna is required for hyperparameter optimisation. "
            "Install it via `pip install optuna`."
        ) from e

    if name == "AutoSampler":
        try:
            import optunahub

            # Requires optunahub, scipy, cmaes and torch.
            return optunahub.load_module("samplers/auto_sampler").AutoSampler
        except ImportError as e:
            raise ImportError(
                f"sampler {name} requires optunahub. Set sampler to other or install optunahub"
            ) from e

    if name not in optuna.samplers.__all__:
        raise ValueError(f"sampler with name {name} not in optuna.samplers")
    return getattr(optuna.samplers, name)


[docs] class OptunaPrunerConfig(BaseModel, extra="forbid"): """Config for optuna pruner. Attributes: name (str): Name of pruner interval (PositiveInt): Interval to check whether a trial should be pruned kwargs (dict[str, Any]): kwargs passed to pruner """ name: str interval: PositiveInt = 1 kwargs: dict[str, Any] = {}
[docs] class OptunaSamplerConfig(BaseModel, extra="forbid"): """Config for optuna sampler. Attributes: name (str): Name of sampler. Default: "AutoSampler" kwargs (str): kwargs passed to sampler """ name: str = "AutoSampler" kwargs: dict[str, Any] = {}
[docs] class OptunaConfig(BaseModel, extra="forbid"): """Configuration of optuna study. Attributes: n_trials (PositiveInt): number of trials search_space (dict[str, dict[str, Any]]): dictionary indicating the search space to use for the hyperparameter optimization. Keys should be keys in the training configuration, and if the keys are nested, they should be prefixed with the parent key, and an underscore, see the example below. seed (int): seed to use for sampling monitor (str): metric to monitor. Default: "val_loss" study_name (str): name of study. Default: "study" study_log_file (str): path to study log file. Default: "study.log" sampler_config (OptunaSamplerConfig): sampler configuration pruner_config (Optional[OptunaPrunerConfig]): pruner configuration Default: None Example: For a search space with varying radius for the environment between 3 and 8 Angstrom, the number of tensor contractions between 5 and 8, and number of epochs between 100 and 200. .. code-block:: python search_space = { 'model_basis_r_max': {'type': 'float', 'low': 3, 'high': 8}, 'model_n_contr': {'type': 'int', 'low': 5, 'high': 8}, 'n_epochs': {'type': 'int', 'low': 100, 'high': 200}, } """ n_trials: PositiveInt search_space: dict[str, dict[str, Any]] seed: int = 1 monitor: str = "val_loss" study_name: str = "study" study_log_file: str | Path = "study.log" sampler_config: OptunaSamplerConfig = OptunaSamplerConfig() pruner_config: Optional[OptunaPrunerConfig] = None
def get_pruner_from_config( optuna_config: OptunaConfig, ) -> Optional["optuna.pruners.BasePruner"]: """Get the instantiated pruner from the optuna configuration Args: optuna_config (OptunaConfig): configuration for study Returns: Optional[optuna.pruners.BasePruner]: instantiated pruner. None if optuna_config.pruner_config is None """ if optuna_config.pruner_config is None: return None pruner_kwargs = optuna_config.pruner_config.kwargs.copy() pruner_class = get_pruner(optuna_config.pruner_config.name) return pruner_class(**pruner_kwargs) def get_sampler_from_config(optuna_config: OptunaConfig) -> "optuna.samplers.BaseSampler": """Get the instantiated sampler from the optuna configuration Args: optuna_config (OptunaConfig): configuration for study Returns: Optional[optuna.pruners.BaseSampler]: instantiated sampler. """ # See https://optuna.readthedocs.io/en/stable/faq.html#how-can-i-obtain-reproducible-optimization-results sampler_kwargs = optuna_config.sampler_config.kwargs.copy() sampler_kwargs["seed"] = optuna_config.seed sampler_class = get_sampler(optuna_config.sampler_config.name) return sampler_class(**sampler_kwargs)