Configuration

apax.config

Training Configuration

class apax.config.train_config.Config(*, n_epochs: Annotated[int, Gt(gt=0)], patience: Annotated[int, Gt(gt=0)] | None = None, patience_min_delta: Annotated[float, Ge(ge=0)] = 0.0, seed: int = 1, data_parallel: bool = True, ckpt_interval: Annotated[int, Gt(gt=0)] = 500, data: DataConfig, model: GMNNConfig | EquivMPConfig | So3kratesConfig = GMNNConfig(basis=BesselBasisConfig(name='bessel', n_basis=16, r_max=5.0), nn=[256, 256], w_init='lecun', b_init='zeros', activation_fn='variance_preserving_swish', use_ntk=False, ensemble=None, property_heads=[], empirical_corrections=[], calc_stress=False, descriptor_dtype='fp32', readout_dtype='fp32', scale_shift_dtype='fp64', name='gmnn', n_radial=5, n_contr=8, emb_init='uniform'), metrics: List[MetricsConfig] = [], loss: List[LossConfig], optimizer: OptimizerConfig = OptimizerConfig(name='adam', emb_lr=0.001, nn_lr=0.001, scale_lr=0.0001, shift_lr=0.003, zbl_lr=0.0001, rep_scale_lr=0.001, rep_prefactor_lr=0.0001, gradient_clipping=1000.0, schedule=LinearLR(name='linear', transition_begin=0, end_value=1e-06), kwargs={}), weight_average: WeightAverage | None = None, callbacks: List[Annotated[CSVCallback | TBCallback | MLFlowCallback | KerasPruningCallback, FieldInfo(annotation=NoneType, required=True, discriminator='name')]] = [CSVCallback(name='csv')], progress_bar: TrainProgressbarConfig = TrainProgressbarConfig(disable_epoch_pbar=False, disable_batch_pbar=True), transfer_learning: TransferLearningConfig | None = None)[source]

Main configuration of a apax training run. Parameter that are config classes will be generated by parsing the config.yaml file and are specified as shown here:

Example

data:
    directory: models/
    experiment: apax
        .
        .
Parameters:
  • n_epochs (int, required) –

    Number of training epochs.

  • patience (int, optional) –

    Number of epochs without improvement before trainings gets terminated.

  • patience_min_delta (float, default = 0.0) –

    Minimum change in the monitored quantity to qualify as an improvement.

  • seed (int, default = 1) –

    Random seed.

  • data (DataConfig) –

    Data configuration.

  • model (ModelConfig) –

    Model configuration.

  • metrics (List of MetricsConfig) –

    Metrics configuration.

  • loss (List of LossConfig) –

    Loss configuration.

  • optimizer (OptimizerConfig) –

    Loss optimizer configuration.

  • weight_average (WeightAverage, optional) –

    Options for averaging weights between epochs.

  • callbacks (List of various CallBack classes) –

    Possible callbacks are CSVCallback,

  • progress_bar (TrainProgressbarConfig) –

    Progressbar configuration.

  • checkpoints (CheckpointConfig) –

    Checkpoint configuration.

  • data_parallel (bool, default = True) –

    Automatically uses all available GPUs for data parallel training.
    Set to false to force single device training.

  • ckpt_interval (int) –

    Number of epochs between checkpoints.

dump_config(save_path)[source]

Writes the current config file to the specified directory.

Parameters:

save_path (Path to the directory.)

class apax.config.train_config.OptimizerConfig(*, name: str = 'adam', emb_lr: Annotated[float, Ge(ge=0)] = 0.001, nn_lr: Annotated[float, Ge(ge=0)] = 0.001, scale_lr: Annotated[float, Ge(ge=0)] = 0.0001, shift_lr: Annotated[float, Ge(ge=0)] = 0.003, zbl_lr: Annotated[float, Ge(ge=0)] = 0.0001, rep_scale_lr: Annotated[float, Ge(ge=0)] = 0.001, rep_prefactor_lr: Annotated[float, Ge(ge=0)] = 0.0001, gradient_clipping: Annotated[float, Ge(ge=0)] = 1000.0, schedule: LinearLR | CyclicCosineLR = LinearLR(name='linear', transition_begin=0, end_value=1e-06), kwargs: dict = {})[source]

