Training

class apax.train.callbacks.CSVLoggerApax(filename, separator=',', append=False)[source]
on_epoch_end(epoch, logs=None)[source]

Called at the end of an epoch.

Subclasses should override for any actions to run. This function should only be called during TRAIN mode.

Parameters:
  • epoch – Integer, index of epoch.

  • logs – Dict, metric results for this training epoch, and for the validation epoch if validation is performed. Validation result keys are prefixed with val_. For training epoch, the values of the Model’s metrics are returned. Example: {‘loss’: 0.2, ‘accuracy’: 0.7}.

on_test_batch_end(batch, logs=None)[source]

Called at the end of a batch in evaluate methods.

Also called at the end of a validation batch in the fit methods, if validation data is provided.

Subclasses should override for any actions to run.

Note that if the steps_per_execution argument to compile in Model is set to N, this method will only be called every N batches.

Parameters:
  • batch – Integer, index of batch within the current epoch.

  • logs – Dict. Aggregated metric results up until this batch.

class apax.train.callbacks.KerasPruningCallback(study_name: str, trial_id: int, study_log_file: str | Path, monitor: str = 'val_loss', interval: int = 1)[source]

Adapted from https://optuna-integration.readthedocs.io/en/latest/_modules/optuna_integration/keras/keras.html#KerasPruningCallback Keras callback to prune unpromising trials.

See the example if you want to add a pruning callback which observes validation accuracy.

Parameters:
  • study_id – id number of study

  • trial_id – id of current trial

  • study_log_file – path to log file

  • monitor – An evaluation metric for pruning, e.g., val_loss and val_accuracy. Please refer to keras.Callback reference for further details.

  • interval – Check if trial should be pruned every n-th epoch. By default interval=1 and pruning is performed after every epoch. Increase interval to run several epochs faster before applying pruning.

on_epoch_end(epoch: int, logs: dict[str, float] | None = None) None[source]

Called at the end of an epoch.

Subclasses should override for any actions to run. This function should only be called during TRAIN mode.

Parameters:
  • epoch – Integer, index of epoch.

  • logs – Dict, metric results for this training epoch, and for the validation epoch if validation is performed. Validation result keys are prefixed with val_. For training epoch, the values of the Model’s metrics are returned. Example: {‘loss’: 0.2, ‘accuracy’: 0.7}.

apax.train.checkpoints.canonicalize_energy_grad_model_parameters(params)[source]

Ensures that parameters from EnergyModels can be loaded into EnergyDerivativeModels by adding the “energy_model” parameter layer.

apax.train.checkpoints.canonicalize_energy_model_parameters(params)[source]

Ensures that parameters from EnergyDerivativeModels can be loaded into EnergyModels by removing the “energy_model” parameter layer.

apax.train.checkpoints.check_for_ensemble(params: FrozenDict) int[source]

Checks if a set of parameters belongs to an ensemble model. This is the case if all parameters share the same first dimension (parameter batch)

apax.train.checkpoints.restore_parameters(model_dir: Path | List[Path]) Tuple[Config, FrozenDict][source]

Restores one or more model configs and parameters. Parameters are stacked for ensembling.

apax.train.checkpoints.restore_single_parameters(model_dir: Path) Tuple[Config, FrozenDict][source]

Load the config and parameters of a single model

apax.train.checkpoints.stack_parameters(param_list: List[FrozenDict]) FrozenDict[source]

Combine a list of parameter sets into a stacked version. Used for model ensembles.

apax.train.eval.eval_model(config_path, n_test=-1, log_file='eval.log', log_level='error')[source]

Evaluate the model using the provided configuration.

Parameters:
  • config_path (str) – Path to the configuration file.

  • n_test (int, default = -1) – Number of test samples to evaluate, by default -1 (evaluate all).

  • log_file (str, default = "eval.log") – Path to the log file.

  • log_level (str, default = "error") – Logging level.

apax.train.eval.load_test_data(config, model_version_path, eval_path, n_test=-1)[source]

Load test data for evaluation.

