Configuration

apax.config

Training Configuration

class apax.config.train_config.Config(*, n_epochs: int, patience: int | None = None, seed: int = 1, n_jitted_steps: int = 1, data_parallel: bool = True, data: DataConfig, model: ModelConfig = ModelConfig(basis=GaussianBasisConfig(name='gaussian', n_basis=7, r_min=0.5, r_max=6.0), n_radial=5, n_contr=8, emb_init='uniform', nn=[512, 512], w_init='normal', b_init='normal', use_ntk=True, ensemble=None, use_zbl=False, calc_stress=False, descriptor_dtype='fp64', readout_dtype='fp32', scale_shift_dtype='fp32'), metrics: List[MetricsConfig] = [], loss: List[LossConfig], optimizer: OptimizerConfig = OptimizerConfig(name='adam', emb_lr=0.02, nn_lr=0.03, scale_lr=0.001, shift_lr=0.05, zbl_lr=0.001, schedule=LinearLR(name='linear', transition_begin=0, end_value=1e-06), kwargs={}), weight_average: WeightAverage | None = None, callbacks: List[CSVCallback | TBCallback | MLFlowCallback] = [CSVCallback(name='csv')], progress_bar: TrainProgressbarConfig = TrainProgressbarConfig(disable_epoch_pbar=False, disable_batch_pbar=True), checkpoints: CheckpointConfig = CheckpointConfig(ckpt_interval=1, base_model_checkpoint=None, reset_layers=[]))[source]

Main configuration of a apax training run. Parameter that are config classes will be generated by parsing the config.yaml file and are specified as shown here:

Example

data:
    directory: models/
    experiment: apax
        .
        .
Parameters:
  • n_epochs (int, required) –

    Number of training epochs.

  • patience (int, optional) –

    Number of epochs without improvement before trainings gets terminated.

  • seed (int, default = 1) –

    Random seed.

  • n_jitted_steps (int, default = 1) –

    Number of train batches to be processed in a compiled loop.
    Can yield significant speedups for small structures or small batch sizes.

  • data (DataConfig) –

    Data configuration.

  • model (ModelConfig) –

    Model configuration.

  • metrics (List of MetricsConfig) –

    Metrics configuration.

  • loss (List of LossConfig) –

    Loss configuration.

  • optimizer (OptimizerConfig) –

    Loss optimizer configuration.

  • weight_average (WeightAverage, optional) –

    Options for averaging weights between epochs.

  • callbacks (List of various CallBack classes) –

    Possible callbacks are CSVCallback,

  • progress_bar (TrainProgressbarConfig) –

    Progressbar configuration.

  • checkpoints (CheckpointConfig) –

    Checkpoint configuration.

  • data_parallel (bool, default = True) –

    Automatically uses all available GPUs for data parallel training.
    Set to false to force single device training.

dump_config(save_path)[source]

Writes the current config file to the specified directory.

Parameters:

save_path (Path to the directory.)

class apax.config.train_config.DataConfig(*, directory: str, experiment: str, dataset: CachedDataset | OTFDataset | PBPDatset = CachedDataset(processing='cached', shuffle_buffer_size=1000), data_path: str | None = None, train_data_path: str | None = None, val_data_path: str | None = None, test_data_path: str | None = None, n_train: int = 1000, n_valid: int = 100, batch_size: int = 32, valid_batch_size: int = 100, additional_properties_info: dict[str, str] = {}, shift_method: str = 'per_element_regression_shift', shift_options: dict = {'energy_regularisation': 1.0}, scale_method: str = 'per_element_force_rms_scale', scale_options: dict | None = {}, pos_unit: str | None = 'Ang', energy_unit: str | None = 'eV')[source]

Configuration for data loading, preprocessing and training.

Parameters:
  • directory (str, required) –

    Path to directory where training results and checkpoints are written.

  • experiment (str, required) –

    Model name distinguishing from others in directory.

  • data_path (str, required if train_data_path and val_data_path is not specified) –

    Path to single dataset file.

  • train_data_path (str, required if data_path is not specified) –

    Path to training dataset.

  • val_data_path (str, required if data_path is not specified) –

    Path to validation dataset.

  • test_data_path (str, optional) –

    Path to test dataset.

  • n_train (int, default = 1000) –

    Number of training datapoints from data_path.

  • n_valid (int, default = 100) –

    Number of validation datapoints from data_path.

  • batch_size (int, default = 32) –

    Number of training examples to be evaluated at once.

  • valid_batch_size (int, default = 100) –

    Number of validation examples to be evaluated at once.

  • shuffle_buffer_size (int, default = 1000) –

    Size of the tf.data shuffle buffer.

  • additional_properties_info (dict, optional) –

    dict of property name, shape (ragged or fixed) pairs. Currently unused.

  • energy_regularisation

    Magnitude of the regularization in the per-element energy regression.