Configuration of the optimizer. Learning rates of 0 will freeze the respective parameters.

Parameters:
  • name (str, default = "adam") – Name of the optimizer. Can be any optax optimizer.

  • emb_lr (NonNegativeFloat, default = 0.001) – Learning rate of the elemental embedding contraction coefficients.

  • nn_lr (NonNegativeFloat, default = 0.001) – Learning rate of the neural network parameters.

  • scale_lr (NonNegativeFloat, default = 0.0001) – Learning rate of the elemental output scaling factors.

  • shift_lr (NonNegativeFloat, default = 0.003) – Learning rate of the elemental output shifts.

  • zbl_lr (NonNegativeFloat, default = 0.0001) – Learning rate of the ZBL correction parameters.

  • rep_scale_lr (NonNegativeFloat, default = 0.001) – LR for the length scale of these exponential repulsion potential.

  • rep_prefactor_lr (NonNegativeFloat, default = 0.0001) – LR for the strength of the exponential repulsion potential.

  • gradient_clipping (NonNegativeFloat, default = 1000.0) – Per element Gradient clipping value. Default is so high that it effectively disabled.

  • schedule (LRSchedule = LinearLR) – Learning rate schedule.

  • kwargs (dict, default = {}) – Optimizer keyword arguments. Passed to the optax optimizer.

class apax.config.train_config.DataConfig(*, directory: str, experiment: str, dataset: CachedDataset | OTFDataset | PBPDatset = CachedDataset(processing='cached', shuffle_buffer_size=1000), data_path: str | None = None, train_data_path: str | None = None, val_data_path: str | None = None, test_data_path: str | None = None, n_train: Annotated[int, Gt(gt=0)] = 1000, n_valid: Annotated[int, Gt(gt=0)] = 100, batch_size: Annotated[int, Gt(gt=0)] = 32, valid_batch_size: Annotated[int, Gt(gt=0)] = 100, shift_method: str = 'per_element_regression_shift', shift_options: dict = {'energy_regularisation': 1.0}, scale_method: str = 'per_element_force_rms_scale', scale_options: dict | None = {}, pos_unit: str | None = 'Ang', energy_unit: str | None = 'eV')[source]

Configuration for data loading, preprocessing and training.

Parameters:
  • directory (str, required) –

    Path to directory where training results and checkpoints are written.

  • experiment (str, required) –

    Model name distinguishing from others in directory.

  • data_path (str, required if train_data_path and val_data_path is not specified) –

    Path to single dataset file.

  • train_data_path (str, required if data_path is not specified) –

    Path to training dataset.

  • val_data_path (str, required if data_path is not specified) –

    Path to validation dataset.

  • test_data_path (str, optional) –

    Path to test dataset.

  • n_train (int, default = 1000) –

    Number of training datapoints from data_path.

  • n_valid (int, default = 100) –

    Number of validation datapoints from data_path.

  • batch_size (int, default = 32) –

    Number of training examples to be evaluated at once.

  • valid_batch_size (int, default = 100) –

    Number of validation examples to be evaluated at once.

  • shuffle_buffer_size (int, default = 1000) –

    Size of the tf.data shuffle buffer.

  • energy_regularisation

    Magnitude of the regularization in the per-element energy regression.

  • pos_unit (str, default = "Ang") – unit of length

  • energy_unit (str, default = "eV") – unit of energy

class apax.config.train_config.DatasetConfig(*, processing: str)[source]
class apax.config.train_config.CachedDataset(*, processing: Literal['cached'] = 'cached', shuffle_buffer_size: Annotated[int, Gt(gt=0)] = 1000)[source]

Dataset which pads everything (atoms, neighbors) to the largest system in the dataset. The NL is computed on the fly during the first epoch and stored to disk using tf.data’s cache. Most performant option for datasets with samples of very similar size.

Parameters:

shuffle_buffer_size (int) –

Size of the buffer that is shuffled by tf.data.
Larger values require more RAM.

class apax.config.train_config.PBPDatset(*, processing: Literal['pbp'] = 'pbp', num_workers: Annotated[int, Gt(gt=0)] = 10, atom_padding: Annotated[int, Gt(gt=0)] = 10, nl_padding: Annotated[int, Gt(gt=0)] = 2000)[source]

