Configuration#

apax.config

Training Configuration#

class apax.config.train_config.Config(*, n_epochs: int, patience: int | None = None, seed: int = 1, n_models: int = 1, n_jitted_steps: int = 1, data_parallel: bool = True, data: DataConfig, model: ModelConfig = ModelConfig(n_basis=7, n_radial=5, r_min=0.5, r_max=6.0, n_contr=-1, emb_init='uniform', nn=[512, 512], b_init='normal', use_zbl=False, calc_stress=False, descriptor_dtype='fp64', readout_dtype='fp32', scale_shift_dtype='fp32'), metrics: List[MetricsConfig] = [], loss: List[LossConfig], optimizer: OptimizerConfig = OptimizerConfig(opt_name='adam', emb_lr=0.02, nn_lr=0.03, scale_lr=0.001, shift_lr=0.05, zbl_lr=0.001, transition_begin=0, opt_kwargs={}, sam_rho=0.0), callbacks: List[CSVCallback | TBCallback | MLFlowCallback] = [CSVCallback(name='csv')], progress_bar: TrainProgressbarConfig = TrainProgressbarConfig(disable_epoch_pbar=False, disable_batch_pbar=True), checkpoints: CheckpointConfig = CheckpointConfig(ckpt_interval=1, base_model_checkpoint=None, reset_layers=[]))[source]#

Main configuration of a apax training run. Parameter that are cofig classes will be generated by parsing the config.yaml file and are specified as shown here:

Example

data:
    directory: models/
    experiment: apax
        .
        .
Parameters:
  • n_epochs (int, required) –

    Number of training epochs.

  • patience (int, optional) –

    Number of epochs without improvement before trainings gets terminated.

  • seed (int, default = 1) –

    Random seed.

  • n_models (int, default = 1) –

    Number of models to be trained at once.

  • n_jitted_steps (int, default = 1) –

    Number of train batches to be processed in a compiled loop.
    Can yield singificant speedups for small structures or small batch sizes.

  • data (DataConfig) –

    Data configuration.

  • model (ModelConfig) –

    Model configuration.

  • metrics (List of MetricsConfig) –

    Metrics configuration.

  • loss (List of LossConfig) –

    Loss configuration.

  • optimizer (OptimizerConfig) –

    Loss optimizer configuration.

  • callbacks (List of various CallBack classes) –

    Possible callbacks are CSVCallback,

  • progress_bar (TrainProgressbarConfig) –

    Progressbar configuration.

  • checkpoints (CheckpointConfig) –

    Checkpoint configuration.

  • data_parallel (bool, default = True) –

    Automatically uses all available GPUs for data parallel training.
    Set to false to force single device training.

dump_config(save_path)[source]#

Writes the current config file to the specified directory.

Parameters:

save_path (Path to the directory.) –

class apax.config.train_config.DataConfig(*, directory: str, experiment: str, ds_type: Literal['cached', 'otf'] = 'cached', data_path: str | None = None, train_data_path: str | None = None, val_data_path: str | None = None, test_data_path: str | None = None, n_train: int = 1000, n_valid: int = 100, batch_size: int = 32, valid_batch_size: int = 100, shuffle_buffer_size: int = 1000, additional_properties_info: dict[str, str] = {}, shift_method: str = 'per_element_regression_shift', shift_options: dict = {'energy_regularisation': 1.0}, scale_method: str = 'per_element_force_rms_scale', scale_options: dict | None = {}, pos_unit: str | None = 'Ang', energy_unit: str | None = 'eV')[source]#

Configuration for data loading, preprocessing and training.

Parameters:
  • directory (str, required) –

    Path to directory where training results and checkpoints are written.

  • experiment (str, required) –

    Model name distinguishing from others in directory.

  • data_path (str, required if train_data_path and val_data_path is not specified) –

    Path to single dataset file.

  • train_data_path (str, required if data_path is not specified) –

    Path to training dataset.

  • val_data_path (str, required if data_path is not specified) –

    Path to validation dataset.

  • test_data_path (str, optional) –

    Path to test dataset.

  • n_train (int, default = 1000) –

    Number of training datapoints from data_path.

  • n_valid (int, default = 100) –

    Number of validation datapoints from data_path.

  • batch_size (int, default = 32) –

    Number of training examples to be evaluated at once.

  • valid_batch_size (int, default = 100) –

    Number of validation examples to be evaluated at once.

  • shuffle_buffer_size (int, default = 1000) –

    Size of the tf.data shuffle buffer.

  • additional_properties_info (dict, optional) –

    dict of property name, shape (ragged or fixed) pairs

  • energy_regularisation

    Magnitude of the regularization in the per-element energy regression.

