Training¶
- class apax.train.callbacks.CSVLoggerApax(filename, separator=',', append=False)[source]¶
- on_epoch_end(epoch, logs=None)[source]¶
Called at the end of an epoch.
Subclasses should override for any actions to run. This function should only be called during TRAIN mode.
- Parameters:
epoch – Integer, index of epoch.
logs – Dict, metric results for this training epoch, and for the validation epoch if validation is performed. Validation result keys are prefixed with val_. For training epoch, the values of the Model’s metrics are returned. Example: {‘loss’: 0.2, ‘accuracy’: 0.7}.
- on_test_batch_end(batch, logs=None)[source]¶
Called at the end of a batch in evaluate methods.
Also called at the end of a validation batch in the fit methods, if validation data is provided.
Subclasses should override for any actions to run.
Note that if the steps_per_execution argument to compile in Model is set to N, this method will only be called every N batches.
- Parameters:
batch – Integer, index of batch within the current epoch.
logs – Dict. Aggregated metric results up until this batch.
- class apax.train.callbacks.KerasPruningCallback(study_name: str, trial_id: int, study_log_file: str | Path, monitor: str = 'val_loss', interval: int = 1)[source]¶
Adapted from https://optuna-integration.readthedocs.io/en/latest/_modules/optuna_integration/keras/keras.html#KerasPruningCallback Keras callback to prune unpromising trials.
See the example if you want to add a pruning callback which observes validation accuracy.
- Parameters:
study_id – id number of study
trial_id – id of current trial
study_log_file – path to log file
monitor – An evaluation metric for pruning, e.g.,
val_lossandval_accuracy. Please refer to keras.Callback reference for further details.interval – Check if trial should be pruned every n-th epoch. By default
interval=1and pruning is performed after every epoch. Increaseintervalto run several epochs faster before applying pruning.
- on_epoch_end(epoch: int, logs: dict[str, float] | None = None) None[source]¶
Called at the end of an epoch.
Subclasses should override for any actions to run. This function should only be called during TRAIN mode.
- Parameters:
epoch – Integer, index of epoch.
logs – Dict, metric results for this training epoch, and for the validation epoch if validation is performed. Validation result keys are prefixed with val_. For training epoch, the values of the Model’s metrics are returned. Example: {‘loss’: 0.2, ‘accuracy’: 0.7}.
- apax.train.checkpoints.canonicalize_energy_grad_model_parameters(params)[source]¶
Ensures that parameters from EnergyModels can be loaded into EnergyDerivativeModels by adding the “energy_model” parameter layer.
- apax.train.checkpoints.canonicalize_energy_model_parameters(params)[source]¶
Ensures that parameters from EnergyDerivativeModels can be loaded into EnergyModels by removing the “energy_model” parameter layer.
- apax.train.checkpoints.check_for_ensemble(params: FrozenDict) int[source]¶
Checks if a set of parameters belongs to an ensemble model. This is the case if all parameters share the same first dimension (parameter batch)
- apax.train.checkpoints.restore_parameters(model_dir: Path | List[Path]) Tuple[Config, FrozenDict][source]¶
Restores one or more model configs and parameters. Parameters are stacked for ensembling.
- apax.train.checkpoints.restore_single_parameters(model_dir: Path) Tuple[Config, FrozenDict][source]¶
Load the config and parameters of a single model
- apax.train.checkpoints.stack_parameters(param_list: List[FrozenDict]) FrozenDict[source]¶
Combine a list of parameter sets into a stacked version. Used for model ensembles.
- apax.train.eval.eval_model(config_path, n_test=-1, log_file='eval.log', log_level='error')[source]¶
Evaluate the model using the provided configuration.
- Parameters:
config_path (str) – Path to the configuration file.
n_test (int, default = -1) – Number of test samples to evaluate, by default -1 (evaluate all).
log_file (str, default = "eval.log") – Path to the log file.
log_level (str, default = "error") – Logging level.
- apax.train.eval.load_test_data(config, model_version_path, eval_path, n_test=-1)[source]¶
Load test data for evaluation.
- Parameters:
config (object) – Configuration object.
model_version_path (str) – Path to the model version.
eval_path (str) – Path to evaluation directory.
n_test (int, default = -1) – Number of test samples to load, by default -1 (load all).
- Returns:
List of ase.Atoms containing the test data.
- Return type:
atoms_list
- apax.train.eval.predict(model, params, Metrics, loss_fn, test_ds, callbacks, is_ensemble=False)[source]¶
Perform predictions on the test dataset.
- Parameters:
model – Trained model.
params – Model parameters.
Metrics – Collection of metrics.
loss_fn – Loss function.
test_ds – Test dataset.
callbacks – Callback functions.
is_ensemble (bool, default = False) – Whether the model is an ensemble.
- class apax.train.loss.Loss(name: str, loss_type: str, weight: float = 1.0, atoms_exponent: float = 1.0, parameters: dict = <factory>)[source]¶
Represents a single weighted loss function that is constructed from a name and a type of comparison metric.
- class apax.train.loss.LossCollection(loss_list: List[apax.train.loss.Loss])[source]¶
- apax.train.loss.crps_loss(label: Array, prediction: Array, name, parameters: dict = {}) Array[source]¶
Computes the CRPS of a gaussian distribution given means, targets and standard deviations (uncertainty estimate)
- apax.train.loss.force_angle_div_force_label(label: array, prediction: array, name, parameters: dict = {})[source]¶
Consine similarity loss function weighted by the norm of the force labels. Contributions are summed in Loss.
- apax.train.loss.force_angle_exponential_weight(label: array, prediction: array, name, parameters: dict = {}) array[source]¶
Consine similarity loss function exponentially scaled by the norm of the force labels. Contributions are summed in Loss.
- apax.train.loss.force_angle_loss(label: array, prediction: array, name, parameters: dict = {}) array[source]¶
Consine similarity loss function. Contributions are summed in Loss.
- apax.train.loss.nll_loss(label: Array, prediction: Array, name, parameters: dict = {}) Array[source]¶
Computes the gaussian NLL loss given means, targets and standard deviations (uncertainty estimate)
- apax.train.loss.weighted_huber_loss(label: array, prediction: array, name, parameters: dict = {}) array[source]¶
Huber loss function that allows weighting of individual contributions by the number of atoms in the system.
- apax.train.loss.weighted_squared_error(label: array, prediction: array, name, parameters: dict = {}) array[source]¶
Squared error function that allows weighting of individual contributions by the number of atoms in the system.
- class apax.train.metrics.RootAverage(total: Array, count: Array)[source]¶
Modifies the compute method of metrics.Average to obtain the root of the average. Meant to be used with mse_fn.
- apax.train.metrics.cosine_sim(inputs: dict, label: dict[array], prediction: dict[array], key: str) array[source]¶
Computes the cosine similarity of two arrays.
- apax.train.metrics.initialize_metrics(metrics_list) Collection[source]¶
Builds a clu metrics Collection by looping over all keys and reductions. the metrics are named according to key_reduction. See make_single_metric for details on the individual metrics.
- apax.train.metrics.mae_fn(inputs: dict, label: dict[array], prediction: dict[array], key: str) array[source]¶
Computes the Mean Absolute Error of two arrays.
- apax.train.metrics.make_single_metric(key: str, reduction: str) Average[source]¶
Builds a single clu metric where the key picks out the quantity from the model predictions dict. Metric functions (like mae_fn) are curried with the key.
- apax.train.metrics.mse_fn(inputs: dict, label: dict[array], prediction: dict[array], key: str) array[source]¶
Computes the Mean Squared Error of two arrays.
- apax.train.metrics.per_atom_mae_fn(inputs: dict, label: dict[array], prediction: dict[array], key: str) array[source]¶
Computes the per atom Mean Absolute Error of two arrays. Only reasanable when using with structural properties like ‘energy’.
- apax.train.metrics.per_atom_mse_fn(inputs: dict, label: dict[array], prediction: dict[array], key: str) array[source]¶
Computes the per atom Mean Squared Error of two arrays. Only reasanable when using with structural properties like ‘energy’.
- apax.train.run.initialize_datasets(config: Config)[source]¶
Initialize training and validation datasets based on the provided configuration.
- Parameters:
config (Config) – Configuration object all parameters.
- Returns:
train_ds (Dataset) – Training dataset.
val_ds (Dataset) – Validation dataset.
ds_stats (Dict[str, Tuple[float, float]]) – Dictionary containing scale and shift parameters for normalization.
- apax.train.run.initialize_loss_fn(loss_config_list: List[LossConfig]) LossCollection[source]¶
Initialize loss functions based on configuration.
- Parameters:
loss_config_list (List[LossConfig]) – List of loss configurations.
- Returns:
Collection of initialized loss functions.
- Return type:
- apax.train.run.run(user_config: str | PathLike | dict, log_level='error')[source]¶
Starts the training of a model with parameters provided by a the config.
- Parameters:
user_config (str | os.PathLike | dict) – training config full example can be find here:
- apax.train.run.setup_logging(log_file, log_level)[source]¶
Setup logging configuration.
- Parameters:
log_file (str) – Path to the log file.
log_level (str) – Logging level. Options: {‘debug’, ‘info’, ‘warning’, ‘error’, ‘critical’}.
- apax.train.trainer.fit(state: TrainState, train_ds: InMemoryDataset, loss_fn, Metrics: Collection, callbacks: list, n_epochs: int, ckpt_dir, ckpt_interval: int = 1, val_ds: InMemoryDataset | None = None, patience: int | None = None, patience_min_delta: float = 0.0, disable_pbar: bool = False, disable_batch_pbar: bool = True, is_ensemble=False, data_parallel=True, ema_handler: EMAParameters | None = None)[source]¶
Trains the model using the provided training dataset.
- Parameters:
state – The initial state of the model.
train_ds (InMemoryDataset) – The training dataset.
loss_fn – The loss function to be minimized.
metrics.Collection (Metrics) – Collection of metrics to evaluate during training.
callbacks (list) – List of callback functions to be executed during training.
n_epochs (int) – Number of epochs for training.
ckpt_dir – Directory to save checkpoints.
ckpt_interval (int, default = 1) – Interval for saving checkpoints.
val_ds (InMemoryDataset, default = None) – Validation dataset.
patience (int, default = None) – Patience for early stopping.
disable_pbar (bool, default = False) – Whether to disable progress bar for epochs..
disable_batch_pbar (bool, default = True) – Whether to disable progress bar for batches.
is_ensemble (bool, default = False) – Whether the model is an ensemble.
data_parallel (bool, default = True) – Whether to use data parallelism.
- apax.train.trainer.make_step_fns(loss_fn: Callable, Metrics: Collection, model: Any, is_ensemble: bool, return_predictions: bool = False) tuple[Callable, Callable][source]¶
Creates JIT-compiled training and validation step functions.
This factory handles the boilerplate for gradient calculation, state updates, metric aggregation, and optional ensemble logic.
- Parameters:
loss_fn (Callable) – A callable that takes (predictions, labels) and returns a scalar loss.
Metrics (metrics.Collection) – A class (typically a clu.metrics.Collection) used to track and merge batch statistics. Must implement single_from_model_output.
model (Any) – The model architecture (e.g., a flax.linen.Module).
is_ensemble (bool) – If True, wraps the update and eval functions with ensemble-specific handling logic.
return_predictions (bool, default = False) – If True, the validation step will return the raw model predictions in addition to metrics and loss.
- Returns:
A tuple of (train_step, val_step), where: - train_step: (carry, batch) -> (new_carry, loss) - val_step: (params, batch, metrics) -> (loss, metrics, [predictions])
- Return type:
Tuple[Callable, Callable]