Parameters:
  • config (object) – Configuration object.

  • model_version_path (str) – Path to the model version.

  • eval_path (str) – Path to evaluation directory.

  • n_test (int, default = -1) – Number of test samples to load, by default -1 (load all).

Returns:

List of ase.Atoms containing the test data.

Return type:

atoms_list

apax.train.eval.predict(model, params, Metrics, loss_fn, test_ds, callbacks, is_ensemble=False)[source]

Perform predictions on the test dataset.

Parameters:
  • model – Trained model.

  • params – Model parameters.

  • Metrics – Collection of metrics.

  • loss_fn – Loss function.

  • test_ds – Test dataset.

  • callbacks – Callback functions.

  • is_ensemble (bool, default = False) – Whether the model is an ensemble.

class apax.train.loss.Loss(name: str, loss_type: str, weight: float = 1.0, atoms_exponent: float = 1.0, parameters: dict = <factory>)[source]

Represents a single weighted loss function that is constructed from a name and a type of comparison metric.

class apax.train.loss.LossCollection(loss_list: List[apax.train.loss.Loss])[source]
apax.train.loss.crps_loss(label: Array, prediction: Array, name, parameters: dict = {}) Array[source]

Computes the CRPS of a gaussian distribution given means, targets and standard deviations (uncertainty estimate)

apax.train.loss.force_angle_div_force_label(label: array, prediction: array, name, parameters: dict = {})[source]

Consine similarity loss function weighted by the norm of the force labels. Contributions are summed in Loss.

apax.train.loss.force_angle_exponential_weight(label: array, prediction: array, name, parameters: dict = {}) array[source]

Consine similarity loss function exponentially scaled by the norm of the force labels. Contributions are summed in Loss.

apax.train.loss.force_angle_loss(label: array, prediction: array, name, parameters: dict = {}) array[source]

Consine similarity loss function. Contributions are summed in Loss.

apax.train.loss.nll_loss(label: Array, prediction: Array, name, parameters: dict = {}) Array[source]

Computes the gaussian NLL loss given means, targets and standard deviations (uncertainty estimate)

apax.train.loss.weighted_huber_loss(label: array, prediction: array, name, parameters: dict = {}) array[source]

Huber loss function that allows weighting of individual contributions by the number of atoms in the system.

apax.train.loss.weighted_squared_error(label: array, prediction: array, name, parameters: dict = {}) array[source]

Squared error function that allows weighting of individual contributions by the number of atoms in the system.

class apax.train.metrics.Averagefp64(total: Array, count: Array)[source]
classmethod empty() Metric[source]

Returns an empty instance (i.e. .merge(Metric.empty()) is a no-op).

class apax.train.metrics.RootAverage(total: Array, count: Array)[source]

Modifies the compute method of metrics.Average to obtain the root of the average. Meant to be used with mse_fn.

compute() array[source]

Computes final metrics from intermediate values.

apax.train.metrics.cosine_sim(inputs: dict, label: dict[array], prediction: dict[array], key: str) array[source]

Computes the cosine similarity of two arrays.

apax.train.metrics.initialize_metrics(metrics_list) Collection[source]

Builds a clu metrics Collection by looping over all keys and reductions. the metrics are named according to key_reduction. See make_single_metric for details on the individual metrics.

apax.train.metrics.mae_fn(inputs: dict, label: dict[array], prediction: dict[array], key: str) array[source]

Computes the Mean Absolute Error of two arrays.

apax.train.metrics.make_single_metric(key: str, reduction: str) Average[source]

Builds a single clu metric where the key picks out the quantity from the model predictions dict. Metric functions (like mae_fn) are curried with the key.

apax.train.metrics.mse_fn(inputs: dict, label: dict[array], prediction: dict[array], key: str) array[source]

Computes the Mean Squared Error of two arrays.

apax.train.metrics.per_atom_mae_fn(inputs: dict, label: dict[array], prediction: dict[array], key: str) array[source]

Computes the per atom Mean Absolute Error of two arrays. Only reasanable when using with structural properties like ‘energy’.

