import logging
import os
from pathlib import Path
from typing import List, Literal, Optional, Union
import yaml
from pydantic import (
BaseModel,
ConfigDict,
Field,
NonNegativeFloat,
PositiveFloat,
PositiveInt,
create_model,
model_validator,
)
from typing_extensions import Annotated
from apax.config.lr_config import CyclicCosineLR, LinearLR
from apax.data.statistics import scale_method_list, shift_method_list
log = logging.getLogger(__name__)
class DatasetConfig(BaseModel, extra="forbid"):
processing: str
class CachedDataset(DatasetConfig, extra="forbid"):
"""Dataset which pads everything (atoms, neighbors)
to the largest system in the dataset.
The NL is computed on the fly during the first epoch and stored to disk using
tf.data's cache.
Most performant option for datasets with samples of very similar size.
Parameters
----------
shuffle_buffer_size : int
| Size of the buffer that is shuffled by tf.data.
| Larger values require more RAM.
"""
processing: Literal["cached"] = "cached"
shuffle_buffer_size: PositiveInt = 1000
class OTFDataset(DatasetConfig, extra="forbid"):
"""Dataset which pads everything (atoms, neighbors)
to the largest system in the dataset.
The NL is computed on the fly and fed into a tf.data generator.
Mostly for internal purposes.
Parameters
----------
shuffle_buffer_size : int
| Size of the buffer that is shuffled by tf.data.
| Larger values require more RAM.
"""
processing: Literal["otf"] = "otf"
shuffle_buffer_size: PositiveInt = 1000
class PBPDatset(DatasetConfig, extra="forbid"):
"""Dataset which pads everything (atoms, neighbors)
to the next larges power of two.
This limits the compute wasted due to padding at the (negligible)
cost of some recompilations.
The NL is computed on-the-fly in parallel for `num_workers` of batches.
Does not use tf.data.
Most performant option for datasets with significantly differently sized systems
(e.g. MP, SPICE).
Parameters
----------
num_workers : int
| Number of batches to be processed in parallel.
reset_every : int
| Number of epochs before reinitializing the ProcessPoolExcecutor.
| Avoids memory leaks.
"""
processing: Literal["pbp"] = "pbp"
num_workers: PositiveInt = 10
reset_every: PositiveInt = 10
[docs]
class DataConfig(BaseModel, extra="forbid"):
"""
Configuration for data loading, preprocessing and training.
Parameters
----------
directory : str, required
| Path to directory where training results and checkpoints are written.
experiment : str, required
| Model name distinguishing from others in directory.
data_path : str, required if train_data_path and val_data_path is not specified
| Path to single dataset file.
train_data_path : str, required if data_path is not specified
| Path to training dataset.
val_data_path : str, required if data_path is not specified
| Path to validation dataset.
test_data_path : str, optional
| Path to test dataset.
n_train : int, default = 1000
| Number of training datapoints from `data_path`.
n_valid : int, default = 100
| Number of validation datapoints from `data_path`.
batch_size : int, default = 32
| Number of training examples to be evaluated at once.
valid_batch_size : int, default = 100
| Number of validation examples to be evaluated at once.
shuffle_buffer_size : int, default = 1000
| Size of the `tf.data` shuffle buffer.
additional_properties_info : dict, optional
| dict of property name, shape (ragged or fixed) pairs. Currently unused.
energy_regularisation :
| Magnitude of the regularization in the per-element energy regression.
"""
directory: str
experiment: str
dataset: Union[CachedDataset, OTFDataset, PBPDatset] = Field(
CachedDataset(processing="cached"), discriminator="processing"
)
data_path: Optional[str] = None
train_data_path: Optional[str] = None
val_data_path: Optional[str] = None
test_data_path: Optional[str] = None
n_train: PositiveInt = 1000
n_valid: PositiveInt = 100
batch_size: PositiveInt = 32
valid_batch_size: PositiveInt = 100
additional_properties_info: dict[str, str] = {}
shift_method: str = "per_element_regression_shift"
shift_options: dict = {"energy_regularisation": 1.0}
scale_method: str = "per_element_force_rms_scale"
scale_options: Optional[dict] = {}
pos_unit: Optional[str] = "Ang"
energy_unit: Optional[str] = "eV"
@model_validator(mode="after")
def set_data_or_train_val_path(self):
not_data_path = self.data_path is None
not_train_path = self.train_data_path is None
neither_set = not_data_path and not_train_path
both_set = not not_data_path and not not_train_path
if neither_set or both_set:
raise ValueError("Please specify either data_path or train_data_path")
return self
@model_validator(mode="after")
def validate_shift_scale_methods(self):
method_lists = [shift_method_list, scale_method_list]
requested_methods = [self.shift_method, self.scale_method]
requested_options = [self.shift_options, self.scale_options]
cases = zip(method_lists, requested_methods, requested_options)
for method_list, requested_method, requested_params in cases:
methods = {method.name: method for method in method_list}
# check if method exists
if requested_method not in methods.keys():
raise KeyError(
f"The initialization method '{requested_method}' is not among the"
f" implemented methods. Choose from {methods.keys()}"
)
# check if parameters names are complete and correct
method = methods[requested_method]
fields = {
name: (dtype, ...)
for name, dtype in zip(method.parameters, method.dtypes)
}
MethodConfig = create_model(
f"{method.name}Config", __config__=ConfigDict(extra="forbid"), **fields
)
_ = MethodConfig(**requested_params)
return self
@property
def model_version_path(self):
version_path = Path(self.directory) / self.experiment
return version_path
@property
def best_model_path(self):
return self.model_version_path / "best"
class GaussianBasisConfig(BaseModel, extra="forbid"):
"""
Gaussian primitive basis functions.
Parameters
----------
n_basis : PositiveInt, default = 7
Number of uncontracted basis functions.
r_min : NonNegativeFloat, default = 0.5
Position of the first uncontracted basis function's mean.
r_max : PositiveFloat, default = 6.0
Cutoff radius of the descriptor.
"""
name: Literal["gaussian"] = "gaussian"
n_basis: PositiveInt = 7
r_min: NonNegativeFloat = 0.5
r_max: PositiveFloat = 6.0
class BesselBasisConfig(BaseModel, extra="forbid"):
"""
Gaussian primitive basis functions.
Parameters
----------
n_basis : PositiveInt, default = 7
Number of uncontracted basis functions.
r_max : PositiveFloat, default = 6.0
Cutoff radius of the descriptor.
"""
name: Literal["bessel"] = "bessel"
n_basis: PositiveInt = 7
r_max: PositiveFloat = 6.0
BasisConfig = Union[GaussianBasisConfig, BesselBasisConfig]
class FullEnsembleConfig(BaseModel, extra="forbid"):
"""
Configuration for full model ensembles.
Usage can improve accuracy and stability at the cost of slower inference.
Uncertainties will generally not be calibrated.
Parameters
----------
n_members : int
Number of ensemble members.
"""
kind: Literal["full"] = "full"
n_members: int
class ShallowEnsembleConfig(BaseModel, extra="forbid"):
"""
Configuration for shallow (last layer) ensembles.
Allows use of probabilistic loss functions.
The predicted uncertainties should be well calibrated.
See 10.1088/2632-2153/ad594a for details.
Parameters
----------
n_members : int
Number of ensemble members.
force_variance : bool, default = True
Whether or not to compute force uncertainties.
Required for probabilistic force loss and calibration of force uncertainties.
Can lead to better force metrics but but enabling it introduces some non-negligible cost.
chunk_size : Optional[int], default = None
If set to an integer, the jacobian of ensemble energies wrt. to positions will be computed
in chunks of that size. This sacrifices some performance for the possibility to use relatively
large ensemble sizes.
"""
kind: Literal["shallow"] = "shallow"
n_members: int
force_variance: bool = True
chunk_size: Optional[int] = None
EnsembleConfig = Union[FullEnsembleConfig, ShallowEnsembleConfig]
[docs]
class ModelConfig(BaseModel, extra="forbid"):
"""
Configuration for the model.
Parameters
----------
basis : BasisConfig, default = GaussianBasisConfig()
Configuration for primitive basis funtions.
n_radial : PositiveInt, default = 5
Number of contracted basis functions.
n_contr : int, default = 8
How many gaussian moment contractions to use.
emb_init : Optional[str], default = "uniform"
Initialization scheme for embedding layer weights.
nn : List[PositiveInt], default = [512, 512]
Number of hidden layers and units in those layers.
w_init : Literal["normal", "lecun"], default = "normal"
Initialization scheme for the neural network weights.
b_init : Literal["normal", "zeros"], default = "normal"
Initialization scheme for the neural network biases.
use_ntk : bool, default = True
Whether or not to use NTK parametrization.
ensemble : Optional[EnsembleConfig], default = None
What kind of model ensemble to use (optional).
use_zbl : bool, default = False
Whether to include the ZBL correction.
calc_stress : bool, default = False
Whether to calculate stress during model evaluation.
descriptor_dtype : Literal["fp32", "fp64"], default = "fp64"
Data type for descriptor calculations.
readout_dtype : Literal["fp32", "fp64"], default = "fp32"
Data type for readout calculations.
scale_shift_dtype : Literal["fp32", "fp64"], default = "fp32"
Data type for scale and shift parameters.
"""
basis: BasisConfig = Field(GaussianBasisConfig(name="gaussian"), discriminator="name")
n_radial: PositiveInt = 5
n_contr: int = 8
emb_init: Optional[str] = "uniform"
nn: List[PositiveInt] = [512, 512]
w_init: Literal["normal", "lecun"] = "normal"
b_init: Literal["normal", "zeros"] = "normal"
use_ntk: bool = True
ensemble: Optional[EnsembleConfig] = None
# corrections
use_zbl: bool = False
calc_stress: bool = False
descriptor_dtype: Literal["fp32", "fp64"] = "fp64"
readout_dtype: Literal["fp32", "fp64"] = "fp32"
scale_shift_dtype: Literal["fp32", "fp64"] = "fp32"
def get_dict(self):
import jax.numpy as jnp
model_dict = self.model_dump()
prec_dict = {"fp32": jnp.float32, "fp64": jnp.float64}
model_dict["descriptor_dtype"] = prec_dict[model_dict["descriptor_dtype"]]
model_dict["readout_dtype"] = prec_dict[model_dict["readout_dtype"]]
model_dict["scale_shift_dtype"] = prec_dict[model_dict["scale_shift_dtype"]]
return model_dict
[docs]
class OptimizerConfig(BaseModel, frozen=True, extra="forbid"):
"""
Configuration of the optimizer.
Learning rates of 0 will freeze the respective parameters.
Parameters
----------
name : str, default = "adam"
Name of the optimizer. Can be any `optax` optimizer.
emb_lr : NonNegativeFloat, default = 0.02
Learning rate of the elemental embedding contraction coefficients.
nn_lr : NonNegativeFloat, default = 0.03
Learning rate of the neural network parameters.
scale_lr : NonNegativeFloat, default = 0.001
Learning rate of the elemental output scaling factors.
shift_lr : NonNegativeFloat, default = 0.05
Learning rate of the elemental output shifts.
zbl_lr : NonNegativeFloat, default = 0.001
Learning rate of the ZBL correction parameters.
schedule : LRSchedule = LinearLR
Learning rate schedule.
kwargs : dict, default = {}
Optimizer keyword arguments. Passed to the `optax` optimizer.
"""
name: str = "adam"
emb_lr: NonNegativeFloat = 0.02
nn_lr: NonNegativeFloat = 0.03
scale_lr: NonNegativeFloat = 0.001
shift_lr: NonNegativeFloat = 0.05
zbl_lr: NonNegativeFloat = 0.001
schedule: Union[LinearLR, CyclicCosineLR] = Field(
LinearLR(name="linear"), discriminator="name"
)
kwargs: dict = {}
[docs]
class MetricsConfig(BaseModel, extra="forbid"):
"""
Configuration for the metrics collected during training.
Parameters
----------
name : str
Keyword of the quantity, e.g., 'energy'.
reductions : List[str]
List of reductions performed on the difference between target and predictions.
Can be 'mae', 'mse', 'rmse' for energies and forces.
For forces, 'angle' can also be used.
"""
name: str
reductions: List[str]
[docs]
class LossConfig(BaseModel, extra="forbid"):
"""
Configuration of the loss functions used during training.
Parameters
----------
name : str
Keyword of the quantity, e.g., 'energy'.
loss_type : str, optional
Weighting scheme for atomic contributions. See the MLIP package
for reference 10.1088/2632-2153/abc9fe for details, by default "mse".
weight : NonNegativeFloat, optional
Weighting factor in the overall loss function, by default 1.0.
atoms_exponent : NonNegativeFloat, optional
Exponent for atomic contributions weighting, by default 1.
parameters : dict, optional
Additional parameters for configuring the loss function, by default {}.
Notes
-----
This class specifies the configuration of the loss functions used during training.
"""
name: str
loss_type: str = "mse"
weight: NonNegativeFloat = 1.0
atoms_exponent: NonNegativeFloat = 1
parameters: dict = {}
[docs]
class CSVCallback(BaseModel, frozen=True, extra="forbid"):
"""
Configuration of the CSVCallback.
Parameters
----------
name: Keyword of the callback used..
"""
name: Literal["csv"]
[docs]
class TBCallback(BaseModel, frozen=True, extra="forbid"):
"""
Configuration of the TensorBoard callback.
Parameters
----------
name: Keyword of the callback used..
"""
name: Literal["tensorboard"]
[docs]
class MLFlowCallback(BaseModel, frozen=True, extra="forbid"):
"""
Configuration of the MLFlow callback.
Parameters
----------
name: Keyword of the callback used.
experiment: Path to the MLFlow experiment, e.g. /Users/<user>/<my_experiment>
"""
name: Literal["mlflow"]
experiment: str
CallBack = Annotated[
Union[CSVCallback, TBCallback, MLFlowCallback], Field(discriminator="name")
]
[docs]
class TrainProgressbarConfig(BaseModel, extra="forbid"):
"""
Configuration of progressbars.
Parameters
----------
disable_epoch_pbar: Set to True to disable the epoch progress bar.
disable_batch_pbar: Set to True to disable the batch progress bar.
"""
disable_epoch_pbar: bool = False
disable_batch_pbar: bool = True
[docs]
class CheckpointConfig(BaseModel, extra="forbid"):
"""
Checkpoint configuration.
Parameters
----------
ckpt_interval: Number of epochs between checkpoints.
base_model_checkpoint: Path to the folder containing a pre-trained model ckpt.
reset_layers: List of layer names for which the parameters will be reinitialized.
"""
ckpt_interval: PositiveInt = 1
base_model_checkpoint: Optional[str] = None
reset_layers: List[str] = []
class WeightAverage(BaseModel, extra="forbid"):
"""Applies an exponential moving average to model parameters.
Parameters
----------
ema_start : int, default = 1
Epoch at which to start averaging models.
alpha : float, default = 0.9
How much of the new model to use. 1.0 would mean no averaging, 0.0 no updates.
"""
ema_start: int = 0
alpha: float = 0.9
[docs]
class Config(BaseModel, frozen=True, extra="forbid"):
"""
Main configuration of a apax training run. Parameter that are config classes will
be generated by parsing the config.yaml file and are specified
as shown :ref:`here <train_config>`:
Example
-------
.. code-block:: yaml
data:
directory: models/
experiment: apax
.
.
Parameters
----------
n_epochs : int, required
| Number of training epochs.
patience : int, optional
| Number of epochs without improvement before trainings gets terminated.
seed : int, default = 1
| Random seed.
n_jitted_steps : int, default = 1
| Number of train batches to be processed in a compiled loop.
| Can yield significant speedups for small structures or small batch sizes.
data : :class:`.DataConfig`
| Data configuration.
model : :class:`.ModelConfig`
| Model configuration.
metrics : List of :class:`.MetricsConfig`
| Metrics configuration.
loss : List of :class:`.LossConfig`
| Loss configuration.
optimizer : :class:`.OptimizerConfig`
| Loss optimizer configuration.
weight_average : :class:`.WeightAverage`, optional
| Options for averaging weights between epochs.
callbacks : List of various CallBack classes
| Possible callbacks are :class:`.CSVCallback`,
| :class:`.TBCallback`, :class:`.MLFlowCallback`
progress_bar : :class:`.TrainProgressbarConfig`
| Progressbar configuration.
checkpoints : :class:`.CheckpointConfig`
| Checkpoint configuration.
data_parallel : bool, default = True
| Automatically uses all available GPUs for data parallel training.
| Set to false to force single device training.
"""
n_epochs: PositiveInt
patience: Optional[PositiveInt] = None
seed: int = 1
n_jitted_steps: int = 1
data_parallel: bool = True
data: DataConfig
model: ModelConfig = ModelConfig()
metrics: List[MetricsConfig] = []
loss: List[LossConfig]
optimizer: OptimizerConfig = OptimizerConfig()
weight_average: Optional[WeightAverage] = None
callbacks: List[CallBack] = [CSVCallback(name="csv")]
progress_bar: TrainProgressbarConfig = TrainProgressbarConfig()
checkpoints: CheckpointConfig = CheckpointConfig()
[docs]
def dump_config(self, save_path):
"""
Writes the current config file to the specified directory.
Parameters
----------
save_path: Path to the directory.
"""
with open(os.path.join(save_path, "config.yaml"), "w") as conf:
yaml.dump(self.model_dump(), conf, default_flow_style=False)