import logging
import os
import sys
from typing import List, Union
import jax
from apax.config import Config, LossConfig, parse_config
from apax.data.initialization import load_data_files
from apax.data.input_pipeline import dataset_dict
from apax.data.statistics import compute_scale_shift_parameters
from apax.model import ModelBuilder
from apax.optimizer import get_opt
from apax.train.callbacks import initialize_callbacks
from apax.train.checkpoints import create_params, create_train_state
from apax.train.loss import Loss, LossCollection
from apax.train.metrics import initialize_metrics
from apax.train.trainer import fit
from apax.transfer_learning import transfer_parameters
from apax.utils.random import seed_py_np_tf
log = logging.getLogger(__name__)
[docs]
def setup_logging(log_file, log_level):
"""
Setup logging configuration.
Parameters
----------
log_file : str
Path to the log file.
log_level : str
Logging level. Options: {'debug', 'info', 'warning', 'error', 'critical'}.
"""
log_levels = {
"debug": logging.DEBUG,
"info": logging.INFO,
"warning": logging.WARNING,
"error": logging.ERROR,
"critical": logging.CRITICAL,
}
while len(logging.root.handlers) > 0:
logging.root.removeHandler(logging.root.handlers[-1])
logging.getLogger("absl").setLevel(logging.WARNING)
logging.basicConfig(
level=log_levels[log_level],
format="%(levelname)s | %(asctime)s | %(message)s",
datefmt="%H:%M:%S",
handlers=[logging.FileHandler(log_file), logging.StreamHandler(sys.stderr)],
)
[docs]
def initialize_loss_fn(loss_config_list: List[LossConfig]) -> LossCollection:
"""
Initialize loss functions based on configuration.
Parameters
----------
loss_config_list : List[LossConfig]
List of loss configurations.
Returns
-------
LossCollection
Collection of initialized loss functions.
"""
log.info("Initializing Loss Function")
loss_funcs = []
for loss in loss_config_list:
loss_funcs.append(Loss(**loss.model_dump()))
return LossCollection(loss_funcs)
[docs]
def initialize_datasets(config: Config):
"""
Initialize training and validation datasets based on the provided configuration.
Parameters
----------
config : Config
Configuration object all parameters.
Returns
-------
train_ds : Dataset
Training dataset.
val_ds : Dataset
Validation dataset.
ds_stats : Dict[str, Tuple[float, float]]
Dictionary containing scale and shift parameters for normalization.
"""
train_raw_ds, val_raw_ds = load_data_files(config.data)
Dataset = dataset_dict[config.data.ds_type]
train_ds = Dataset(
train_raw_ds,
config.model.r_max,
config.data.batch_size,
config.n_epochs,
config.data.shuffle_buffer_size,
config.n_jitted_steps,
pos_unit=config.data.pos_unit,
energy_unit=config.data.energy_unit,
pre_shuffle=True,
cache_path=config.data.model_version_path,
)
val_ds = Dataset(
val_raw_ds,
config.model.r_max,
config.data.valid_batch_size,
config.n_epochs,
pos_unit=config.data.pos_unit,
energy_unit=config.data.energy_unit,
cache_path=config.data.model_version_path,
)
ds_stats = compute_scale_shift_parameters(
train_ds.inputs,
train_ds.labels,
config.data.shift_method,
config.data.scale_method,
config.data.shift_options,
config.data.scale_options,
)
return train_ds, val_ds, ds_stats
[docs]
def run(user_config: Union[str, os.PathLike, dict], log_level="error"):
"""
Starts the training of a model with parameters provided by a the config.
Parameters
----------
user_config : str | os.PathLike | dict
training config full exmaple can be finde :ref:`here <train_config>`:
"""
config = parse_config(user_config)
seed_py_np_tf(config.seed)
rng_key = jax.random.PRNGKey(config.seed)
config.data.model_version_path.mkdir(parents=True, exist_ok=True)
setup_logging(config.data.model_version_path / "train.log", log_level)
config.dump_config(config.data.model_version_path)
log.info(f"Running on {jax.devices()}")
callbacks = initialize_callbacks(config, config.data.model_version_path)
loss_fn = initialize_loss_fn(config.loss)
Metrics = initialize_metrics(config.metrics)
train_ds, val_ds, ds_stats = initialize_datasets(config)
log.info("Initializing Model")
sample_input, init_box = train_ds.init_input()
builder = ModelBuilder(config.model.get_dict())
model = builder.build_energy_derivative_model(
scale=ds_stats.elemental_scale,
shift=ds_stats.elemental_shift,
apply_mask=True,
init_box=init_box,
)
batched_model = jax.vmap(model.apply, in_axes=(None, 0, 0, 0, 0, 0))
params, rng_key = create_params(model, rng_key, sample_input, config.n_models)
# TODO rework optimizer initialization and lr keywords
steps_per_epoch = train_ds.steps_per_epoch()
n_epochs = config.n_epochs
transition_steps = steps_per_epoch * n_epochs - config.optimizer.transition_begin
tx = get_opt(
params,
transition_steps=transition_steps,
**config.optimizer.model_dump(),
)
state = create_train_state(batched_model, params, tx)
base_checkpoint = config.checkpoints.base_model_checkpoint
do_transfer_learning = base_checkpoint is not None
if do_transfer_learning:
state = transfer_parameters(state, config.checkpoints)
fit(
state,
train_ds,
loss_fn,
Metrics,
callbacks,
n_epochs,
ckpt_dir=config.data.model_version_path,
ckpt_interval=config.checkpoints.ckpt_interval,
val_ds=val_ds,
sam_rho=config.optimizer.sam_rho,
patience=config.patience,
disable_pbar=config.progress_bar.disable_epoch_pbar,
disable_batch_pbar=config.progress_bar.disable_batch_pbar,
is_ensemble=config.n_models > 1,
data_parallel=config.data_parallel,
)
log.info("Finished training")