class apax.config.train_config.ModelConfig(*, n_basis: int = 7, n_radial: int = 5, r_min: float = 0.5, r_max: float = 6.0, n_contr: int = -1, emb_init: str | None = 'uniform', nn: List[int] = [512, 512], b_init: Literal['normal', 'zeros'] = 'normal', use_zbl: bool = False, calc_stress: bool = False, descriptor_dtype: Literal['fp32', 'fp64'] = 'fp64', readout_dtype: Literal['fp32', 'fp64'] = 'fp32', scale_shift_dtype: Literal['fp32', 'fp64'] = 'fp32')[source]#

Configuration for the model.

Parameters:
  • n_basis (PositiveInt, default = 7) – Number of uncontracted gaussian basis functions.

  • n_radial (PositiveInt, default = 5) – Number of contracted basis functions.

  • r_min (NonNegativeFloat, default = 0.5) – Position of the first uncontracted basis function’s mean.

  • r_max (PositiveFloat, default = 6.0) – Cutoff radius of the descriptor.

  • nn (List[PositiveInt], default = [512, 512]) – Number of hidden layers and units in those layers.

  • b_init (Literal["normal", "zeros"], default = "normal") – Initialization scheme for the neural network biases.

  • emb_init (Optional[str], default = "uniform") – Initialization scheme for embedding layer weights.

  • use_zbl (bool, default = False) – Whether to include the ZBL correction.

  • calc_stress (bool, default = False) – Whether to calculate stress during model evaluation.

  • descriptor_dtype (Literal["fp32", "fp64"], default = "fp64") – Data type for descriptor calculations.

  • readout_dtype (Literal["fp32", "fp64"], default = "fp32") – Data type for readout calculations.

  • scale_shift_dtype (Literal["fp32", "fp64"], default = "fp32") – Data type for scale and shift parameters.

class apax.config.train_config.OptimizerConfig(*, opt_name: str = 'adam', emb_lr: float = 0.02, nn_lr: float = 0.03, scale_lr: float = 0.001, shift_lr: float = 0.05, zbl_lr: float = 0.001, transition_begin: int = 0, opt_kwargs: dict = {}, sam_rho: float = 0.0)[source]#

Configuration of the optimizer. Learning rates of 0 will freeze the respective parameters.

Parameters:
  • opt_name (str, default = "adam") – Name of the optimizer. Can be any optax optimizer.

  • emb_lr (NonNegativeFloat, default = 0.02) – Learning rate of the elemental embedding contraction coefficients.

  • nn_lr (NonNegativeFloat, default = 0.03) – Learning rate of the neural network parameters.

  • scale_lr (NonNegativeFloat, default = 0.001) – Learning rate of the elemental output scaling factors.

  • shift_lr (NonNegativeFloat, default = 0.05) – Learning rate of the elemental output shifts.

  • zbl_lr (NonNegativeFloat, default = 0.001) – Learning rate of the ZBL correction parameters.

  • transition_begin (int, default = 0) – Number of training steps (not epochs) before the start of the linear learning rate schedule.

  • opt_kwargs (dict, default = {}) – Optimizer keyword arguments. Passed to the optax optimizer.

  • sam_rho (NonNegativeFloat, default = 0.0) – Rho parameter for Sharpness-Aware Minimization.

class apax.config.train_config.MetricsConfig(*, name: str, reductions: List[str])[source]#

Configuration for the metrics collected during training.

Parameters:
  • name (str) – Keyword of the quantity, e.g., ‘energy’.

  • reductions (List[str]) – List of reductions performed on the difference between target and predictions. Can be ‘mae’, ‘mse’, ‘rmse’ for energies and forces. For forces, ‘angle’ can also be used.

class apax.config.train_config.LossConfig(*, name: str, loss_type: str = 'mse', weight: float = 1.0, atoms_exponent: float = 1, parameters: dict = {})[source]#

Configuration of the loss functions used during training.

Parameters:
  • name (str) – Keyword of the quantity, e.g., ‘energy’.

  • loss_type (str, optional) – Weighting scheme for atomic contributions. See the MLIP package for reference 10.1088/2632-2153/abc9fe for details, by default “mse”.

  • weight (NonNegativeFloat, optional) – Weighting factor in the overall loss function, by default 1.0.

  • atoms_exponent (NonNegativeFloat, optional) – Exponent for atomic contributions weighting, by default 1.

  • parameters (dict, optional) – Additional parameters for configuring the loss function, by default {}.