Dataset which pads everything (atoms, neighbors) to the next larges power of two. This limits the compute wasted due to padding at the (negligible) cost of some recompilations. The NL is computed on-the-fly in parallel for num_workers of batches. Does not use tf.data.

Most performant option for datasets with significantly differently sized systems (e.g. MP, SPICE).

Parameters:
  • num_workers (int) –

    Number of batches to be processed in parallel.

  • atom_padding (int) –

    Next nearest integer to which to pad per-atom arrays (positions, forces, …).

  • nl_padding (int) –

    Next nearest integer to which to pad neighborlists.

class apax.config.train_config.OTFDataset(*, processing: Literal['otf'] = 'otf', shuffle_buffer_size: Annotated[int, Gt(gt=0)] = 1000)[source]

Dataset which pads everything (atoms, neighbors) to the largest system in the dataset. The NL is computed on the fly and fed into a tf.data generator. Mostly for internal purposes.

Parameters:

shuffle_buffer_size (int) –

Size of the buffer that is shuffled by tf.data.
Larger values require more RAM.

class apax.config.train_config.MetricsConfig(*, name: str, reductions: List[str])[source]

Configuration for the metrics collected during training.

Parameters:
  • name (str) – Keyword of the quantity, e.g., ‘energy’.

  • reductions (List[str]) – List of reductions performed on the difference between target and predictions. Can be ‘mae’, ‘mse’, ‘rmse’ for energies and forces. For forces, ‘angle’ can also be used.

class apax.config.train_config.LossConfig(*, name: str, loss_type: str = 'mse', weight: Annotated[float, Ge(ge=0)] = 1.0, atoms_exponent: Annotated[float, Ge(ge=0)] = 1, parameters: dict = {})[source]

Configuration of the loss functions used during training.

Parameters:
  • name (str) – Keyword of the quantity, e.g., ‘energy’.

  • loss_type (str, optional) – Weighting scheme for atomic contributions. See the MLIP package for reference 10.1088/2632-2153/abc9fe for details, by default “mse”.

  • weight (NonNegativeFloat, optional) – Weighting factor in the overall loss function, by default 1.0.

  • atoms_exponent (NonNegativeFloat, optional) – Exponent for atomic contributions weighting, by default 1.

  • parameters (dict, optional) – Additional parameters for configuring the loss function, by default {}.

Notes

This class specifies the configuration of the loss functions used during training.

class apax.config.train_config.TrainProgressbarConfig(*, disable_epoch_pbar: bool = False, disable_batch_pbar: bool = True)[source]

Configuration of progressbars.

Parameters:
  • disable_epoch_pbar (Set to True to disable the epoch progress bar.)

  • disable_batch_pbar (Set to True to disable the batch progress bar.)

class apax.config.train_config.CSVCallback(*, name: Literal['csv'])[source]

Configuration of the CSVCallback.

Parameters:

name (Keyword of the callback used..)

class apax.config.train_config.TBCallback(*, name: Literal['tensorboard'])[source]

Configuration of the TensorBoard callback.

Parameters:

name (Keyword of the callback used..)

class apax.config.train_config.MLFlowCallback(*, name: Literal['mlflow'], experiment: str)[source]

Configuration of the MLFlow callback.

Parameters:
  • name (Keyword of the callback used.)

  • experiment (Path to the MLFlow experiment, e.g. /Users/<user>/<my_experiment>)

class apax.config.train_config.KerasPruningCallback(*, name: Literal['pruning'], trial_id: int, study_name: str, study_log_file: str | Path, interval: int = 1, monitor: str = 'val_loss')[source]

Configuration for the pruning callback for use during a trial in Optuna hyperparameter optimization

Parameters:
  • name (Literal['pruning']) – Keyword of the callback used.

  • trial_id (int) – id of the trial in the study

  • study_name (str) – name of the study

  • study_log_file (str | Path) – path to the study log file

  • interval (int, default = 1) – interval to check whether the trial should be pruned

  • monitor (str, default = "val_loss") – metric key to monitor to determine pruning

class apax.config.train_config.TransferLearningConfig(*, base_model_checkpoint: str | None = None, reset_layers: List[str] = [], freeze_layers: List[str] = [])[source]

Checkpoint configuration.

