Training¶
- class apax.train.callbacks.CSVLoggerApax(filename, separator=',', append=False)[source]¶
- on_epoch_end(epoch, logs=None)[source]¶
Called at the end of an epoch.
Subclasses should override for any actions to run. This function should only be called during TRAIN mode.
- Parameters:
epoch – Integer, index of epoch.
logs – Dict, metric results for this training epoch, and for the validation epoch if validation is performed. Validation result keys are prefixed with val_. For training epoch, the values of the Model’s metrics are returned. Example: {‘loss’: 0.2, ‘accuracy’: 0.7}.
- on_test_batch_end(batch, logs=None)[source]¶
Called at the end of a batch in evaluate methods.
Also called at the end of a validation batch in the fit methods, if validation data is provided.
Subclasses should override for any actions to run.
Note that if the steps_per_execution argument to compile in Model is set to N, this method will only be called every N batches.
- Parameters:
batch – Integer, index of batch within the current epoch.
logs – Dict. Aggregated metric results up until this batch.
- apax.train.checkpoints.canonicalize_energy_grad_model_parameters(params)[source]¶
Ensures that parameters from EnergyModels can be loaded into EnergyDerivativeModels by adding the “energy_model” parameter layer.
- apax.train.checkpoints.canonicalize_energy_model_parameters(params)[source]¶
Ensures that parameters from EnergyDerivativeModels can be loaded into EnergyModels by removing the “energy_model” parameter layer.
- apax.train.checkpoints.check_for_ensemble(params: FrozenDict) int[source]¶
Checks if a set of parameters belongs to an ensemble model. This is the case if all parameters share the same first dimension (parameter batch)
- apax.train.checkpoints.restore_parameters(model_dir: Path | List[Path]) Tuple[Config, FrozenDict][source]¶
Restores one or more model configs and parameters. Parameters are stacked for ensembling.
- apax.train.checkpoints.restore_single_parameters(model_dir: Path) Tuple[Config, FrozenDict][source]¶
Load the config and parameters of a single model
- apax.train.checkpoints.stack_parameters(param_list: List[FrozenDict]) FrozenDict[source]¶
Combine a list of parameter sets into a stacked version. Used for model ensembles.
- apax.train.eval.eval_model(config_path, n_test=-1, log_file='eval.log', log_level='error')[source]¶
Evaluate the model using the provided configuration.
- Parameters:
config_path (str) – Path to the configuration file.
n_test (int, default = -1) – Number of test samples to evaluate, by default -1 (evaluate all).
log_file (str, default = "eval.log") – Path to the log file.
log_level (str, default = "error") – Logging level.
- apax.train.eval.load_test_data(config, model_version_path, eval_path, n_test=-1)[source]¶
Load test data for evaluation.
- Parameters:
config (object) – Configuration object.
model_version_path (str) – Path to the model version.
eval_path (str) – Path to evaluation directory.
n_test (int, default = -1) – Number of test samples to load, by default -1 (load all).
- Returns:
List of ase.Atoms containing the test data.
- Return type:
atoms_list
- apax.train.eval.predict(model, params, Metrics, loss_fn, test_ds, callbacks, is_ensemble=False)[source]¶
Perform predictions on the test dataset.
- Parameters:
model – Trained model.
params – Model parameters.
Metrics – Collection of metrics.
loss_fn – Loss function.
test_ds – Test dataset.
callbacks – Callback functions.
is_ensemble (bool, default = False) – Whether the model is an ensemble.
- class apax.train.loss.Loss(name: str, loss_type: str, weight: float = 1.0, atoms_exponent: float = 1.0, parameters: dict = <factory>)[source]¶
Represents a single weighted loss function that is constructed from a name and a type of comparison metric.
- class apax.train.loss.LossCollection(loss_list: List[apax.train.loss.Loss])[source]¶
- apax.train.loss.crps_loss(label: Array, prediction: Array, name, divisor: float = 1.0, parameters: dict = {}) Array[source]¶
Computes the CRPS of a gaussian distribution given means, targets and standard deviations (uncertainty estimate)
- apax.train.loss.force_angle_div_force_label(label: array, prediction: array, name, divisor: float = 1.0, parameters: dict = {})[source]¶
Consine similarity loss function weighted by the norm of the force labels. Contributions are summed in Loss.
- apax.train.loss.force_angle_exponential_weight(label: array, prediction: array, name, divisor: float = 1.0, parameters: dict = {}) array[source]¶
Consine similarity loss function exponentially scaled by the norm of the force labels. Contributions are summed in Loss.
- apax.train.loss.force_angle_loss(label: array, prediction: array, name, divisor: float = 1.0, parameters: dict = {}) array[source]¶
Consine similarity loss function. Contributions are summed in Loss.
- apax.train.loss.nll_loss(label: Array, prediction: Array, name, divisor: float = 1.0, parameters: dict = {}) Array[source]¶
Computes the gaussian NLL loss given means, targets and standard deviations (uncertainty estimate)
- apax.train.loss.weighted_huber_loss(label: array, prediction: array, name, divisor: float = 1.0, parameters: dict = {}) array[source]¶
Huber loss function that allows weighting of individual contributions by the number of atoms in the system.
- apax.train.loss.weighted_squared_error(label: array, prediction: array, name, divisor: float = 1.0, parameters: dict = {}) array[source]¶
Squared error function that allows weighting of individual contributions by the number of atoms in the system.
- class apax.train.metrics.RootAverage(total: array, count: array)[source]¶
Modifies the compute method of metrics.Average to obtain the root of the average. Meant to be used with mse_fn.
- apax.train.metrics.cosine_sim(label: dict[array], prediction: dict[array], key: str) array[source]¶
Computes the cosine similarity of two arrays.
- apax.train.metrics.initialize_metrics(metrics_list) Collection[source]¶
Builds a clu metrics Collection by looping over all keys and reductions. the metrics are named according to key_reduction. See make_single_metric for details on the individual metrics.
- apax.train.metrics.mae_fn(label: dict[array], prediction: dict[array], key: str) array[source]¶
Computes the Mean Absolute Error of two arrays.
- apax.train.metrics.make_single_metric(key: str, reduction: str) Average[source]¶
Builds a single clu metric where the key picks out the quantity from the model predictions dict. Metric functions (like mae_fn) are curried with the key.
- apax.train.metrics.mse_fn(label: dict[array], prediction: dict[array], key: str) array[source]¶
Computes the Mean Squared Error of two arrays.
- apax.train.run.initialize_datasets(config: Config)[source]¶
Initialize training and validation datasets based on the provided configuration.
- Parameters:
config (Config) – Configuration object all parameters.
- Returns:
train_ds (Dataset) – Training dataset.
val_ds (Dataset) – Validation dataset.
ds_stats (Dict[str, Tuple[float, float]]) – Dictionary containing scale and shift parameters for normalization.
- apax.train.run.initialize_loss_fn(loss_config_list: List[LossConfig]) LossCollection[source]¶
Initialize loss functions based on configuration.
- Parameters:
loss_config_list (List[LossConfig]) – List of loss configurations.
- Returns:
Collection of initialized loss functions.
- Return type:
- apax.train.run.run(user_config: str | PathLike | dict, log_level='error')[source]¶
Starts the training of a model with parameters provided by a the config.
- Parameters:
user_config (str | os.PathLike | dict) – training config full example can be find here:
- apax.train.run.setup_logging(log_file, log_level)[source]¶
Setup logging configuration.
- Parameters:
log_file (str) – Path to the log file.
log_level (str) – Logging level. Options: {‘debug’, ‘info’, ‘warning’, ‘error’, ‘critical’}.
- apax.train.trainer.fit(state: TrainState, train_ds: InMemoryDataset, loss_fn, Metrics: Collection, callbacks: list, n_epochs: int, ckpt_dir, ckpt_interval: int = 1, val_ds: InMemoryDataset | None = None, patience: int | None = None, disable_pbar: bool = False, disable_batch_pbar: bool = True, is_ensemble=False, data_parallel=True, ema_handler: EMAParameters | None = None)[source]¶
Trains the model using the provided training dataset.
- Parameters:
state – The initial state of the model.
train_ds (InMemoryDataset) – The training dataset.
loss_fn – The loss function to be minimized.
metrics.Collection (Metrics) – Collection of metrics to evaluate during training.
callbacks (list) – List of callback functions to be executed during training.
n_epochs (int) – Number of epochs for training.
ckpt_dir – Directory to save checkpoints.
ckpt_interval (int, default = 1) – Interval for saving checkpoints.
val_ds (InMemoryDataset, default = None) – Validation dataset.
patience (int, default = None) – Patience for early stopping.
disable_pbar (bool, default = False) – Whether to disable progress bar for epochs..
disable_batch_pbar (bool, default = True) – Whether to disable progress bar for batches.
is_ensemble (bool, default = False) – Whether the model is an ensemble.
data_parallel (bool, default = True) – Whether to use data parallelism.