Configuration#
apax.config
Training Configuration#
- class apax.config.train_config.Config(*, n_epochs: int, patience: int | None = None, seed: int = 1, n_models: int = 1, n_jitted_steps: int = 1, data_parallel: bool = True, data: DataConfig, model: ModelConfig = ModelConfig(n_basis=7, n_radial=5, r_min=0.5, r_max=6.0, n_contr=-1, emb_init='uniform', nn=[512, 512], b_init='normal', use_zbl=False, calc_stress=False, descriptor_dtype='fp64', readout_dtype='fp32', scale_shift_dtype='fp32'), metrics: List[MetricsConfig] = [], loss: List[LossConfig], optimizer: OptimizerConfig = OptimizerConfig(opt_name='adam', emb_lr=0.02, nn_lr=0.03, scale_lr=0.001, shift_lr=0.05, zbl_lr=0.001, transition_begin=0, opt_kwargs={}, sam_rho=0.0), callbacks: List[CSVCallback | TBCallback | MLFlowCallback] = [CSVCallback(name='csv')], progress_bar: TrainProgressbarConfig = TrainProgressbarConfig(disable_epoch_pbar=False, disable_batch_pbar=True), checkpoints: CheckpointConfig = CheckpointConfig(ckpt_interval=1, base_model_checkpoint=None, reset_layers=[]))[source]#
Main configuration of a apax training run. Parameter that are cofig classes will be generated by parsing the config.yaml file and are specified as shown here:
Example
data: directory: models/ experiment: apax . .
- Parameters:
n_epochs (int, required) –
Number of training epochs.patience (int, optional) –
Number of epochs without improvement before trainings gets terminated.seed (int, default = 1) –
Random seed.n_models (int, default = 1) –
Number of models to be trained at once.n_jitted_steps (int, default = 1) –
Number of train batches to be processed in a compiled loop.Can yield singificant speedups for small structures or small batch sizes.data (
DataConfig) –Data configuration.model (
ModelConfig) –Model configuration.metrics (List of
MetricsConfig) –Metrics configuration.loss (List of
LossConfig) –Loss configuration.optimizer (
OptimizerConfig) –Loss optimizer configuration.callbacks (List of various CallBack classes) –
Possible callbacks areCSVCallback,progress_bar (
TrainProgressbarConfig) –Progressbar configuration.checkpoints (
CheckpointConfig) –Checkpoint configuration.data_parallel (bool, default = True) –
Automatically uses all available GPUs for data parallel training.Set to false to force single device training.
- class apax.config.train_config.DataConfig(*, directory: str, experiment: str, ds_type: Literal['cached', 'otf'] = 'cached', data_path: str | None = None, train_data_path: str | None = None, val_data_path: str | None = None, test_data_path: str | None = None, n_train: int = 1000, n_valid: int = 100, batch_size: int = 32, valid_batch_size: int = 100, shuffle_buffer_size: int = 1000, additional_properties_info: dict[str, str] = {}, shift_method: str = 'per_element_regression_shift', shift_options: dict = {'energy_regularisation': 1.0}, scale_method: str = 'per_element_force_rms_scale', scale_options: dict | None = {}, pos_unit: str | None = 'Ang', energy_unit: str | None = 'eV')[source]#
Configuration for data loading, preprocessing and training.
- Parameters:
directory (str, required) –
Path to directory where training results and checkpoints are written.experiment (str, required) –
Model name distinguishing from others in directory.data_path (str, required if train_data_path and val_data_path is not specified) –
Path to single dataset file.train_data_path (str, required if data_path is not specified) –
Path to training dataset.val_data_path (str, required if data_path is not specified) –
Path to validation dataset.test_data_path (str, optional) –
Path to test dataset.n_train (int, default = 1000) –
Number of training datapoints from data_path.n_valid (int, default = 100) –
Number of validation datapoints from data_path.batch_size (int, default = 32) –
Number of training examples to be evaluated at once.valid_batch_size (int, default = 100) –
Number of validation examples to be evaluated at once.shuffle_buffer_size (int, default = 1000) –
Size of the tf.data shuffle buffer.additional_properties_info (dict, optional) –
dict of property name, shape (ragged or fixed) pairsenergy_regularisation –
Magnitude of the regularization in the per-element energy regression.
- class apax.config.train_config.ModelConfig(*, n_basis: int = 7, n_radial: int = 5, r_min: float = 0.5, r_max: float = 6.0, n_contr: int = -1, emb_init: str | None = 'uniform', nn: List[int] = [512, 512], b_init: Literal['normal', 'zeros'] = 'normal', use_zbl: bool = False, calc_stress: bool = False, descriptor_dtype: Literal['fp32', 'fp64'] = 'fp64', readout_dtype: Literal['fp32', 'fp64'] = 'fp32', scale_shift_dtype: Literal['fp32', 'fp64'] = 'fp32')[source]#
Configuration for the model.
- Parameters:
n_basis (PositiveInt, default = 7) – Number of uncontracted gaussian basis functions.
n_radial (PositiveInt, default = 5) – Number of contracted basis functions.
r_min (NonNegativeFloat, default = 0.5) – Position of the first uncontracted basis function’s mean.
r_max (PositiveFloat, default = 6.0) – Cutoff radius of the descriptor.
nn (List[PositiveInt], default = [512, 512]) – Number of hidden layers and units in those layers.
b_init (Literal["normal", "zeros"], default = "normal") – Initialization scheme for the neural network biases.
emb_init (Optional[str], default = "uniform") – Initialization scheme for embedding layer weights.
use_zbl (bool, default = False) – Whether to include the ZBL correction.
calc_stress (bool, default = False) – Whether to calculate stress during model evaluation.
descriptor_dtype (Literal["fp32", "fp64"], default = "fp64") – Data type for descriptor calculations.
readout_dtype (Literal["fp32", "fp64"], default = "fp32") – Data type for readout calculations.
scale_shift_dtype (Literal["fp32", "fp64"], default = "fp32") – Data type for scale and shift parameters.
- class apax.config.train_config.OptimizerConfig(*, opt_name: str = 'adam', emb_lr: float = 0.02, nn_lr: float = 0.03, scale_lr: float = 0.001, shift_lr: float = 0.05, zbl_lr: float = 0.001, transition_begin: int = 0, opt_kwargs: dict = {}, sam_rho: float = 0.0)[source]#
Configuration of the optimizer. Learning rates of 0 will freeze the respective parameters.
- Parameters:
opt_name (str, default = "adam") – Name of the optimizer. Can be any optax optimizer.
emb_lr (NonNegativeFloat, default = 0.02) – Learning rate of the elemental embedding contraction coefficients.
nn_lr (NonNegativeFloat, default = 0.03) – Learning rate of the neural network parameters.
scale_lr (NonNegativeFloat, default = 0.001) – Learning rate of the elemental output scaling factors.
shift_lr (NonNegativeFloat, default = 0.05) – Learning rate of the elemental output shifts.
zbl_lr (NonNegativeFloat, default = 0.001) – Learning rate of the ZBL correction parameters.
transition_begin (int, default = 0) – Number of training steps (not epochs) before the start of the linear learning rate schedule.
opt_kwargs (dict, default = {}) – Optimizer keyword arguments. Passed to the optax optimizer.
sam_rho (NonNegativeFloat, default = 0.0) – Rho parameter for Sharpness-Aware Minimization.
- class apax.config.train_config.MetricsConfig(*, name: str, reductions: List[str])[source]#
Configuration for the metrics collected during training.
- Parameters:
name (str) – Keyword of the quantity, e.g., ‘energy’.
reductions (List[str]) – List of reductions performed on the difference between target and predictions. Can be ‘mae’, ‘mse’, ‘rmse’ for energies and forces. For forces, ‘angle’ can also be used.
- class apax.config.train_config.LossConfig(*, name: str, loss_type: str = 'mse', weight: float = 1.0, atoms_exponent: float = 1, parameters: dict = {})[source]#
Configuration of the loss functions used during training.
- Parameters:
name (str) – Keyword of the quantity, e.g., ‘energy’.
loss_type (str, optional) – Weighting scheme for atomic contributions. See the MLIP package for reference 10.1088/2632-2153/abc9fe for details, by default “mse”.
weight (NonNegativeFloat, optional) – Weighting factor in the overall loss function, by default 1.0.
atoms_exponent (NonNegativeFloat, optional) – Exponent for atomic contributions weighting, by default 1.
parameters (dict, optional) – Additional parameters for configuring the loss function, by default {}.
Notes
This class specifies the configuration of the loss functions used during training.
- class apax.config.train_config.CheckpointConfig(*, ckpt_interval: int = 1, base_model_checkpoint: str | None = None, reset_layers: List[str] = [])[source]#
Checkpoint configuration.
- Parameters:
ckpt_interval (Number of epochs between checkpoints.) –
base_model_checkpoint (Path to the folder containing a pre-trained model ckpt.) –
reset_layers (List of layer names for which the parameters will be reinitialized.) –
- class apax.config.train_config.TrainProgressbarConfig(*, disable_epoch_pbar: bool = False, disable_batch_pbar: bool = True)[source]#
Configuration of progressbars.
- Parameters:
disable_epoch_pbar (Set to True to disable the epoch progress bar.) –
disable_batch_pbar (Set to True to disable the batch progress bar.) –
- class apax.config.train_config.CSVCallback(*, name: Literal['csv'])[source]#
Configuration of the CSVCallback.
- Parameters:
name (Keyword of the callback used..) –
Molecular Dynamics Configuration#
- class apax.config.md_config.MDConfig(*, seed: int = 1, ensemble: NVEOptions | NVTOptions | NPTOptions = NVTOptions(dt=0.5, name='nvt', temperature=298.15, thermostat_chain=NHCOptions(chain_length=3, chain_steps=2, sy_steps=3, tau=100)), duration: float, n_inner: int = 100, sampling_rate: int = 10, buffer_size: int = 100, dr_threshold: float = 0.5, extra_capacity: int = 0, initial_structure: str, load_momenta: bool = False, sim_dir: str = '.', traj_name: str = 'md.h5', restart: bool = True, checkpoint_interval: int = 50000, disable_pbar: bool = False)[source]#
Configuration for a NHC molecular dynamics simulation. Full config here:
- Parameters:
seed (int, default = 1) –
Random seed for momentum initialization.temperature (float, default = 298.15) –
Temperature of the simulation in Kelvin.dt (float, default = 0.5) –
Time step in fs.duration (float, required) –
Total simulation time in fs.n_inner (int, default = 100) –
Number of compiled simulation steps (i.e. number of iterations of thejax.lax.fori_loop loop). Also determines atoms buffer size.sampling_rate (int, default = 10) –
Interval between saving frames.buffer_size (int, default = 100) –
Number of collected frames to be dumped at once.dr_threshold (float, default = 0.5) –
Skin of the neighborlist.extra_capacity (int, default = 0) –
JaxMD allocates a maximal number of neighbors. This argument lets you addadditional capacity to avoid recompilation. The default is usually fine.initial_structure (str, required) –
Path to the starting structure of the simulation.sim_dir (str, default = ".") –
Directory where simulation file will be stored.traj_name (str, default = "md.h5") –
Name of the trajectory file.restart (bool, default = True) –
Whether the simulation should restart from the latest configuration intraj_name.checkpoint_interval (int, default = 50_000) –
Number of time steps between saving full simulation state checkpoints.These will be loaded with the restart option.disable_pbar (bool, False) –
Disables the MD progressbar.
- class apax.config.md_config.NPTOptions(*, dt: float = 0.5, name: Literal['npt'], temperature: float = 298.15, thermostat_chain: NHCOptions = NHCOptions(chain_length=3, chain_steps=2, sy_steps=3, tau=100), pressure: float = 1.01325, barostat_chain: NHCOptions = NHCOptions(chain_length=3, chain_steps=2, sy_steps=3, tau=1000.0))[source]#
Options for NPT ensemble simulations.
- Parameters:
name (Literal["npt"]) – Name of the ensemble.
pressure (PositiveFloat, default = 1.01325) – Pressure in bar.
barostat_chain (NHCOptions, default = NHCOptions(tau=1000)) – Barostat chain options.