Parameters:
  • base_model_checkpoint (Path to the folder containing a pre-trained model ckpt.)

  • reset_layers (List of layer names for which the parameters will be reinitialized.)

  • freeze_layers (List of layer names for which the parameters will be frozen during training.)

class apax.config.train_config.WeightAverage(*, ema_start: int = 0, alpha: float = 0.9)[source]

Applies an exponential moving average to model parameters.

Parameters:
  • ema_start (int, default = 1) – Epoch at which to start averaging models.

  • alpha (float, default = 0.9) – How much of the new model to use. 1.0 would mean no averaging, 0.0 no updates.

Model Configuration

class apax.config.model_config.GaussianBasisConfig(*, name: Literal['gaussian'] = 'gaussian', n_basis: Annotated[int, Gt(gt=0)] = 7, r_min: Annotated[float, Ge(ge=0)] = 0.5, r_max: Annotated[float, Gt(gt=0)] = 6.0, spacing: Literal['linear', 'exponential'] = 'linear')[source]

Gaussian primitive basis functions.

Parameters:
  • n_basis (PositiveInt, default = 7) – Number of uncontracted basis functions.

  • r_min (NonNegativeFloat, default = 0.5) – Position of the first uncontracted basis function’s mean.

  • r_max (PositiveFloat, default = 6.0) – Cutoff radius of the descriptor.

  • spacing (Literal['linear', 'exponential'], default = 'linear') – Spacing of centers of Gaussians. “exponential” results in more basis functions closer to r_min, and less at r_max. See https://pubs.acs.org/doi/10.1021/acs.jctc.9b00181, Figure 2

class apax.config.model_config.BesselBasisConfig(*, name: Literal['bessel'] = 'bessel', n_basis: Annotated[int, Gt(gt=0)] = 16, r_max: Annotated[float, Gt(gt=0)] = 5.0)[source]

Gaussian primitive basis functions.

Parameters:
  • n_basis (PositiveInt, default = 16) – Number of uncontracted basis functions.

  • r_max (PositiveFloat, default = 5.0) – Cutoff radius of the descriptor.

class apax.config.model_config.FullEnsembleConfig(*, kind: Literal['full'] = 'full', n_members: int)[source]

Configuration for full model ensembles. Usage can improve accuracy and stability at the cost of slower inference. Uncertainties will generally not be calibrated.

Parameters:

n_members (int) – Number of ensemble members.

class apax.config.model_config.ShallowEnsembleConfig(*, kind: Literal['shallow'] = 'shallow', n_members: int, force_variance: bool = True, chunk_size: int | None = None)[source]

Configuration for shallow (last layer) ensembles. Allows use of probabilistic loss functions. The predicted uncertainties should be well calibrated. See 10.1088/2632-2153/ad594a for details.

Parameters:
  • n_members (int) – Number of ensemble members.

  • force_variance (bool, default = True) – Whether or not to compute force uncertainties. Required for probabilistic force loss and calibration of force uncertainties. Can lead to better force metrics but but enabling it introduces some non-negligible cost.

  • chunk_size (Optional[int], default = None) – If set to an integer, the jacobian of ensemble energies wrt. to positions will be computed in chunks of that size. This sacrifices some performance for the possibility to use relatively large ensemble sizes.

Hint

Loss type hase to be changed to a probabilistic loss like ‘nll’ or ‘crps’

