Source code for apax.train.checkpoints

import logging
from pathlib import Path
from typing import List, Tuple, Union

import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.training import checkpoints, train_state
from flax.traverse_util import flatten_dict, unflatten_dict

from apax.config.common import parse_config
from apax.config.train_config import Config

log = logging.getLogger(__name__)


[docs] def check_for_ensemble(params: FrozenDict) -> int: """Checks if a set of parameters belongs to an ensemble model. This is the case if all parameters share the same first dimension (parameter batch) """ flat_params = flatten_dict(params) shapes = [v.shape[0] for v in flat_params.values()] is_ensemble = len(set(shapes)) == 1 if is_ensemble: return shapes[0] else: return 1
def create_train_state(model, params: FrozenDict, tx): n_models = check_for_ensemble(params) def create_single_train_state(params): state = train_state.TrainState.create( apply_fn=model, params=params, tx=tx, ) return state if n_models > 1: train_state_fn = jax.vmap(create_single_train_state, axis_name="ensemble") else: train_state_fn = create_single_train_state return train_state_fn(params) def create_params(model, rng_key, sample_input: tuple, n_models: int): keys = jax.random.split(rng_key, num=n_models + 1) rng_key, model_rng = keys[0], keys[1:] log.info(f"initializing {n_models} models") if n_models == 1: params = model.init(model_rng[0], *sample_input) elif n_models > 1: num_args = len(sample_input) # vmap only over parameters, not over any data from the input in_axes = (0, *[None] * num_args) params = jax.vmap(model.init, in_axes=in_axes)(model_rng, *sample_input) else: raise ValueError(f"n_models should be a positive integer, found {n_models}") params = freeze(params) return params, rng_key def load_state(state, ckpt_dir): start_epoch = 0 target = {"model": state, "epoch": 0} checkpoints_exist = Path(ckpt_dir).is_dir() if checkpoints_exist: log.info("Loading checkpoint") raw_restored = checkpoints.restore_checkpoint(ckpt_dir, target=target, step=None) state = raw_restored["model"] start_epoch = raw_restored["epoch"] + 1 log.info("Successfully restored checkpoint from epoch %d", raw_restored["epoch"]) return state, start_epoch class CheckpointManager: def __init__(self) -> None: self.async_manager = checkpoints.AsyncManager() def save_checkpoint(self, ckpt, epoch: int, path: Path) -> None: checkpoints.save_checkpoint( ckpt_dir=path.resolve(), target=ckpt, step=epoch, overwrite=True, keep=2, async_manager=self.async_manager, )
[docs] def stack_parameters(param_list: List[FrozenDict]) -> FrozenDict: """Combine a list of parameter sets into a stacked version. Used for model ensembles. """ flat_param_list = [] for params in param_list: params = unfreeze(params) flat_params = flatten_dict(params) flat_param_list.append(flat_params) stacked_flat_params = flat_params for p in flat_param_list[0].keys(): stacked_flat_params[p] = jnp.stack( [flat_param[p] for flat_param in flat_param_list] ) stacked_params = unflatten_dict(stacked_flat_params) stack_params = freeze(stacked_params) return stack_params
def load_params(model_version_path: Path, best=True) -> FrozenDict: model_version_path = Path(model_version_path) if best: model_version_path = model_version_path / "best" log.info(f"loading checkpoint from {model_version_path}") try: # keep try except block for zntrack load from rev raw_restored = checkpoints.restore_checkpoint( model_version_path, target=None, step=None ) except FileNotFoundError: print(f"No checkpoint found at {model_version_path}") if raw_restored is None: raise FileNotFoundError(f"No checkpoint found at {model_version_path}") params = jax.tree_map(jnp.asarray, raw_restored["model"]["params"]) return params
[docs] def restore_single_parameters(model_dir: Path) -> Tuple[Config, FrozenDict]: """Load the config and parameters of a single model""" model_dir = Path(model_dir) model_config = parse_config(model_dir / "config.yaml") if model_config.data.experiment == "": model_config.data.directory = model_dir.resolve().as_posix() else: model_config.data.directory = model_dir.parent.resolve().as_posix() ckpt_dir = model_config.data.model_version_path return model_config, load_params(ckpt_dir)
[docs] def restore_parameters(model_dir: Union[Path, List[Path]]) -> Tuple[Config, FrozenDict]: """Restores one or more model configs and parameters. Parameters are stacked for ensembling. """ if isinstance(model_dir, Path) or isinstance(model_dir, str): config, params = restore_single_parameters(model_dir) elif isinstance(model_dir, list): param_list = [] for path in model_dir: config, params = restore_single_parameters(path) param_list.append(params) params = stack_parameters(param_list) else: raise NotImplementedError( "Please provide either a path or list of paths to trained models" ) return config, params
[docs] def canonicalize_energy_model_parameters(params): """Ensures that parameters from EnergyDerivativeModels can be loaded into EnergyModels by removing the "energy_model" parameter layer. """ param_dict = unfreeze(params) first_level = param_dict["params"] if "energy_model" in first_level.keys(): params = {"params": first_level["energy_model"]} params = freeze(params) return params
[docs] def canonicalize_energy_grad_model_parameters(params): """Ensures that parameters from EnergyModels can be loaded into EnergyDerivativeModels by adding the "energy_model" parameter layer. """ param_dict = unfreeze(params) first_level = param_dict["params"] if "energy_model" not in first_level.keys(): params = {"params": {"energy_model": first_level}} params = freeze(params) return params