Configuration¶
apax.config
Training Configuration¶
- class apax.config.train_config.Config(*, n_epochs: int, patience: int | None = None, seed: int = 1, n_jitted_steps: int = 1, data_parallel: bool = True, data: DataConfig, model: ModelConfig = ModelConfig(basis=GaussianBasisConfig(name='gaussian', n_basis=7, r_min=0.5, r_max=6.0), n_radial=5, n_contr=8, emb_init='uniform', nn=[512, 512], w_init='normal', b_init='normal', use_ntk=True, ensemble=None, use_zbl=False, calc_stress=False, descriptor_dtype='fp64', readout_dtype='fp32', scale_shift_dtype='fp32'), metrics: List[MetricsConfig] = [], loss: List[LossConfig], optimizer: OptimizerConfig = OptimizerConfig(name='adam', emb_lr=0.02, nn_lr=0.03, scale_lr=0.001, shift_lr=0.05, zbl_lr=0.001, schedule=LinearLR(name='linear', transition_begin=0, end_value=1e-06), kwargs={}), weight_average: WeightAverage | None = None, callbacks: List[CSVCallback | TBCallback | MLFlowCallback] = [CSVCallback(name='csv')], progress_bar: TrainProgressbarConfig = TrainProgressbarConfig(disable_epoch_pbar=False, disable_batch_pbar=True), checkpoints: CheckpointConfig = CheckpointConfig(ckpt_interval=1, base_model_checkpoint=None, reset_layers=[]))[source]¶
Main configuration of a apax training run. Parameter that are config classes will be generated by parsing the config.yaml file and are specified as shown here:
Example
data: directory: models/ experiment: apax . .
- Parameters:
n_epochs (int, required) –
Number of training epochs.patience (int, optional) –
Number of epochs without improvement before trainings gets terminated.seed (int, default = 1) –
Random seed.n_jitted_steps (int, default = 1) –
Number of train batches to be processed in a compiled loop.Can yield significant speedups for small structures or small batch sizes.data (
DataConfig) –Data configuration.model (
ModelConfig) –Model configuration.metrics (List of
MetricsConfig) –Metrics configuration.loss (List of
LossConfig) –Loss configuration.optimizer (
OptimizerConfig) –Loss optimizer configuration.weight_average (
WeightAverage, optional) –Options for averaging weights between epochs.callbacks (List of various CallBack classes) –
Possible callbacks areCSVCallback,progress_bar (
TrainProgressbarConfig) –Progressbar configuration.checkpoints (
CheckpointConfig) –Checkpoint configuration.data_parallel (bool, default = True) –
Automatically uses all available GPUs for data parallel training.Set to false to force single device training.
- class apax.config.train_config.DataConfig(*, directory: str, experiment: str, dataset: CachedDataset | OTFDataset | PBPDatset = CachedDataset(processing='cached', shuffle_buffer_size=1000), data_path: str | None = None, train_data_path: str | None = None, val_data_path: str | None = None, test_data_path: str | None = None, n_train: int = 1000, n_valid: int = 100, batch_size: int = 32, valid_batch_size: int = 100, additional_properties_info: dict[str, str] = {}, shift_method: str = 'per_element_regression_shift', shift_options: dict = {'energy_regularisation': 1.0}, scale_method: str = 'per_element_force_rms_scale', scale_options: dict | None = {}, pos_unit: str | None = 'Ang', energy_unit: str | None = 'eV')[source]¶
Configuration for data loading, preprocessing and training.
- Parameters:
directory (str, required) –
Path to directory where training results and checkpoints are written.experiment (str, required) –
Model name distinguishing from others in directory.data_path (str, required if train_data_path and val_data_path is not specified) –
Path to single dataset file.train_data_path (str, required if data_path is not specified) –
Path to training dataset.val_data_path (str, required if data_path is not specified) –
Path to validation dataset.test_data_path (str, optional) –
Path to test dataset.n_train (int, default = 1000) –
Number of training datapoints from data_path.n_valid (int, default = 100) –
Number of validation datapoints from data_path.batch_size (int, default = 32) –
Number of training examples to be evaluated at once.valid_batch_size (int, default = 100) –
Number of validation examples to be evaluated at once.shuffle_buffer_size (int, default = 1000) –
Size of the tf.data shuffle buffer.additional_properties_info (dict, optional) –
dict of property name, shape (ragged or fixed) pairs. Currently unused.energy_regularisation –
Magnitude of the regularization in the per-element energy regression.
- class apax.config.train_config.ModelConfig(*, basis: GaussianBasisConfig | BesselBasisConfig = GaussianBasisConfig(name='gaussian', n_basis=7, r_min=0.5, r_max=6.0), n_radial: int = 5, n_contr: int = 8, emb_init: str | None = 'uniform', nn: List[int] = [512, 512], w_init: Literal['normal', 'lecun'] = 'normal', b_init: Literal['normal', 'zeros'] = 'normal', use_ntk: bool = True, ensemble: FullEnsembleConfig | ShallowEnsembleConfig | None = None, use_zbl: bool = False, calc_stress: bool = False, descriptor_dtype: Literal['fp32', 'fp64'] = 'fp64', readout_dtype: Literal['fp32', 'fp64'] = 'fp32', scale_shift_dtype: Literal['fp32', 'fp64'] = 'fp32')[source]¶
Configuration for the model.
- Parameters:
basis (BasisConfig, default = GaussianBasisConfig()) – Configuration for primitive basis funtions.
n_radial (PositiveInt, default = 5) – Number of contracted basis functions.
n_contr (int, default = 8) – How many gaussian moment contractions to use.
emb_init (Optional[str], default = "uniform") – Initialization scheme for embedding layer weights.
nn (List[PositiveInt], default = [512, 512]) – Number of hidden layers and units in those layers.
w_init (Literal["normal", "lecun"], default = "normal") – Initialization scheme for the neural network weights.
b_init (Literal["normal", "zeros"], default = "normal") – Initialization scheme for the neural network biases.
use_ntk (bool, default = True) – Whether or not to use NTK parametrization.
ensemble (Optional[EnsembleConfig], default = None) – What kind of model ensemble to use (optional).
use_zbl (bool, default = False) – Whether to include the ZBL correction.
calc_stress (bool, default = False) – Whether to calculate stress during model evaluation.
descriptor_dtype (Literal["fp32", "fp64"], default = "fp64") – Data type for descriptor calculations.
readout_dtype (Literal["fp32", "fp64"], default = "fp32") – Data type for readout calculations.
scale_shift_dtype (Literal["fp32", "fp64"], default = "fp32") – Data type for scale and shift parameters.
- class apax.config.train_config.OptimizerConfig(*, name: str = 'adam', emb_lr: float = 0.02, nn_lr: float = 0.03, scale_lr: float = 0.001, shift_lr: float = 0.05, zbl_lr: float = 0.001, schedule: LinearLR | CyclicCosineLR = LinearLR(name='linear', transition_begin=0, end_value=1e-06), kwargs: dict = {})[source]¶
Configuration of the optimizer. Learning rates of 0 will freeze the respective parameters.
- Parameters:
name (str, default = "adam") – Name of the optimizer. Can be any optax optimizer.
emb_lr (NonNegativeFloat, default = 0.02) – Learning rate of the elemental embedding contraction coefficients.
nn_lr (NonNegativeFloat, default = 0.03) – Learning rate of the neural network parameters.
scale_lr (NonNegativeFloat, default = 0.001) – Learning rate of the elemental output scaling factors.
shift_lr (NonNegativeFloat, default = 0.05) – Learning rate of the elemental output shifts.
zbl_lr (NonNegativeFloat, default = 0.001) – Learning rate of the ZBL correction parameters.
schedule (LRSchedule = LinearLR) – Learning rate schedule.
kwargs (dict, default = {}) – Optimizer keyword arguments. Passed to the optax optimizer.
- class apax.config.train_config.MetricsConfig(*, name: str, reductions: List[str])[source]¶
Configuration for the metrics collected during training.
- Parameters:
name (str) – Keyword of the quantity, e.g., ‘energy’.
reductions (List[str]) – List of reductions performed on the difference between target and predictions. Can be ‘mae’, ‘mse’, ‘rmse’ for energies and forces. For forces, ‘angle’ can also be used.
- class apax.config.train_config.LossConfig(*, name: str, loss_type: str = 'mse', weight: float = 1.0, atoms_exponent: float = 1, parameters: dict = {})[source]¶
Configuration of the loss functions used during training.
- Parameters:
name (str) – Keyword of the quantity, e.g., ‘energy’.
loss_type (str, optional) – Weighting scheme for atomic contributions. See the MLIP package for reference 10.1088/2632-2153/abc9fe for details, by default “mse”.
weight (NonNegativeFloat, optional) – Weighting factor in the overall loss function, by default 1.0.
atoms_exponent (NonNegativeFloat, optional) – Exponent for atomic contributions weighting, by default 1.
parameters (dict, optional) – Additional parameters for configuring the loss function, by default {}.
Notes
This class specifies the configuration of the loss functions used during training.
- class apax.config.train_config.CheckpointConfig(*, ckpt_interval: int = 1, base_model_checkpoint: str | None = None, reset_layers: List[str] = [])[source]¶
Checkpoint configuration.
- Parameters:
ckpt_interval (Number of epochs between checkpoints.)
base_model_checkpoint (Path to the folder containing a pre-trained model ckpt.)
reset_layers (List of layer names for which the parameters will be reinitialized.)
- class apax.config.train_config.TrainProgressbarConfig(*, disable_epoch_pbar: bool = False, disable_batch_pbar: bool = True)[source]¶
Configuration of progressbars.
- Parameters:
disable_epoch_pbar (Set to True to disable the epoch progress bar.)
disable_batch_pbar (Set to True to disable the batch progress bar.)
- class apax.config.train_config.CSVCallback(*, name: Literal['csv'])[source]¶
Configuration of the CSVCallback.
- Parameters:
name (Keyword of the callback used..)
Molecular Dynamics Configuration¶
- class apax.config.md_config.MDConfig(*, seed: int = 1, ensemble: NVEOptions | NVTOptions | NPTOptions = NVTOptions(dt=0.5, name='nvt', temperature=298.15, thermostat_chain=NHCOptions(chain_length=3, chain_steps=2, sy_steps=3, tau=100)), duration: float, n_inner: int = 100, sampling_rate: int = 10, buffer_size: int = 100, dr_threshold: float = 0.5, extra_capacity: int = 0, initial_structure: str, load_momenta: bool = False, sim_dir: str = '.', traj_name: str = 'md.h5', restart: bool = True, checkpoint_interval: int = 50000, disable_pbar: bool = False)[source]¶
Configuration for a NHC molecular dynamics simulation. Full config here:
- Parameters:
seed (int, default = 1) –
Random seed for momentum initialization.temperature (float, default = 298.15) –
Temperature of the simulation in Kelvin.dt (float, default = 0.5) –
Time step in fs.duration (float, required) –
Total simulation time in fs.n_inner (int, default = 100) –
Number of compiled simulation steps (i.e. number of iterations of thejax.lax.fori_loop loop). Also determines atoms buffer size.sampling_rate (int, default = 10) –
Interval between saving frames.buffer_size (int, default = 100) –
Number of collected frames to be dumped at once.dr_threshold (float, default = 0.5) –
Skin of the neighborlist.extra_capacity (int, default = 0) –
JaxMD allocates a maximal number of neighbors. This argument lets you addadditional capacity to avoid recompilation. The default is usually fine.initial_structure (str, required) –
Path to the starting structure of the simulation.sim_dir (str, default = ".") –
Directory where simulation file will be stored.traj_name (str, default = "md.h5") –
Name of the trajectory file.restart (bool, default = True) –
Whether the simulation should restart from the latest configuration intraj_name.checkpoint_interval (int, default = 50_000) –
Number of time steps between saving full simulation state checkpoints.These will be loaded with the restart option.disable_pbar (bool, False) –
Disables the MD progressbar.
- class apax.config.md_config.NPTOptions(*, dt: float = 0.5, name: Literal['npt'], temperature: float = 298.15, thermostat_chain: NHCOptions = NHCOptions(chain_length=3, chain_steps=2, sy_steps=3, tau=100), pressure: float = 1.01325, barostat_chain: NHCOptions = NHCOptions(chain_length=3, chain_steps=2, sy_steps=3, tau=1000.0))[source]¶
Options for NPT ensemble simulations.
- Parameters:
name (Literal["npt"]) – Name of the ensemble.
pressure (PositiveFloat, default = 1.01325) – Pressure in bar.
barostat_chain (NHCOptions, default = NHCOptions(tau=1000)) – Barostat chain options.