Configuration¶
apax.config
Training Configuration¶
- class apax.config.train_config.Config(*, n_epochs: Annotated[int, Gt(gt=0)], patience: Annotated[int, Gt(gt=0)] | None = None, patience_min_delta: Annotated[float, Ge(ge=0)] = 0.0, seed: int = 1, data_parallel: bool = True, ckpt_interval: Annotated[int, Gt(gt=0)] = 500, data: DataConfig, model: GMNNConfig | EquivMPConfig | So3kratesConfig = GMNNConfig(basis=BesselBasisConfig(name='bessel', n_basis=16, r_max=5.0), nn=[256, 256], w_init='lecun', b_init='zeros', activation_fn='variance_preserving_swish', use_ntk=False, ensemble=None, property_heads=[], empirical_corrections=[], calc_stress=False, descriptor_dtype='fp32', readout_dtype='fp32', scale_shift_dtype='fp64', name='gmnn', n_radial=5, n_contr=8, emb_init='uniform'), metrics: List[MetricsConfig] = [], loss: List[LossConfig], optimizer: OptimizerConfig = OptimizerConfig(name='adam', emb_lr=0.001, nn_lr=0.001, scale_lr=0.0001, shift_lr=0.003, zbl_lr=0.0001, rep_scale_lr=0.001, rep_prefactor_lr=0.0001, gradient_clipping=1000.0, schedule=LinearLR(name='linear', transition_begin=0, end_value=1e-06), kwargs={}), weight_average: WeightAverage | None = None, callbacks: List[Annotated[CSVCallback | TBCallback | MLFlowCallback | KerasPruningCallback, FieldInfo(annotation=NoneType, required=True, discriminator='name')]] = [CSVCallback(name='csv')], progress_bar: TrainProgressbarConfig = TrainProgressbarConfig(disable_epoch_pbar=False, disable_batch_pbar=True), transfer_learning: TransferLearningConfig | None = None)[source]¶
Main configuration of a apax training run. Parameter that are config classes will be generated by parsing the config.yaml file and are specified as shown here:
Example
data: directory: models/ experiment: apax . .
- Parameters:
n_epochs (int, required) –
Number of training epochs.patience (int, optional) –
Number of epochs without improvement before trainings gets terminated.patience_min_delta (float, default = 0.0) –
Minimum change in the monitored quantity to qualify as an improvement.seed (int, default = 1) –
Random seed.data (
DataConfig) –Data configuration.model (
ModelConfig) –Model configuration.metrics (List of
MetricsConfig) –Metrics configuration.loss (List of
LossConfig) –Loss configuration.optimizer (
OptimizerConfig) –Loss optimizer configuration.weight_average (
WeightAverage, optional) –Options for averaging weights between epochs.callbacks (List of various CallBack classes) –
Possible callbacks areCSVCallback,progress_bar (
TrainProgressbarConfig) –Progressbar configuration.checkpoints (
CheckpointConfig) –Checkpoint configuration.data_parallel (bool, default = True) –
Automatically uses all available GPUs for data parallel training.Set to false to force single device training.ckpt_interval (int) –
Number of epochs between checkpoints.
- class apax.config.train_config.OptimizerConfig(*, name: str = 'adam', emb_lr: Annotated[float, Ge(ge=0)] = 0.001, nn_lr: Annotated[float, Ge(ge=0)] = 0.001, scale_lr: Annotated[float, Ge(ge=0)] = 0.0001, shift_lr: Annotated[float, Ge(ge=0)] = 0.003, zbl_lr: Annotated[float, Ge(ge=0)] = 0.0001, rep_scale_lr: Annotated[float, Ge(ge=0)] = 0.001, rep_prefactor_lr: Annotated[float, Ge(ge=0)] = 0.0001, gradient_clipping: Annotated[float, Ge(ge=0)] = 1000.0, schedule: LinearLR | CyclicCosineLR = LinearLR(name='linear', transition_begin=0, end_value=1e-06), kwargs: dict = {})[source]¶
Configuration of the optimizer. Learning rates of 0 will freeze the respective parameters.
- Parameters:
name (str, default = "adam") – Name of the optimizer. Can be any optax optimizer.
emb_lr (NonNegativeFloat, default = 0.001) – Learning rate of the elemental embedding contraction coefficients.
nn_lr (NonNegativeFloat, default = 0.001) – Learning rate of the neural network parameters.
scale_lr (NonNegativeFloat, default = 0.0001) – Learning rate of the elemental output scaling factors.
shift_lr (NonNegativeFloat, default = 0.003) – Learning rate of the elemental output shifts.
zbl_lr (NonNegativeFloat, default = 0.0001) – Learning rate of the ZBL correction parameters.
rep_scale_lr (NonNegativeFloat, default = 0.001) – LR for the length scale of these exponential repulsion potential.
rep_prefactor_lr (NonNegativeFloat, default = 0.0001) – LR for the strength of the exponential repulsion potential.
gradient_clipping (NonNegativeFloat, default = 1000.0) – Per element Gradient clipping value. Default is so high that it effectively disabled.
schedule (LRSchedule = LinearLR) – Learning rate schedule.
kwargs (dict, default = {}) – Optimizer keyword arguments. Passed to the optax optimizer.
- class apax.config.train_config.DataConfig(*, directory: str, experiment: str, dataset: CachedDataset | OTFDataset | PBPDatset = CachedDataset(processing='cached', shuffle_buffer_size=1000), data_path: str | None = None, train_data_path: str | None = None, val_data_path: str | None = None, test_data_path: str | None = None, n_train: Annotated[int, Gt(gt=0)] = 1000, n_valid: Annotated[int, Gt(gt=0)] = 100, batch_size: Annotated[int, Gt(gt=0)] = 32, valid_batch_size: Annotated[int, Gt(gt=0)] = 100, shift_method: str = 'per_element_regression_shift', shift_options: dict = {'energy_regularisation': 1.0}, scale_method: str = 'per_element_force_rms_scale', scale_options: dict | None = {}, pos_unit: str | None = 'Ang', energy_unit: str | None = 'eV')[source]¶
Configuration for data loading, preprocessing and training.
- Parameters:
directory (str, required) –
Path to directory where training results and checkpoints are written.experiment (str, required) –
Model name distinguishing from others in directory.data_path (str, required if train_data_path and val_data_path is not specified) –
Path to single dataset file.train_data_path (str, required if data_path is not specified) –
Path to training dataset.val_data_path (str, required if data_path is not specified) –
Path to validation dataset.test_data_path (str, optional) –
Path to test dataset.n_train (int, default = 1000) –
Number of training datapoints from data_path.n_valid (int, default = 100) –
Number of validation datapoints from data_path.batch_size (int, default = 32) –
Number of training examples to be evaluated at once.valid_batch_size (int, default = 100) –
Number of validation examples to be evaluated at once.shuffle_buffer_size (int, default = 1000) –
Size of the tf.data shuffle buffer.energy_regularisation –
Magnitude of the regularization in the per-element energy regression.pos_unit (str, default = "Ang") – unit of length
energy_unit (str, default = "eV") – unit of energy
- class apax.config.train_config.CachedDataset(*, processing: Literal['cached'] = 'cached', shuffle_buffer_size: Annotated[int, Gt(gt=0)] = 1000)[source]¶
Dataset which pads everything (atoms, neighbors) to the largest system in the dataset. The NL is computed on the fly during the first epoch and stored to disk using tf.data’s cache. Most performant option for datasets with samples of very similar size.
- Parameters:
shuffle_buffer_size (int) –
Size of the buffer that is shuffled by tf.data.Larger values require more RAM.
- class apax.config.train_config.PBPDatset(*, processing: Literal['pbp'] = 'pbp', num_workers: Annotated[int, Gt(gt=0)] = 10, atom_padding: Annotated[int, Gt(gt=0)] = 10, nl_padding: Annotated[int, Gt(gt=0)] = 2000)[source]¶
Dataset which pads everything (atoms, neighbors) to the next larges power of two. This limits the compute wasted due to padding at the (negligible) cost of some recompilations. The NL is computed on-the-fly in parallel for num_workers of batches. Does not use tf.data.
Most performant option for datasets with significantly differently sized systems (e.g. MP, SPICE).
- Parameters:
num_workers (int) –
Number of batches to be processed in parallel.atom_padding (int) –
Next nearest integer to which to pad per-atom arrays (positions, forces, …).nl_padding (int) –
Next nearest integer to which to pad neighborlists.
- class apax.config.train_config.OTFDataset(*, processing: Literal['otf'] = 'otf', shuffle_buffer_size: Annotated[int, Gt(gt=0)] = 1000)[source]¶
Dataset which pads everything (atoms, neighbors) to the largest system in the dataset. The NL is computed on the fly and fed into a tf.data generator. Mostly for internal purposes.
- Parameters:
shuffle_buffer_size (int) –
Size of the buffer that is shuffled by tf.data.Larger values require more RAM.
- class apax.config.train_config.MetricsConfig(*, name: str, reductions: List[str])[source]¶
Configuration for the metrics collected during training.
- Parameters:
name (str) – Keyword of the quantity, e.g., ‘energy’.
reductions (List[str]) – List of reductions performed on the difference between target and predictions. Can be ‘mae’, ‘mse’, ‘rmse’ for energies and forces. For forces, ‘angle’ can also be used.
- class apax.config.train_config.LossConfig(*, name: str, loss_type: str = 'mse', weight: Annotated[float, Ge(ge=0)] = 1.0, atoms_exponent: Annotated[float, Ge(ge=0)] = 1, parameters: dict = {})[source]¶
Configuration of the loss functions used during training.
- Parameters:
name (str) – Keyword of the quantity, e.g., ‘energy’.
loss_type (str, optional) – Weighting scheme for atomic contributions. See the MLIP package for reference 10.1088/2632-2153/abc9fe for details, by default “mse”.
weight (NonNegativeFloat, optional) – Weighting factor in the overall loss function, by default 1.0.
atoms_exponent (NonNegativeFloat, optional) – Exponent for atomic contributions weighting, by default 1.
parameters (dict, optional) – Additional parameters for configuring the loss function, by default {}.
Notes
This class specifies the configuration of the loss functions used during training.
- class apax.config.train_config.TrainProgressbarConfig(*, disable_epoch_pbar: bool = False, disable_batch_pbar: bool = True)[source]¶
Configuration of progressbars.
- Parameters:
disable_epoch_pbar (Set to True to disable the epoch progress bar.)
disable_batch_pbar (Set to True to disable the batch progress bar.)
- class apax.config.train_config.CSVCallback(*, name: Literal['csv'])[source]¶
Configuration of the CSVCallback.
- Parameters:
name (Keyword of the callback used..)
- class apax.config.train_config.TBCallback(*, name: Literal['tensorboard'])[source]¶
Configuration of the TensorBoard callback.
- Parameters:
name (Keyword of the callback used..)
- class apax.config.train_config.MLFlowCallback(*, name: Literal['mlflow'], experiment: str)[source]¶
Configuration of the MLFlow callback.
- Parameters:
name (Keyword of the callback used.)
experiment (Path to the MLFlow experiment, e.g. /Users/<user>/<my_experiment>)
- class apax.config.train_config.KerasPruningCallback(*, name: Literal['pruning'], trial_id: int, study_name: str, study_log_file: str | Path, interval: int = 1, monitor: str = 'val_loss')[source]¶
Configuration for the pruning callback for use during a trial in Optuna hyperparameter optimization
- Parameters:
name (Literal['pruning']) – Keyword of the callback used.
trial_id (int) – id of the trial in the study
study_name (str) – name of the study
study_log_file (str | Path) – path to the study log file
interval (int, default = 1) – interval to check whether the trial should be pruned
monitor (str, default = "val_loss") – metric key to monitor to determine pruning
- class apax.config.train_config.TransferLearningConfig(*, base_model_checkpoint: str | None = None, reset_layers: List[str] = [], freeze_layers: List[str] = [])[source]¶
Checkpoint configuration.
- Parameters:
base_model_checkpoint (Path to the folder containing a pre-trained model ckpt.)
reset_layers (List of layer names for which the parameters will be reinitialized.)
freeze_layers (List of layer names for which the parameters will be frozen during training.)
- class apax.config.train_config.WeightAverage(*, ema_start: int = 0, alpha: float = 0.9)[source]¶
Applies an exponential moving average to model parameters.
- Parameters:
ema_start (int, default = 1) – Epoch at which to start averaging models.
alpha (float, default = 0.9) – How much of the new model to use. 1.0 would mean no averaging, 0.0 no updates.
Model Configuration¶
- class apax.config.model_config.GaussianBasisConfig(*, name: Literal['gaussian'] = 'gaussian', n_basis: Annotated[int, Gt(gt=0)] = 7, r_min: Annotated[float, Ge(ge=0)] = 0.5, r_max: Annotated[float, Gt(gt=0)] = 6.0, spacing: Literal['linear', 'exponential'] = 'linear')[source]¶
Gaussian primitive basis functions.
- Parameters:
n_basis (PositiveInt, default = 7) – Number of uncontracted basis functions.
r_min (NonNegativeFloat, default = 0.5) – Position of the first uncontracted basis function’s mean.
r_max (PositiveFloat, default = 6.0) – Cutoff radius of the descriptor.
spacing (Literal['linear', 'exponential'], default = 'linear') – Spacing of centers of Gaussians. “exponential” results in more basis functions closer to r_min, and less at r_max. See https://pubs.acs.org/doi/10.1021/acs.jctc.9b00181, Figure 2
- class apax.config.model_config.BesselBasisConfig(*, name: Literal['bessel'] = 'bessel', n_basis: Annotated[int, Gt(gt=0)] = 16, r_max: Annotated[float, Gt(gt=0)] = 5.0)[source]¶
Gaussian primitive basis functions.
- Parameters:
n_basis (PositiveInt, default = 16) – Number of uncontracted basis functions.
r_max (PositiveFloat, default = 5.0) – Cutoff radius of the descriptor.
- class apax.config.model_config.FullEnsembleConfig(*, kind: Literal['full'] = 'full', n_members: int)[source]¶
Configuration for full model ensembles. Usage can improve accuracy and stability at the cost of slower inference. Uncertainties will generally not be calibrated.
- Parameters:
n_members (int) – Number of ensemble members.
- class apax.config.model_config.ShallowEnsembleConfig(*, kind: Literal['shallow'] = 'shallow', n_members: int, force_variance: bool = True, chunk_size: int | None = None)[source]¶
Configuration for shallow (last layer) ensembles. Allows use of probabilistic loss functions. The predicted uncertainties should be well calibrated. See 10.1088/2632-2153/ad594a for details.
- Parameters:
n_members (int) – Number of ensemble members.
force_variance (bool, default = True) – Whether or not to compute force uncertainties. Required for probabilistic force loss and calibration of force uncertainties. Can lead to better force metrics but but enabling it introduces some non-negligible cost.
chunk_size (Optional[int], default = None) – If set to an integer, the jacobian of ensemble energies wrt. to positions will be computed in chunks of that size. This sacrifices some performance for the possibility to use relatively large ensemble sizes.
Hint
Loss type hase to be changed to a probabilistic loss like ‘nll’ or ‘crps’
- class apax.config.model_config.ZBLRepulsion(*, name: Literal['zbl'], r_max: Annotated[float, Ge(ge=0)] = 1.5)[source]¶
- class apax.config.model_config.ExponentialRepulsion(*, name: Literal['exponential'], r_max: Annotated[float, Ge(ge=0)] = 1.5)[source]¶
- class apax.config.model_config.LatentEwald(*, name: Literal['latent_ewald'], kgrid: list, sigma: float = 1.0, use_property: str = 'charges')[source]¶
- class apax.config.model_config.PropertyHead(*, name: str, aggregation: str = 'none', mode: str = 'l0', nn: List[Annotated[int, Gt(gt=0)]] = [128, 128], n_shallow_members: int = 0, w_init: Literal['normal', 'lecun'] = 'lecun', b_init: Literal['normal', 'zeros'] = 'zeros', use_ntk: bool = False, dtype: Literal['fp32', 'fp64'] = 'fp32')[source]¶
- class apax.config.model_config.GMNNConfig(*, basis: GaussianBasisConfig | BesselBasisConfig = BesselBasisConfig(name='bessel', n_basis=16, r_max=5.0), nn: List[Annotated[int, Gt(gt=0)]] = [256, 256], w_init: Literal['normal', 'lecun'] = 'lecun', b_init: Literal['normal', 'zeros'] = 'zeros', activation_fn: str = 'variance_preserving_swish', use_ntk: bool = False, ensemble: FullEnsembleConfig | ShallowEnsembleConfig | None = None, property_heads: list[PropertyHead] = [], empirical_corrections: list[ZBLRepulsion | ExponentialRepulsion | LatentEwald] = [], calc_stress: bool = False, descriptor_dtype: Literal['fp32', 'fp64'] = 'fp32', readout_dtype: Literal['fp32', 'fp64'] = 'fp32', scale_shift_dtype: Literal['fp32', 'fp64'] = 'fp64', name: Literal['gmnn'] = 'gmnn', n_radial: Annotated[int, Gt(gt=0)] = 5, n_contr: int = 8, emb_init: str | None = 'uniform')[source]¶
Configuration for the model.
- Parameters:
n_radial (PositiveInt, default = 5) – Number of contracted basis functions.
n_contr (int, default = 8) – How many gaussian moment contractions to use.
emb_init (Optional[str], default = "uniform") – Initialization scheme for embedding layer weights.
- class apax.config.model_config.EquivMPConfig(*, basis: GaussianBasisConfig | BesselBasisConfig = BesselBasisConfig(name='bessel', n_basis=16, r_max=5.0), nn: List[Annotated[int, Gt(gt=0)]] = [256, 256], w_init: Literal['normal', 'lecun'] = 'lecun', b_init: Literal['normal', 'zeros'] = 'zeros', activation_fn: str = 'variance_preserving_swish', use_ntk: bool = False, ensemble: FullEnsembleConfig | ShallowEnsembleConfig | None = None, property_heads: list[PropertyHead] = [], empirical_corrections: list[ZBLRepulsion | ExponentialRepulsion | LatentEwald] = [], calc_stress: bool = False, descriptor_dtype: Literal['fp32', 'fp64'] = 'fp32', readout_dtype: Literal['fp32', 'fp64'] = 'fp32', scale_shift_dtype: Literal['fp32', 'fp64'] = 'fp64', name: Literal['equiv-mp'] = 'equiv-mp', features: Annotated[int, Gt(gt=0)] = 32, max_degree: Annotated[int, Gt(gt=0)] = 2, num_iterations: Annotated[int, Gt(gt=0)] = 1)[source]¶
Configuration for the model.
- Parameters:
features (PositiveInt = 32) – Feature dimension of the linear layers
max_degree (PositiveInt = 2) – Maximal rotation order for features and tensorproducts
num_iterations (PositiveInt = 1) – Number of message passing steps.
- class apax.config.model_config.So3kratesConfig(*, basis: GaussianBasisConfig | BesselBasisConfig = BesselBasisConfig(name='bessel', n_basis=16, r_max=5.0), nn: List[Annotated[int, Gt(gt=0)]] = [256, 256], w_init: Literal['normal', 'lecun'] = 'lecun', b_init: Literal['normal', 'zeros'] = 'zeros', activation_fn: str = 'variance_preserving_swish', use_ntk: bool = False, ensemble: FullEnsembleConfig | ShallowEnsembleConfig | None = None, property_heads: list[PropertyHead] = [], empirical_corrections: list[ZBLRepulsion | ExponentialRepulsion | LatentEwald] = [], calc_stress: bool = False, descriptor_dtype: Literal['fp32', 'fp64'] = 'fp32', readout_dtype: Literal['fp32', 'fp64'] = 'fp32', scale_shift_dtype: Literal['fp32', 'fp64'] = 'fp64', name: Literal['so3krates'] = 'so3krates', num_layers: Annotated[int, Gt(gt=0)] = 1, max_degree: Annotated[int, Gt(gt=0)] = 3, num_features: Annotated[int, Gt(gt=0)] = 128, num_heads: Annotated[int, Gt(gt=0)] = 4, use_layer_norm_1: bool = False, use_layer_norm_2: bool = False, use_layer_norm_final: bool = False, activation: str = 'silu', cutoff_fn: str = 'cosine_cutoff', transform_input_features: bool = False)[source]¶
Configuration for the model.
- Parameters:
num_layers (PositiveInt = 1) – Number of message passing layers
max_degree (PositiveInt = 3) – Maximum rotation order
num_features (PositiveInt = 128) – Feature dimension
num_heads (PositiveInt = 4) – Number of attention heads
use_layer_norm_1 (bool = False) – Layer norm in transformer block
use_layer_norm_2 (bool = False) – Layer norm in transformer block
use_layer_norm_final (bool = False) – Layer norm before readout
activation (str = "silu") – Activation function
cutoff_fn (str = "cosine_cutoff") – Smooth cutoff function
transform_input_features (bool = False) – Whether or not to apply a dense layer to transformer input features
Molecular Dynamics Configuration¶
- class apax.config.md_config.MDConfig(*, seed: int = 1, ensemble: NVEOptions | NVTOptions | NPTOptions = NVTOptions(name='nvt', dt=0.5, temperature_schedule=ConstantTempSchedule(name='constant', T0=298.15), thermostat_chain=NHCOptions(chain_length=3, chain_steps=2, sy_steps=3, tau=100)), duration: Annotated[float, Gt(gt=0)], n_inner: Annotated[int, Gt(gt=0)] = 500, sampling_rate: Annotated[int, Gt(gt=0)] = 10, buffer_size: Annotated[int, Gt(gt=0)] = 2500, dr_threshold: Annotated[float, Gt(gt=0)] = 0.5, extra_capacity: Annotated[int, Ge(ge=0)] = 0, disable_cell_list: bool = False, biases: list[Annotated[SphericalWallEnergy, FieldInfo(annotation=NoneType, required=True, discriminator='name')]] = [], dynamics_checks: list[Annotated[EnergyUncertaintyCheck | ForcesUncertaintyCheck | ReflectionCheck, FieldInfo(annotation=NoneType, required=True, discriminator='name')]] = [], constraints: list[Annotated[FixAtomsConstraint | FixCenterOfMassConstraint | FixLayerConstraint, FieldInfo(annotation=NoneType, required=True, discriminator='name')]] = [], properties: list[str] = ['energy', 'forces', 'stress', 'forces_uncertainty', 'energy_uncertainty', 'stress_uncertainty', 'energy_ensemble', 'forces_ensemble', 'stress_ensemble', 'energy_unbiased', 'forces_unbiased', 'charge', 'charges'], h5md_options: H5MDOptions = H5MDOptions(compression='gzip', compression_opts=4, store='time', author='N/A', author_email='N/A'), initial_structure: str, load_momenta: bool = False, sim_dir: str = '.', traj_name: str = 'md.h5', restart: bool = True, wrapped: bool = True, checkpoint_interval: int = 50000, disable_pbar: bool = False)[source]¶
Configuration for a NHC molecular dynamics simulation. Full config here:
- Parameters:
seed (int, default = 1) –
Random seed for momentum initialization.ensemle –
Options for integrating the EoM and which ensemble to use.dt (float, default = 0.5) –
Time step in fs.duration (float, required) –
Total simulation time in fs.n_inner (int, default = 500) –
Number of compiled simulation steps (i.e. number of iterations of thejax.lax.fori_loop loop). Also determines atoms buffer size.sampling_rate (int, default = 10) –
Interval between saving frames.buffer_size (int, default = 2500) –
Number of collected frames to be dumped at once.dr_threshold (float, default = 0.5) –
Skin of the neighborlist.extra_capacity (int, default = 0) –
JaxMD allocates a maximal number of neighbors. This argument lets you addadditional capacity to avoid recompilation. The default is usually fine.biases (list[BiasEnergy]) –
List of bias energies. Currently a spherical wall potential is available.dynamics_checks (list[DynamicsCheck]) –
List of termination criteria. Currently energy and force uncertaintyare availableconstraints (list[Constraint]) –
List of constraints. Currently …properties (list[str]) –
Whitelist of properties to be saved in the trajectory.This does not effect what the model will calculate, e.g..an ensemble will still calculate uncertainties.initial_structure (str, required) –
Path to the starting structure of the simulation.sim_dir (str, default = ".") –
Directory where simulation file will be stored.traj_name (str, default = "md.h5") –
Name of the trajectory file.restart (bool, default = True) –
Whether the simulation should restart from the latest configuration intraj_name.wrapped (bool, default = True) –
Whether the atoms in the simulation should wrapped back into the box.checkpoint_interval (int, default = 50_000) –
Number of time steps between saving full simulation state checkpoints.These will be loaded with the restart option.disable_pbar (bool, False) –
Disables the MD progressbar.
- class apax.config.md_config.NPTOptions(*, name: Literal['npt'], dt: Annotated[float, Gt(gt=0)] = 0.5, temperature_schedule: ConstantTempSchedule | PiecewiseLinearTempSchedule | OscillatingRampTempSchedule = ConstantTempSchedule(name='constant', T0=298.15), thermostat_chain: NHCOptions = NHCOptions(chain_length=3, chain_steps=2, sy_steps=3, tau=100), pressure: Annotated[float, Gt(gt=0)] = 1.01325, barostat_chain: NHCOptions = NHCOptions(chain_length=3, chain_steps=2, sy_steps=3, tau=1000.0))[source]¶
Options for NPT ensemble simulations.
- Parameters:
name (Literal["npt"]) – Name of the ensemble.
pressure (PositiveFloat, default = 1.01325) – Pressure in bar.
barostat_chain (NHCOptions, default = NHCOptions(tau=1000)) – Barostat chain options.
Hyperparameter Optimization Configuration¶
- class apax.config.optuna_config.OptunaPrunerConfig(*, name: str, interval: Annotated[int, Gt(gt=0)] = 1, kwargs: dict[str, Any] = {})[source]¶
Config for optuna pruner.
- name¶
Name of pruner
- Type:
str
- interval¶
Interval to check whether a trial should be pruned
- Type:
PositiveInt
- kwargs¶
kwargs passed to pruner
- Type:
dict[str, Any]
- class apax.config.optuna_config.OptunaSamplerConfig(*, name: str = 'AutoSampler', kwargs: dict[str, Any] = {})[source]¶
Config for optuna sampler.
- name¶
Name of sampler. Default: “AutoSampler”
- Type:
str
- kwargs¶
kwargs passed to sampler
- Type:
str
- class apax.config.optuna_config.OptunaConfig(*, n_trials: Annotated[int, Gt(gt=0)], search_space: dict[str, dict[str, Any]], seed: int = 1, monitor: str = 'val_loss', study_name: str = 'study', study_log_file: str | Path = 'study.log', sampler_config: OptunaSamplerConfig = OptunaSamplerConfig(name='AutoSampler', kwargs={}), pruner_config: OptunaPrunerConfig | None = None)[source]¶
Configuration of optuna study.
- n_trials¶
number of trials
- Type:
PositiveInt
- search_space¶
dictionary indicating the search space to use for the hyperparameter optimization. Keys should be keys in the training configuration, and if the keys are nested, they should be prefixed with the parent key, and an underscore, see the example below.
- Type:
dict[str, dict[str, Any]]
- seed¶
seed to use for sampling
- Type:
int
- monitor¶
metric to monitor. Default: “val_loss”
- Type:
str
- study_name¶
name of study. Default: “study”
- Type:
str
- study_log_file¶
path to study log file. Default: “study.log”
- Type:
str
- sampler_config¶
sampler configuration
- Type:
- pruner_config¶
pruner configuration Default: None
- Type:
Optional[OptunaPrunerConfig]
Example
- For a search space with varying radius for the environment between 3
and 8 Angstrom, the number of tensor contractions between 5 and 8, and number of epochs between 100 and 200.
search_space = { 'model_basis_r_max': {'type': 'float', 'low': 3, 'high': 8}, 'model_n_contr': {'type': 'int', 'low': 5, 'high': 8}, 'n_epochs': {'type': 'int', 'low': 100, 'high': 200}, }