class apax.config.model_config.Correction(*, name: str)[source]
class apax.config.model_config.ZBLRepulsion(*, name: Literal['zbl'], r_max: Annotated[float, Ge(ge=0)] = 1.5)[source]
class apax.config.model_config.ExponentialRepulsion(*, name: Literal['exponential'], r_max: Annotated[float, Ge(ge=0)] = 1.5)[source]
class apax.config.model_config.LatentEwald(*, name: Literal['latent_ewald'], kgrid: list, sigma: float = 1.0, use_property: str = 'charges')[source]
class apax.config.model_config.PropertyHead(*, name: str, aggregation: str = 'none', mode: str = 'l0', nn: List[Annotated[int, Gt(gt=0)]] = [128, 128], n_shallow_members: int = 0, w_init: Literal['normal', 'lecun'] = 'lecun', b_init: Literal['normal', 'zeros'] = 'zeros', use_ntk: bool = False, dtype: Literal['fp32', 'fp64'] = 'fp32')[source]
class apax.config.model_config.GMNNConfig(*, basis: GaussianBasisConfig | BesselBasisConfig = BesselBasisConfig(name='bessel', n_basis=16, r_max=5.0), nn: List[Annotated[int, Gt(gt=0)]] = [256, 256], w_init: Literal['normal', 'lecun'] = 'lecun', b_init: Literal['normal', 'zeros'] = 'zeros', activation_fn: str = 'variance_preserving_swish', use_ntk: bool = False, ensemble: FullEnsembleConfig | ShallowEnsembleConfig | None = None, property_heads: list[PropertyHead] = [], empirical_corrections: list[ZBLRepulsion | ExponentialRepulsion | LatentEwald] = [], calc_stress: bool = False, descriptor_dtype: Literal['fp32', 'fp64'] = 'fp32', readout_dtype: Literal['fp32', 'fp64'] = 'fp32', scale_shift_dtype: Literal['fp32', 'fp64'] = 'fp64', name: Literal['gmnn'] = 'gmnn', n_radial: Annotated[int, Gt(gt=0)] = 5, n_contr: int = 8, emb_init: str | None = 'uniform')[source]

Configuration for the model.

Parameters:
  • n_radial (PositiveInt, default = 5) – Number of contracted basis functions.

  • n_contr (int, default = 8) – How many gaussian moment contractions to use.

  • emb_init (Optional[str], default = "uniform") – Initialization scheme for embedding layer weights.

class apax.config.model_config.EquivMPConfig(*, basis: GaussianBasisConfig | BesselBasisConfig = BesselBasisConfig(name='bessel', n_basis=16, r_max=5.0), nn: List[Annotated[int, Gt(gt=0)]] = [256, 256], w_init: Literal['normal', 'lecun'] = 'lecun', b_init: Literal['normal', 'zeros'] = 'zeros', activation_fn: str = 'variance_preserving_swish', use_ntk: bool = False, ensemble: FullEnsembleConfig | ShallowEnsembleConfig | None = None, property_heads: list[PropertyHead] = [], empirical_corrections: list[ZBLRepulsion | ExponentialRepulsion | LatentEwald] = [], calc_stress: bool = False, descriptor_dtype: Literal['fp32', 'fp64'] = 'fp32', readout_dtype: Literal['fp32', 'fp64'] = 'fp32', scale_shift_dtype: Literal['fp32', 'fp64'] = 'fp64', name: Literal['equiv-mp'] = 'equiv-mp', features: Annotated[int, Gt(gt=0)] = 32, max_degree: Annotated[int, Gt(gt=0)] = 2, num_iterations: Annotated[int, Gt(gt=0)] = 1)[source]

Configuration for the model.

Parameters:
  • features (PositiveInt = 32) – Feature dimension of the linear layers

  • max_degree (PositiveInt = 2) – Maximal rotation order for features and tensorproducts

  • num_iterations (PositiveInt = 1) – Number of message passing steps.

class apax.config.model_config.So3kratesConfig(*, basis: GaussianBasisConfig | BesselBasisConfig = BesselBasisConfig(name='bessel', n_basis=16, r_max=5.0), nn: List[Annotated[int, Gt(gt=0)]] = [256, 256], w_init: Literal['normal', 'lecun'] = 'lecun', b_init: Literal['normal', 'zeros'] = 'zeros', activation_fn: str = 'variance_preserving_swish', use_ntk: bool = False, ensemble: FullEnsembleConfig | ShallowEnsembleConfig | None = None, property_heads: list[PropertyHead] = [], empirical_corrections: list[ZBLRepulsion | ExponentialRepulsion | LatentEwald] = [], calc_stress: bool = False, descriptor_dtype: Literal['fp32', 'fp64'] = 'fp32', readout_dtype: Literal['fp32', 'fp64'] = 'fp32', scale_shift_dtype: Literal['fp32', 'fp64'] = 'fp64', name: Literal['so3krates'] = 'so3krates', num_layers: Annotated[int, Gt(gt=0)] = 1, max_degree: Annotated[int, Gt(gt=0)] = 3, num_features: Annotated[int, Gt(gt=0)] = 128, num_heads: Annotated[int, Gt(gt=0)] = 4, use_layer_norm_1: bool = False, use_layer_norm_2: bool = False, use_layer_norm_final: bool = False, activation: str = 'silu', cutoff_fn: str = 'cosine_cutoff', transform_input_features: bool = False)[source]