class apax.config.train_config.ModelConfig(*, basis: GaussianBasisConfig | BesselBasisConfig = GaussianBasisConfig(name='gaussian', n_basis=7, r_min=0.5, r_max=6.0), n_radial: int = 5, n_contr: int = 8, emb_init: str | None = 'uniform', nn: List[int] = [512, 512], w_init: Literal['normal', 'lecun'] = 'normal', b_init: Literal['normal', 'zeros'] = 'normal', use_ntk: bool = True, ensemble: FullEnsembleConfig | ShallowEnsembleConfig | None = None, use_zbl: bool = False, calc_stress: bool = False, descriptor_dtype: Literal['fp32', 'fp64'] = 'fp64', readout_dtype: Literal['fp32', 'fp64'] = 'fp32', scale_shift_dtype: Literal['fp32', 'fp64'] = 'fp32')[source]

Configuration for the model.

Parameters:
  • basis (BasisConfig, default = GaussianBasisConfig()) – Configuration for primitive basis funtions.

  • n_radial (PositiveInt, default = 5) – Number of contracted basis functions.

  • n_contr (int, default = 8) – How many gaussian moment contractions to use.

  • emb_init (Optional[str], default = "uniform") – Initialization scheme for embedding layer weights.

  • nn (List[PositiveInt], default = [512, 512]) – Number of hidden layers and units in those layers.

  • w_init (Literal["normal", "lecun"], default = "normal") – Initialization scheme for the neural network weights.

  • b_init (Literal["normal", "zeros"], default = "normal") – Initialization scheme for the neural network biases.

  • use_ntk (bool, default = True) – Whether or not to use NTK parametrization.

  • ensemble (Optional[EnsembleConfig], default = None) – What kind of model ensemble to use (optional).

  • use_zbl (bool, default = False) – Whether to include the ZBL correction.

  • calc_stress (bool, default = False) – Whether to calculate stress during model evaluation.

  • descriptor_dtype (Literal["fp32", "fp64"], default = "fp64") – Data type for descriptor calculations.

  • readout_dtype (Literal["fp32", "fp64"], default = "fp32") – Data type for readout calculations.

  • scale_shift_dtype (Literal["fp32", "fp64"], default = "fp32") – Data type for scale and shift parameters.

class apax.config.train_config.OptimizerConfig(*, name: str = 'adam', emb_lr: float = 0.02, nn_lr: float = 0.03, scale_lr: float = 0.001, shift_lr: float = 0.05, zbl_lr: float = 0.001, schedule: LinearLR | CyclicCosineLR = LinearLR(name='linear', transition_begin=0, end_value=1e-06), kwargs: dict = {})[source]

Configuration of the optimizer. Learning rates of 0 will freeze the respective parameters.

Parameters:
  • name (str, default = "adam") – Name of the optimizer. Can be any optax optimizer.

  • emb_lr (NonNegativeFloat, default = 0.02) – Learning rate of the elemental embedding contraction coefficients.

  • nn_lr (NonNegativeFloat, default = 0.03) – Learning rate of the neural network parameters.

  • scale_lr (NonNegativeFloat, default = 0.001) – Learning rate of the elemental output scaling factors.

  • shift_lr (NonNegativeFloat, default = 0.05) – Learning rate of the elemental output shifts.

  • zbl_lr (NonNegativeFloat, default = 0.001) – Learning rate of the ZBL correction parameters.

  • schedule (LRSchedule = LinearLR) – Learning rate schedule.

  • kwargs (dict, default = {}) – Optimizer keyword arguments. Passed to the optax optimizer.

class apax.config.train_config.MetricsConfig(*, name: str, reductions: List[str])[source]

Configuration for the metrics collected during training.

Parameters:
  • name (str) – Keyword of the quantity, e.g., ‘energy’.

  • reductions (List[str]) – List of reductions performed on the difference between target and predictions. Can be ‘mae’, ‘mse’, ‘rmse’ for energies and forces. For forces, ‘angle’ can also be used.

class apax.config.train_config.LossConfig(*, name: str, loss_type: str = 'mse', weight: float = 1.0, atoms_exponent: float = 1, parameters: dict = {})[source]

Configuration of the loss functions used during training.

Parameters:
  • name (str) – Keyword of the quantity, e.g., ‘energy’.

  • loss_type (str, optional) – Weighting scheme for atomic contributions. See the MLIP package for reference 10.1088/2632-2153/abc9fe for details, by default “mse”.

  • weight (NonNegativeFloat, optional) – Weighting factor in the overall loss function, by default 1.0.

  • atoms_exponent (NonNegativeFloat, optional) – Exponent for atomic contributions weighting, by default 1.

  • parameters (dict, optional) – Additional parameters for configuring the loss function, by default {}.

