Training#

class apax.train.callbacks.CSVLoggerApax(filename, separator=',', append=False)[source]#
on_test_batch_end(batch, logs=None)[source]#

Called at the end of a batch in evaluate methods.

Also called at the end of a validation batch in the fit methods, if validation data is provided.

Subclasses should override for any actions to run.

Note that if the steps_per_execution argument to compile in tf.keras.Model is set to N, this method will only be called every N batches.

Parameters:
  • batch – Integer, index of batch within the current epoch.

  • logs – Dict. Aggregated metric results up until this batch.

apax.train.checkpoints.canonicalize_energy_grad_model_parameters(params)[source]#

Ensures that parameters from EnergyModels can be loaded into EnergyDerivativeModels by adding the “energy_model” parameter layer.

apax.train.checkpoints.canonicalize_energy_model_parameters(params)[source]#

Ensures that parameters from EnergyDerivativeModels can be loaded into EnergyModels by removing the “energy_model” parameter layer.

apax.train.checkpoints.check_for_ensemble(params: FrozenDict) int[source]#

Checks if a set of parameters belongs to an ensemble model. This is the case if all parameters share the same first dimension (parameter batch)

apax.train.checkpoints.restore_parameters(model_dir: Path | List[Path]) Tuple[Config, FrozenDict][source]#

Restores one or more model configs and parameters. Parameters are stacked for ensembling.

apax.train.checkpoints.restore_single_parameters(model_dir: Path) Tuple[Config, FrozenDict][source]#

Load the config and parameters of a single model

apax.train.checkpoints.stack_parameters(param_list: List[FrozenDict]) FrozenDict[source]#

Combine a list of parameter sets into a stacked version. Used for model ensembles.

apax.train.eval.eval_model(config_path, n_test=-1, log_file='eval.log', log_level='error')[source]#

Evaluate the model using the provided configuration.

Parameters:
  • config_path (str) – Path to the configuration file.

  • n_test (int, default = -1) – Number of test samples to evaluate, by default -1 (evaluate all).

  • log_file (str, default = "eval.log") – Path to the log file.

  • log_level (str, default = "error") – Logging level.

apax.train.eval.load_test_data(config, model_version_path, eval_path, n_test=-1)[source]#

Load test data for evaluation.

Parameters:
  • config (object) – Configuration object.

  • model_version_path (str) – Path to the model version.

  • eval_path (str) – Path to evaluation directory.

  • n_test (int, default = -1) – Number of test samples to load, by default -1 (load all).

Returns:

List of ase.Atoms containing the test data.

Return type:

atoms_list

apax.train.eval.predict(model, params, Metrics, loss_fn, test_ds, callbacks, is_ensemble=False)[source]#

Perform predictions on the test dataset.

Parameters:
  • model – Trained model.

  • params – Model parameters.

  • Metrics – Collection of metrics.

  • loss_fn – Loss function.

  • test_ds – Test dataset.

  • callbacks – Callback functions.

  • is_ensemble (bool, default = False) – Whether the model is an ensemble.

class apax.train.loss.Loss(name: str, loss_type: str, weight: float = 1.0, atoms_exponent: float = 1.0, parameters: dict = <factory>)[source]#

Represents a single weighted loss function that is constructed from a name and a type of comparison metric.

class apax.train.loss.LossCollection(loss_list: List[apax.train.loss.Loss])[source]#
apax.train.loss.force_angle_div_force_label(label: array, prediction: array, divisor: float = 1.0, parameters: dict = {})[source]#

Consine similarity loss function weighted by the norm of the force labels. Contributions are summed in Loss.

apax.train.loss.force_angle_exponential_weight(label: array, prediction: array, divisor: float = 1.0, parameters: dict = {}) array[source]#

Consine similarity loss function exponentially scaled by the norm of the force labels. Contributions are summed in Loss.

apax.train.loss.force_angle_loss(label: array, prediction: array, divisor: float = 1.0, parameters: dict = {}) array[source]#

Consine similarity loss function. Contributions are summed in Loss.

apax.train.loss.weighted_huber_loss(label: array, prediction: array, divisor: float = 1.0, parameters: dict = {}) array[source]#

Huber loss function that allows weighting of individual contributions by the number of atoms in the system.

apax.train.loss.weighted_squared_error(label: array, prediction: array, divisor: float = 1.0, parameters: dict = {}) array[source]#