Configuration for the model.

Parameters:
  • num_layers (PositiveInt = 1) – Number of message passing layers

  • max_degree (PositiveInt = 3) – Maximum rotation order

  • num_features (PositiveInt = 128) – Feature dimension

  • num_heads (PositiveInt = 4) – Number of attention heads

  • use_layer_norm_1 (bool = False) – Layer norm in transformer block

  • use_layer_norm_2 (bool = False) – Layer norm in transformer block

  • use_layer_norm_final (bool = False) – Layer norm before readout

  • activation (str = "silu") – Activation function

  • cutoff_fn (str = "cosine_cutoff") – Smooth cutoff function

  • transform_input_features (bool = False) – Whether or not to apply a dense layer to transformer input features

Molecular Dynamics Configuration

class apax.config.md_config.MDConfig(*, seed: int = 1, ensemble: NVEOptions | NVTOptions | NPTOptions = NVTOptions(name='nvt', dt=0.5, temperature_schedule=ConstantTempSchedule(name='constant', T0=298.15), thermostat_chain=NHCOptions(chain_length=3, chain_steps=2, sy_steps=3, tau=100)), duration: Annotated[float, Gt(gt=0)], n_inner: Annotated[int, Gt(gt=0)] = 500, sampling_rate: Annotated[int, Gt(gt=0)] = 10, buffer_size: Annotated[int, Gt(gt=0)] = 2500, dr_threshold: Annotated[float, Gt(gt=0)] = 0.5, extra_capacity: Annotated[int, Ge(ge=0)] = 0, disable_cell_list: bool = False, biases: list[Annotated[SphericalWallEnergy, FieldInfo(annotation=NoneType, required=True, discriminator='name')]] = [], dynamics_checks: list[Annotated[EnergyUncertaintyCheck | ForcesUncertaintyCheck | ReflectionCheck, FieldInfo(annotation=NoneType, required=True, discriminator='name')]] = [], constraints: list[Annotated[FixAtomsConstraint | FixCenterOfMassConstraint | FixLayerConstraint, FieldInfo(annotation=NoneType, required=True, discriminator='name')]] = [], properties: list[str] = ['energy', 'forces', 'stress', 'forces_uncertainty', 'energy_uncertainty', 'stress_uncertainty', 'energy_ensemble', 'forces_ensemble', 'stress_ensemble', 'energy_unbiased', 'forces_unbiased', 'charge', 'charges'], h5md_options: H5MDOptions = H5MDOptions(compression='gzip', compression_opts=4, store='time', author='N/A', author_email='N/A'), initial_structure: str, load_momenta: bool = False, sim_dir: str = '.', traj_name: str = 'md.h5', restart: bool = True, wrapped: bool = True, checkpoint_interval: int = 50000, disable_pbar: bool = False)[source]

Configuration for a NHC molecular dynamics simulation. Full config here:

Parameters:
  • seed (int, default = 1) –

    Random seed for momentum initialization.

  • ensemle

    Options for integrating the EoM and which ensemble to use.

  • dt (float, default = 0.5) –

    Time step in fs.

  • duration (float, required) –

    Total simulation time in fs.

  • n_inner (int, default = 500) –

    Number of compiled simulation steps (i.e. number of iterations of the
    jax.lax.fori_loop loop). Also determines atoms buffer size.

  • sampling_rate (int, default = 10) –

    Interval between saving frames.

  • buffer_size (int, default = 2500) –

    Number of collected frames to be dumped at once.

  • dr_threshold (float, default = 0.5) –

    Skin of the neighborlist.

  • extra_capacity (int, default = 0) –

    JaxMD allocates a maximal number of neighbors. This argument lets you add
    additional capacity to avoid recompilation. The default is usually fine.

  • biases (list[BiasEnergy]) –

    List of bias energies. Currently a spherical wall potential is available.

  • dynamics_checks (list[DynamicsCheck]) –

    List of termination criteria. Currently energy and force uncertainty
    are available

  • constraints (list[Constraint]) –

    List of constraints. Currently …

  • properties (list[str]) –

    Whitelist of properties to be saved in the trajectory.
    This does not effect what the model will calculate, e.g..
    an ensemble will still calculate uncertainties.

  • initial_structure (str, required) –

    Path to the starting structure of the simulation.

  • sim_dir (str, default = ".") –

    Directory where simulation file will be stored.

  • traj_name (str, default = "md.h5") –

    Name of the trajectory file.

  • restart (bool, default = True) –

    Whether the simulation should restart from the latest configuration in
    traj_name.

  • wrapped (bool, default = True) –

    Whether the atoms in the simulation should wrapped back into the box.

  • checkpoint_interval (int, default = 50_000) –

    Number of time steps between saving full simulation state checkpoints.
    These will be loaded with the restart option.

  • disable_pbar (bool, False) –

    Disables the MD progressbar.