Notes

This class specifies the configuration of the loss functions used during training.

class apax.config.train_config.CheckpointConfig(*, ckpt_interval: int = 1, base_model_checkpoint: str | None = None, reset_layers: List[str] = [])[source]#

Checkpoint configuration.

Parameters:
  • ckpt_interval (Number of epochs between checkpoints.) –

  • base_model_checkpoint (Path to the folder containing a pre-trained model ckpt.) –

  • reset_layers (List of layer names for which the parameters will be reinitialized.) –

class apax.config.train_config.TrainProgressbarConfig(*, disable_epoch_pbar: bool = False, disable_batch_pbar: bool = True)[source]#

Configuration of progressbars.

Parameters:
  • disable_epoch_pbar (Set to True to disable the epoch progress bar.) –

  • disable_batch_pbar (Set to True to disable the batch progress bar.) –

class apax.config.train_config.CSVCallback(*, name: Literal['csv'])[source]#

Configuration of the CSVCallback.

Parameters:

name (Keyword of the callback used..) –

class apax.config.train_config.TBCallback(*, name: Literal['tensorboard'])[source]#

Configuration of the TensorBoard callback.

Parameters:

name (Keyword of the callback used..) –

class apax.config.train_config.MLFlowCallback(*, name: Literal['mlflow'], experiment: str)[source]#

Configuration of the MLFlow callback.

Parameters:
  • name (Keyword of the callback used.) –

  • experiment (Path to the MLFlow experiment, e.g. /Users/<user>/<my_experiment>) –

Molecular Dynamics Configuration#

class apax.config.md_config.MDConfig(*, seed: int = 1, ensemble: NVEOptions | NVTOptions | NPTOptions = NVTOptions(dt=0.5, name='nvt', temperature=298.15, thermostat_chain=NHCOptions(chain_length=3, chain_steps=2, sy_steps=3, tau=100)), duration: float, n_inner: int = 100, sampling_rate: int = 10, buffer_size: int = 100, dr_threshold: float = 0.5, extra_capacity: int = 0, initial_structure: str, load_momenta: bool = False, sim_dir: str = '.', traj_name: str = 'md.h5', restart: bool = True, checkpoint_interval: int = 50000, disable_pbar: bool = False)[source]#

Configuration for a NHC molecular dynamics simulation. Full config here:

Parameters:
  • seed (int, default = 1) –

    Random seed for momentum initialization.

  • temperature (float, default = 298.15) –

    Temperature of the simulation in Kelvin.

  • dt (float, default = 0.5) –

    Time step in fs.

  • duration (float, required) –

    Total simulation time in fs.

  • n_inner (int, default = 100) –

    Number of compiled simulation steps (i.e. number of iterations of the
    jax.lax.fori_loop loop). Also determines atoms buffer size.

  • sampling_rate (int, default = 10) –

    Interval between saving frames.

  • buffer_size (int, default = 100) –

    Number of collected frames to be dumped at once.

  • dr_threshold (float, default = 0.5) –

    Skin of the neighborlist.

  • extra_capacity (int, default = 0) –

    JaxMD allocates a maximal number of neighbors. This argument lets you add
    additional capacity to avoid recompilation. The default is usually fine.

  • initial_structure (str, required) –

    Path to the starting structure of the simulation.

  • sim_dir (str, default = ".") –

    Directory where simulation file will be stored.

  • traj_name (str, default = "md.h5") –

    Name of the trajectory file.

  • restart (bool, default = True) –

    Whether the simulation should restart from the latest configuration in
    traj_name.

  • checkpoint_interval (int, default = 50_000) –

    Number of time steps between saving full simulation state checkpoints.
    These will be loaded with the restart option.

  • disable_pbar (bool, False) –

    Disables the MD progressbar.

dump_config()[source]#

Writes the current config file to the MD directory.

class apax.config.md_config.NPTOptions(*, dt: float = 0.5, name: Literal['npt'], temperature: float = 298.15, thermostat_chain: NHCOptions = NHCOptions(chain_length=3, chain_steps=2, sy_steps=3, tau=100), pressure: float = 1.01325, barostat_chain: NHCOptions = NHCOptions(chain_length=3, chain_steps=2, sy_steps=3, tau=1000.0))[source]#

Options for NPT ensemble simulations.

Parameters:
  • name (Literal["npt"]) – Name of the ensemble.

  • pressure (PositiveFloat, default = 1.01325) – Pressure in bar.

  • barostat_chain (NHCOptions, default = NHCOptions(tau=1000)) – Barostat chain options.