apax.train.metrics.per_atom_mse_fn(inputs: dict, label: dict[array], prediction: dict[array], key: str) array[source]

Computes the per atom Mean Squared Error of two arrays. Only reasanable when using with structural properties like ‘energy’.

apax.train.run.initialize_datasets(config: Config)[source]

Initialize training and validation datasets based on the provided configuration.

Parameters:

config (Config) – Configuration object all parameters.

Returns:

  • train_ds (Dataset) – Training dataset.

  • val_ds (Dataset) – Validation dataset.

  • ds_stats (Dict[str, Tuple[float, float]]) – Dictionary containing scale and shift parameters for normalization.

apax.train.run.initialize_loss_fn(loss_config_list: List[LossConfig]) LossCollection[source]

Initialize loss functions based on configuration.

Parameters:

loss_config_list (List[LossConfig]) – List of loss configurations.

Returns:

Collection of initialized loss functions.

Return type:

LossCollection

apax.train.run.run(user_config: str | PathLike | dict, log_level='error')[source]

Starts the training of a model with parameters provided by a the config.

Parameters:

user_config (str | os.PathLike | dict) – training config full example can be find here:

apax.train.run.setup_logging(log_file, log_level)[source]

Setup logging configuration.

Parameters:
  • log_file (str) – Path to the log file.

  • log_level (str) – Logging level. Options: {‘debug’, ‘info’, ‘warning’, ‘error’, ‘critical’}.

exception apax.train.trainer.EarlyStop[source]
apax.train.trainer.fit(state: TrainState, train_ds: InMemoryDataset, loss_fn, Metrics: Collection, callbacks: list, n_epochs: int, ckpt_dir, ckpt_interval: int = 1, val_ds: InMemoryDataset | None = None, patience: int | None = None, patience_min_delta: float = 0.0, disable_pbar: bool = False, disable_batch_pbar: bool = True, is_ensemble=False, data_parallel=True, ema_handler: EMAParameters | None = None)[source]

Trains the model using the provided training dataset.

Parameters:
  • state – The initial state of the model.

  • train_ds (InMemoryDataset) – The training dataset.

  • loss_fn – The loss function to be minimized.

  • metrics.Collection (Metrics) – Collection of metrics to evaluate during training.

  • callbacks (list) – List of callback functions to be executed during training.

  • n_epochs (int) – Number of epochs for training.

  • ckpt_dir – Directory to save checkpoints.

  • ckpt_interval (int, default = 1) – Interval for saving checkpoints.

  • val_ds (InMemoryDataset, default = None) – Validation dataset.

  • patience (int, default = None) – Patience for early stopping.

  • disable_pbar (bool, default = False) – Whether to disable progress bar for epochs..

  • disable_batch_pbar (bool, default = True) – Whether to disable progress bar for batches.

  • is_ensemble (bool, default = False) – Whether the model is an ensemble.

  • data_parallel (bool, default = True) – Whether to use data parallelism.

apax.train.trainer.make_step_fns(loss_fn: Callable, Metrics: Collection, model: Any, is_ensemble: bool, return_predictions: bool = False) tuple[Callable, Callable][source]

Creates JIT-compiled training and validation step functions.

This factory handles the boilerplate for gradient calculation, state updates, metric aggregation, and optional ensemble logic.

Parameters:
  • loss_fn (Callable) – A callable that takes (predictions, labels) and returns a scalar loss.

  • Metrics (metrics.Collection) – A class (typically a clu.metrics.Collection) used to track and merge batch statistics. Must implement single_from_model_output.

  • model (Any) – The model architecture (e.g., a flax.linen.Module).

  • is_ensemble (bool) – If True, wraps the update and eval functions with ensemble-specific handling logic.

  • return_predictions (bool, default = False) – If True, the validation step will return the raw model predictions in addition to metrics and loss.

Returns:

A tuple of (train_step, val_step), where: - train_step: (carry, batch) -> (new_carry, loss) - val_step: (params, batch, metrics) -> (loss, metrics, [predictions])

Return type:

Tuple[Callable, Callable]