dump_config()[source]

Writes the current config file to the MD directory.

class apax.config.md_config.NPTOptions(*, name: Literal['npt'], dt: Annotated[float, Gt(gt=0)] = 0.5, temperature_schedule: ConstantTempSchedule | PiecewiseLinearTempSchedule | OscillatingRampTempSchedule = ConstantTempSchedule(name='constant', T0=298.15), thermostat_chain: NHCOptions = NHCOptions(chain_length=3, chain_steps=2, sy_steps=3, tau=100), pressure: Annotated[float, Gt(gt=0)] = 1.01325, barostat_chain: NHCOptions = NHCOptions(chain_length=3, chain_steps=2, sy_steps=3, tau=1000.0))[source]

Options for NPT ensemble simulations.

Parameters:
  • name (Literal["npt"]) – Name of the ensemble.

  • pressure (PositiveFloat, default = 1.01325) – Pressure in bar.

  • barostat_chain (NHCOptions, default = NHCOptions(tau=1000)) – Barostat chain options.

Hyperparameter Optimization Configuration

class apax.config.optuna_config.OptunaPrunerConfig(*, name: str, interval: Annotated[int, Gt(gt=0)] = 1, kwargs: dict[str, Any] = {})[source]

Config for optuna pruner.

name

Name of pruner

Type:

str

interval

Interval to check whether a trial should be pruned

Type:

PositiveInt

kwargs

kwargs passed to pruner

Type:

dict[str, Any]

class apax.config.optuna_config.OptunaSamplerConfig(*, name: str = 'AutoSampler', kwargs: dict[str, Any] = {})[source]

Config for optuna sampler.

name

Name of sampler. Default: “AutoSampler”

Type:

str

kwargs

kwargs passed to sampler

Type:

str

class apax.config.optuna_config.OptunaConfig(*, n_trials: Annotated[int, Gt(gt=0)], search_space: dict[str, dict[str, Any]], seed: int = 1, monitor: str = 'val_loss', study_name: str = 'study', study_log_file: str | Path = 'study.log', sampler_config: OptunaSamplerConfig = OptunaSamplerConfig(name='AutoSampler', kwargs={}), pruner_config: OptunaPrunerConfig | None = None)[source]

Configuration of optuna study.

n_trials

number of trials

Type:

PositiveInt

search_space

dictionary indicating the search space to use for the hyperparameter optimization. Keys should be keys in the training configuration, and if the keys are nested, they should be prefixed with the parent key, and an underscore, see the example below.

Type:

dict[str, dict[str, Any]]

seed

seed to use for sampling

Type:

int

monitor

metric to monitor. Default: “val_loss”

Type:

str

study_name

name of study. Default: “study”

Type:

str

study_log_file

path to study log file. Default: “study.log”

Type:

str

sampler_config

sampler configuration

Type:

OptunaSamplerConfig

pruner_config

pruner configuration Default: None

Type:

Optional[OptunaPrunerConfig]

Example

For a search space with varying radius for the environment between 3

and 8 Angstrom, the number of tensor contractions between 5 and 8, and number of epochs between 100 and 200.

search_space = {
    'model_basis_r_max': {'type': 'float', 'low': 3, 'high': 8},
    'model_n_contr': {'type': 'int', 'low': 5, 'high': 8},
    'n_epochs': {'type': 'int', 'low': 100, 'high': 200},
}