Source code for apax.train.run

import logging
import os
import sys
from typing import List, Union

import jax

from apax.config import Config, LossConfig, parse_config
from apax.data.initialization import load_data_files
from apax.data.input_pipeline import dataset_dict
from apax.data.statistics import compute_scale_shift_parameters
from apax.optimizer import get_opt
from apax.train.callbacks import initialize_callbacks
from apax.train.checkpoints import create_params, create_train_state
from apax.train.loss import Loss, LossCollection
from apax.train.metrics import initialize_metrics
from apax.train.parameters import EMAParameters
from apax.train.trainer import fit
from apax.transfer_learning import transfer_parameters
from apax.utils.random import seed_py_np_tf

log = logging.getLogger(__name__)


[docs] def setup_logging(log_file, log_level): """ Setup logging configuration. Parameters ---------- log_file : str Path to the log file. log_level : str Logging level. Options: {'debug', 'info', 'warning', 'error', 'critical'}. """ log_levels = { "debug": logging.DEBUG, "info": logging.INFO, "warning": logging.WARNING, "error": logging.ERROR, "critical": logging.CRITICAL, } while len(logging.root.handlers) > 0: logging.root.removeHandler(logging.root.handlers[-1]) logging.getLogger("absl").setLevel(logging.WARNING) logging.basicConfig( level=log_levels[log_level], format="%(levelname)s | %(asctime)s | %(message)s", datefmt="%H:%M:%S", handlers=[logging.FileHandler(log_file), logging.StreamHandler(sys.stderr)], )
[docs] def initialize_loss_fn(loss_config_list: List[LossConfig]) -> LossCollection: """ Initialize loss functions based on configuration. Parameters ---------- loss_config_list : List[LossConfig] List of loss configurations. Returns ------- LossCollection Collection of initialized loss functions. """ log.info("Initializing Loss Function") loss_funcs = [] for loss in loss_config_list: loss_funcs.append(Loss(**loss.model_dump())) return LossCollection(loss_funcs)
def compute_property_shapes(config: Config): property_configs = [p.model_dump() for p in config.model.property_heads] additional_properties = [] if len(property_configs) == 0: return additional_properties loss_names = [loss.name for loss in config.loss] for pconf in property_configs: name = pconf["name"] if name not in loss_names: continue shape = [] if pconf["aggregation"] == "none": shape.append("natoms") feature_shapes = {"l0": [1], "l1": [3], "symmetric_traceless_l2": [3, 3]} shape.extend(feature_shapes[pconf["mode"]]) additional_properties.append((name, shape)) return additional_properties
[docs] def initialize_datasets(config: Config): """ Initialize training and validation datasets based on the provided configuration. Parameters ---------- config : Config Configuration object all parameters. Returns ------- train_ds : Dataset Training dataset. val_ds : Dataset Validation dataset. ds_stats : Dict[str, Tuple[float, float]] Dictionary containing scale and shift parameters for normalization. """ train_raw_ds, val_raw_ds = load_data_files(config.data) Dataset = dataset_dict[config.data.dataset.processing] dataset_kwargs = dict(config.data.dataset) processing = dataset_kwargs.pop("processing") if processing == "cached": dataset_kwargs["cache_path"] = config.data.model_version_path additional_properties = compute_property_shapes(config) train_ds = Dataset( train_raw_ds, config.model.basis.r_max, config.data.batch_size, config.n_epochs, pos_unit=config.data.pos_unit, energy_unit=config.data.energy_unit, pre_shuffle=True, additional_properties=additional_properties, **dataset_kwargs, ) val_ds = Dataset( val_raw_ds, config.model.basis.r_max, config.data.valid_batch_size, config.n_epochs, pos_unit=config.data.pos_unit, energy_unit=config.data.energy_unit, additional_properties=additional_properties, **dataset_kwargs, ) ds_stats = compute_scale_shift_parameters( train_ds.inputs, train_ds.labels, config.data.shift_method, config.data.scale_method, config.data.shift_options, config.data.scale_options, ) return train_ds, val_ds, ds_stats
[docs] def run(user_config: Union[str, os.PathLike, dict], log_level="error"): """ Starts the training of a model with parameters provided by a the config. Parameters ---------- user_config : str | os.PathLike | dict training config full example can be find :ref:`here <train_config>`: """ config = parse_config(user_config) seed_py_np_tf(config.seed) rng_key = jax.random.PRNGKey(config.seed) config.data.model_version_path.mkdir(parents=True, exist_ok=True) setup_logging(config.data.model_version_path / "train.log", log_level) config.dump_config(config.data.model_version_path) log.info(f"Running on {jax.devices()}") callbacks = initialize_callbacks(config, config.data.model_version_path) loss_fn = initialize_loss_fn(config.loss) Metrics = initialize_metrics(config.metrics) train_ds, val_ds, ds_stats = initialize_datasets(config) sample_input, init_box = train_ds.init_input() Builder = config.model.get_builder() builder = Builder(config.model.model_dump()) model = builder.build_energy_derivative_model( scale=ds_stats.elemental_scale, shift=ds_stats.elemental_shift, apply_mask=True, init_box=init_box, ) batched_model = jax.vmap(model.apply, in_axes=(None, 0, 0, 0, 0, 0)) if config.model.ensemble and config.model.ensemble.kind == "full": n_full_models = config.model.ensemble.n_members else: n_full_models = 1 params, rng_key = create_params(model, rng_key, sample_input, n_full_models) freeze_layers = [] do_transfer_learning = config.transfer_learning is not None if do_transfer_learning: freeze_layers = config.transfer_learning.freeze_layers # TODO rework optimizer initialization and lr keywords steps_per_epoch = train_ds.steps_per_epoch() tx = get_opt( params, config.n_epochs, steps_per_epoch, freeze_layers=freeze_layers, **config.optimizer.model_dump(), ) state = create_train_state(batched_model, params, tx) if do_transfer_learning: state = transfer_parameters(state, config.transfer_learning) if config.weight_average: ema_handler = EMAParameters( config.weight_average.ema_start, config.weight_average.alpha ) else: ema_handler = None fit( state, train_ds, loss_fn, Metrics, callbacks, config.n_epochs, ckpt_dir=config.data.model_version_path, ckpt_interval=config.ckpt_interval, val_ds=val_ds, patience=config.patience, patience_min_delta=config.patience_min_delta, disable_pbar=config.progress_bar.disable_epoch_pbar, disable_batch_pbar=config.progress_bar.disable_batch_pbar, is_ensemble=n_full_models > 1, data_parallel=config.data_parallel, ema_handler=ema_handler, ) log.info("Finished training")