Squared error function that allows weighting of individual contributions by the number of atoms in the system.

class apax.train.metrics.Averagefp64(total: array, count: array)[source]#
classmethod empty() Metric[source]#

Returns an empty instance (i.e. .merge(Metric.empty()) is a no-op).

class apax.train.metrics.RootAverage(total: array, count: array)[source]#

Modifies the compute method of metrics.Average to obtain the root of the average. Meant to be used with mse_fn.

compute() array[source]#

Computes final metrics from intermediate values.

apax.train.metrics.cosine_sim(label: dict[array], prediction: dict[array], key: str) array[source]#

Computes the cosine similarity of two arrays.

apax.train.metrics.initialize_metrics(metrics_list) Collection[source]#

Builds a clu metrics Collection by looping over all keys and reductions. the metrics are named according to key_reduction. See make_single_metric for details on the individual metrics.

apax.train.metrics.mae_fn(label: dict[array], prediction: dict[array], key: str) array[source]#

Computes the Mean Absolute Error of two arrays.

apax.train.metrics.make_single_metric(key: str, reduction: str) Average[source]#

Builds a single clu metric where the key picks out the quantity from the model predictions dict. Metric functions (like mae_fn) are curried with the key.

apax.train.metrics.mse_fn(label: dict[array], prediction: dict[array], key: str) array[source]#

Computes the Mean Squared Error of two arrays.

apax.train.run.initialize_datasets(config: Config)[source]#

Initialize training and validation datasets based on the provided configuration.

Parameters:

config (Config) – Configuration object all parameters.

Returns:

  • train_ds (Dataset) – Training dataset.

  • val_ds (Dataset) – Validation dataset.

  • ds_stats (Dict[str, Tuple[float, float]]) – Dictionary containing scale and shift parameters for normalization.

apax.train.run.initialize_loss_fn(loss_config_list: List[LossConfig]) LossCollection[source]#

Initialize loss functions based on configuration.

Parameters:

loss_config_list (List[LossConfig]) – List of loss configurations.

Returns:

Collection of initialized loss functions.

Return type:

LossCollection

apax.train.run.run(user_config: str | PathLike | dict, log_level='error')[source]#

Starts the training of a model with parameters provided by a the config.

Parameters:

user_config (str | os.PathLike | dict) – training config full exmaple can be finde here:

apax.train.run.setup_logging(log_file, log_level)[source]#

Setup logging configuration.

Parameters:
  • log_file (str) – Path to the log file.

  • log_level (str) – Logging level. Options: {‘debug’, ‘info’, ‘warning’, ‘error’, ‘critical’}.

apax.train.trainer.fit(state, train_ds: InMemoryDataset, loss_fn, Metrics: Collection, callbacks: list, n_epochs: int, ckpt_dir, ckpt_interval: int = 1, val_ds: InMemoryDataset | None = None, sam_rho=0.0, patience: int | None = None, disable_pbar: bool = False, disable_batch_pbar: bool = True, is_ensemble=False, data_parallel=True)[source]#

Trains the model using the provided training dataset.

Parameters:
  • state – The initial state of the model.

  • train_ds (InMemoryDataset) – The training dataset.

  • loss_fn – The loss function to be minimized.

  • metrics.Collection (Metrics) – Collection of metrics to evaluate during training.

  • callbacks (list) – List of callback functions to be executed during training.

  • n_epochs (int) – Number of epochs for training.

  • ckpt_dir – Directory to save checkpoints.

  • ckpt_interval (int, default = 1) – Interval for saving checkpoints.

  • val_ds (InMemoryDataset, default = None) – Validation dataset.

  • sam_rho (float, default = 0.0) – Rho parameter for Sharpness-Aware Minimization.

  • patience (int, default = None) – Patience for early stopping.

  • disable_pbar (bool, default = False) – Whether to disable progress bar for epochs..

  • disable_batch_pbar (bool, default = True) – Whether to disable progress bar for batches.

  • is_ensemble (bool, default = False) – Whether the model is an ensemble.

  • data_parallel (bool, default = True) – Whether to use data parallelism.

apax.train.trainer.global_norm(updates) Array[source]#

Returns the l2 norm of the input.

Parameters:

updates (A pytree of ndarrays representing the gradient.) –