Notes

This class specifies the configuration of the loss functions used during training.

class apax.config.train_config.CheckpointConfig(*, ckpt_interval: int = 1, base_model_checkpoint: str | None = None, reset_layers: List[str] = [])[source]

Checkpoint configuration.

Parameters:
  • ckpt_interval (Number of epochs between checkpoints.)

  • base_model_checkpoint (Path to the folder containing a pre-trained model ckpt.)

  • reset_layers (List of layer names for which the parameters will be reinitialized.)

class apax.config.train_config.TrainProgressbarConfig(*, disable_epoch_pbar: bool = False, disable_batch_pbar: bool = True)[source]

Configuration of progressbars.

Parameters:
  • disable_epoch_pbar (Set to True to disable the epoch progress bar.)

  • disable_batch_pbar (Set to True to disable the batch progress bar.)

class apax.config.train_config.CSVCallback(*, name: Literal['csv'])[source]

Configuration of the CSVCallback.

Parameters:

name (Keyword of the callback used..)

class apax.config.train_config.TBCallback(*, name: Literal['tensorboard'])[source]

Configuration of the TensorBoard callback.

Parameters:

name (Keyword of the callback used..)

class apax.config.train_config.MLFlowCallback(*, name: Literal['mlflow'], experiment: str)[source]

Configuration of the MLFlow callback.

Parameters:
  • name (Keyword of the callback used.)

  • experiment (Path to the MLFlow experiment, e.g. /Users/<user>/<my_experiment>)

Molecular Dynamics Configuration

class apax.config.md_config.MDConfig(*, seed: int = 1, ensemble: NVEOptions | NVTOptions | NPTOptions = NVTOptions(dt=0.5, name='nvt', temperature=298.15, thermostat_chain=NHCOptions(chain_length=3, chain_steps=2, sy_steps=3, tau=100)), duration: float, n_inner: int = 100, sampling_rate: int = 10, buffer_size: int = 100, dr_threshold: float = 0.5, extra_capacity: int = 0, initial_structure: str, load_momenta: bool = False, sim_dir: str = '.', traj_name: str = 'md.h5', restart: bool = True, checkpoint_interval: int = 50000, disable_pbar: bool = False)[source]

Configuration for a NHC molecular dynamics simulation. Full config here:

Parameters:
  • seed (int, default = 1) –

    Random seed for momentum initialization.

  • temperature (float, default = 298.15) –

    Temperature of the simulation in Kelvin.

  • dt (float, default = 0.5) –

    Time step in fs.

  • duration (float, required) –

    Total simulation time in fs.

  • n_inner (int, default = 100) –

    Number of compiled simulation steps (i.e. number of iterations of the
    jax.lax.fori_loop loop). Also determines atoms buffer size.

  • sampling_rate (int, default = 10) –

    Interval between saving frames.

  • buffer_size (int, default = 100) –

    Number of collected frames to be dumped at once.

  • dr_threshold (float, default = 0.5) –

    Skin of the neighborlist.

  • extra_capacity (int, default = 0) –

    JaxMD allocates a maximal number of neighbors. This argument lets you add
    additional capacity to avoid recompilation. The default is usually fine.

  • initial_structure (str, required) –

    Path to the starting structure of the simulation.

  • sim_dir (str, default = ".") –

    Directory where simulation file will be stored.

  • traj_name (str, default = "md.h5") –

    Name of the trajectory file.

  • restart (bool, default = True) –

    Whether the simulation should restart from the latest configuration in
    traj_name.

  • checkpoint_interval (int, default = 50_000) –

    Number of time steps between saving full simulation state checkpoints.
    These will be loaded with the restart option.

  • disable_pbar (bool, False) –

    Disables the MD progressbar.

dump_config()[source]

Writes the current config file to the MD directory.

class apax.config.md_config.NPTOptions(*, dt: float = 0.5, name: Literal['npt'], temperature: float = 298.15, thermostat_chain: NHCOptions = NHCOptions(chain_length=3, chain_steps=2, sy_steps=3, tau=100), pressure: float = 1.01325, barostat_chain: NHCOptions = NHCOptions(chain_length=3, chain_steps=2, sy_steps=3, tau=1000.0))[source]

Options for NPT ensemble simulations.

Parameters:
  • name (Literal["npt"]) – Name of the ensemble.

  • pressure (PositiveFloat, default = 1.01325) – Pressure in bar.

  • barostat_chain (NHCOptions, default = NHCOptions(tau=1000)) – Barostat chain options.