import logging
import os
from pathlib import Path
from typing import List, Literal, Optional, Union
import yaml
from pydantic import (
BaseModel,
ConfigDict,
Field,
NonNegativeFloat,
PositiveFloat,
PositiveInt,
create_model,
model_validator,
)
from typing_extensions import Annotated
from apax.data.statistics import scale_method_list, shift_method_list
log = logging.getLogger(__name__)
[docs]
class DataConfig(BaseModel, extra="forbid"):
"""
Configuration for data loading, preprocessing and training.
Parameters
----------
directory : str, required
| Path to directory where training results and checkpoints are written.
experiment : str, required
| Model name distinguishing from others in directory.
data_path : str, required if train_data_path and val_data_path is not specified
| Path to single dataset file.
train_data_path : str, required if data_path is not specified
| Path to training dataset.
val_data_path : str, required if data_path is not specified
| Path to validation dataset.
test_data_path : str, optional
| Path to test dataset.
n_train : int, default = 1000
| Number of training datapoints from `data_path`.
n_valid : int, default = 100
| Number of validation datapoints from `data_path`.
batch_size : int, default = 32
| Number of training examples to be evaluated at once.
valid_batch_size : int, default = 100
| Number of validation examples to be evaluated at once.
shuffle_buffer_size : int, default = 1000
| Size of the `tf.data` shuffle buffer.
additional_properties_info : dict, optional
| dict of property name, shape (ragged or fixed) pairs
energy_regularisation :
| Magnitude of the regularization in the per-element energy regression.
"""
directory: str
experiment: str
ds_type: Literal["cached", "otf"] = "cached"
data_path: Optional[str] = None
train_data_path: Optional[str] = None
val_data_path: Optional[str] = None
test_data_path: Optional[str] = None
n_train: PositiveInt = 1000
n_valid: PositiveInt = 100
batch_size: PositiveInt = 32
valid_batch_size: PositiveInt = 100
shuffle_buffer_size: PositiveInt = 1000
additional_properties_info: dict[str, str] = {}
shift_method: str = "per_element_regression_shift"
shift_options: dict = {"energy_regularisation": 1.0}
scale_method: str = "per_element_force_rms_scale"
scale_options: Optional[dict] = {}
pos_unit: Optional[str] = "Ang"
energy_unit: Optional[str] = "eV"
@model_validator(mode="after")
def set_data_or_train_val_path(self):
not_data_path = self.data_path is None
not_train_path = self.train_data_path is None
neither_set = not_data_path and not_train_path
both_set = not not_data_path and not not_train_path
if neither_set or both_set:
raise ValueError("Please specify either data_path or train_data_path")
return self
@model_validator(mode="after")
def validate_shift_scale_methods(self):
method_lists = [shift_method_list, scale_method_list]
requested_methods = [self.shift_method, self.scale_method]
requested_options = [self.shift_options, self.scale_options]
cases = zip(method_lists, requested_methods, requested_options)
for method_list, requested_method, requested_params in cases:
methods = {method.name: method for method in method_list}
# check if method exists
if requested_method not in methods.keys():
raise KeyError(
f"The initialization method '{requested_method}' is not among the"
f" implemented methods. Choose from {methods.keys()}"
)
# check if parameters names are complete and correct
method = methods[requested_method]
fields = {
name: (dtype, ...)
for name, dtype in zip(method.parameters, method.dtypes)
}
MethodConfig = create_model(
f"{method.name}Config", __config__=ConfigDict(extra="forbid"), **fields
)
_ = MethodConfig(**requested_params)
return self
@property
def model_version_path(self):
version_path = Path(self.directory) / self.experiment
return version_path
@property
def best_model_path(self):
return self.model_version_path / "best"
[docs]
class ModelConfig(BaseModel, extra="forbid"):
"""
Configuration for the model.
Parameters
----------
n_basis : PositiveInt, default = 7
Number of uncontracted gaussian basis functions.
n_radial : PositiveInt, default = 5
Number of contracted basis functions.
r_min : NonNegativeFloat, default = 0.5
Position of the first uncontracted basis function's mean.
r_max : PositiveFloat, default = 6.0
Cutoff radius of the descriptor.
nn : List[PositiveInt], default = [512, 512]
Number of hidden layers and units in those layers.
b_init : Literal["normal", "zeros"], default = "normal"
Initialization scheme for the neural network biases.
emb_init : Optional[str], default = "uniform"
Initialization scheme for embedding layer weights.
use_zbl : bool, default = False
Whether to include the ZBL correction.
calc_stress : bool, default = False
Whether to calculate stress during model evaluation.
descriptor_dtype : Literal["fp32", "fp64"], default = "fp64"
Data type for descriptor calculations.
readout_dtype : Literal["fp32", "fp64"], default = "fp32"
Data type for readout calculations.
scale_shift_dtype : Literal["fp32", "fp64"], default = "fp32"
Data type for scale and shift parameters.
"""
n_basis: PositiveInt = 7
n_radial: PositiveInt = 5
r_min: NonNegativeFloat = 0.5
r_max: PositiveFloat = 6.0
n_contr: int = -1
emb_init: Optional[str] = "uniform"
nn: List[PositiveInt] = [512, 512]
b_init: Literal["normal", "zeros"] = "normal"
# corrections
use_zbl: bool = False
calc_stress: bool = False
descriptor_dtype: Literal["fp32", "fp64"] = "fp64"
readout_dtype: Literal["fp32", "fp64"] = "fp32"
scale_shift_dtype: Literal["fp32", "fp64"] = "fp32"
def get_dict(self):
import jax.numpy as jnp
model_dict = self.model_dump()
prec_dict = {"fp32": jnp.float32, "fp64": jnp.float64}
model_dict["descriptor_dtype"] = prec_dict[model_dict["descriptor_dtype"]]
model_dict["readout_dtype"] = prec_dict[model_dict["readout_dtype"]]
model_dict["scale_shift_dtype"] = prec_dict[model_dict["scale_shift_dtype"]]
return model_dict
[docs]
class OptimizerConfig(BaseModel, frozen=True, extra="forbid"):
"""
Configuration of the optimizer.
Learning rates of 0 will freeze the respective parameters.
Parameters
----------
opt_name : str, default = "adam"
Name of the optimizer. Can be any `optax` optimizer.
emb_lr : NonNegativeFloat, default = 0.02
Learning rate of the elemental embedding contraction coefficients.
nn_lr : NonNegativeFloat, default = 0.03
Learning rate of the neural network parameters.
scale_lr : NonNegativeFloat, default = 0.001
Learning rate of the elemental output scaling factors.
shift_lr : NonNegativeFloat, default = 0.05
Learning rate of the elemental output shifts.
zbl_lr : NonNegativeFloat, default = 0.001
Learning rate of the ZBL correction parameters.
transition_begin : int, default = 0
Number of training steps (not epochs) before the start of the linear
learning rate schedule.
opt_kwargs : dict, default = {}
Optimizer keyword arguments. Passed to the `optax` optimizer.
sam_rho : NonNegativeFloat, default = 0.0
Rho parameter for Sharpness-Aware Minimization.
"""
opt_name: str = "adam"
emb_lr: NonNegativeFloat = 0.02
nn_lr: NonNegativeFloat = 0.03
scale_lr: NonNegativeFloat = 0.001
shift_lr: NonNegativeFloat = 0.05
zbl_lr: NonNegativeFloat = 0.001
transition_begin: int = 0
opt_kwargs: dict = {}
sam_rho: NonNegativeFloat = 0.0
[docs]
class MetricsConfig(BaseModel, extra="forbid"):
"""
Configuration for the metrics collected during training.
Parameters
----------
name : str
Keyword of the quantity, e.g., 'energy'.
reductions : List[str]
List of reductions performed on the difference between target and predictions.
Can be 'mae', 'mse', 'rmse' for energies and forces.
For forces, 'angle' can also be used.
"""
name: str
reductions: List[str]
[docs]
class LossConfig(BaseModel, extra="forbid"):
"""
Configuration of the loss functions used during training.
Parameters
----------
name : str
Keyword of the quantity, e.g., 'energy'.
loss_type : str, optional
Weighting scheme for atomic contributions. See the MLIP package
for reference 10.1088/2632-2153/abc9fe for details, by default "mse".
weight : NonNegativeFloat, optional
Weighting factor in the overall loss function, by default 1.0.
atoms_exponent : NonNegativeFloat, optional
Exponent for atomic contributions weighting, by default 1.
parameters : dict, optional
Additional parameters for configuring the loss function, by default {}.
Notes
-----
This class specifies the configuration of the loss functions used during training.
"""
name: str
loss_type: str = "mse"
weight: NonNegativeFloat = 1.0
atoms_exponent: NonNegativeFloat = 1
parameters: dict = {}
[docs]
class CSVCallback(BaseModel, frozen=True, extra="forbid"):
"""
Configuration of the CSVCallback.
Parameters
----------
name: Keyword of the callback used..
"""
name: Literal["csv"]
[docs]
class TBCallback(BaseModel, frozen=True, extra="forbid"):
"""
Configuration of the TensorBoard callback.
Parameters
----------
name: Keyword of the callback used..
"""
name: Literal["tensorboard"]
[docs]
class MLFlowCallback(BaseModel, frozen=True, extra="forbid"):
"""
Configuration of the MLFlow callback.
Parameters
----------
name: Keyword of the callback used.
experiment: Path to the MLFlow experiment, e.g. /Users/<user>/<my_experiment>
"""
name: Literal["mlflow"]
experiment: str
CallBack = Annotated[
Union[CSVCallback, TBCallback, MLFlowCallback], Field(discriminator="name")
]
[docs]
class TrainProgressbarConfig(BaseModel, extra="forbid"):
"""
Configuration of progressbars.
Parameters
----------
disable_epoch_pbar: Set to True to disable the epoch progress bar.
disable_batch_pbar: Set to True to disable the batch progress bar.
"""
disable_epoch_pbar: bool = False
disable_batch_pbar: bool = True
[docs]
class CheckpointConfig(BaseModel, extra="forbid"):
"""
Checkpoint configuration.
Parameters
----------
ckpt_interval: Number of epochs between checkpoints.
base_model_checkpoint: Path to the folder containing a pre-trained model ckpt.
reset_layers: List of layer names for which the parameters will be reinitialized.
"""
ckpt_interval: PositiveInt = 1
base_model_checkpoint: Optional[str] = None
reset_layers: List[str] = []
[docs]
class Config(BaseModel, frozen=True, extra="forbid"):
"""
Main configuration of a apax training run. Parameter that are cofig classes will
be generated by parsing the config.yaml file and are specified
as shown :ref:`here <train_config>`:
Example
-------
.. code-block:: yaml
data:
directory: models/
experiment: apax
.
.
Parameters
----------
n_epochs : int, required
| Number of training epochs.
patience : int, optional
| Number of epochs without improvement before trainings gets terminated.
seed : int, default = 1
| Random seed.
n_models : int, default = 1
| Number of models to be trained at once.
n_jitted_steps : int, default = 1
| Number of train batches to be processed in a compiled loop.
| Can yield singificant speedups for small structures or small batch sizes.
data : :class:`.DataConfig`
| Data configuration.
model : :class:`.ModelConfig`
| Model configuration.
metrics : List of :class:`.MetricsConfig`
| Metrics configuration.
loss : List of :class:`.LossConfig`
| Loss configuration.
optimizer : :class:`.OptimizerConfig`
| Loss optimizer configuration.
callbacks : List of various CallBack classes
| Possible callbacks are :class:`.CSVCallback`,
| :class:`.TBCallback`, :class:`.MLFlowCallback`
progress_bar : :class:`.TrainProgressbarConfig`
| Progressbar configuration.
checkpoints : :class:`.CheckpointConfig`
| Checkpoint configuration.
data_parallel : bool, default = True
| Automatically uses all available GPUs for data parallel training.
| Set to false to force single device training.
"""
n_epochs: PositiveInt
patience: Optional[PositiveInt] = None
seed: int = 1
n_models: int = 1
n_jitted_steps: int = 1
data_parallel: bool = True
data: DataConfig
model: ModelConfig = ModelConfig()
metrics: List[MetricsConfig] = []
loss: List[LossConfig]
optimizer: OptimizerConfig = OptimizerConfig()
callbacks: List[CallBack] = [CSVCallback(name="csv")]
progress_bar: TrainProgressbarConfig = TrainProgressbarConfig()
checkpoints: CheckpointConfig = CheckpointConfig()
[docs]
def dump_config(self, save_path):
"""
Writes the current config file to the specified directory.
Parameters
----------
save_path: Path to the directory.
"""
with open(os.path.join(save_path, "config.yaml"), "w") as conf:
yaml.dump(self.model_dump(), conf, default_flow_style=False)