import functools
import logging
import time
from functools import partial
from typing import Callable, Optional
import jax
import jax.numpy as jnp
import numpy as np
from clu import metrics
from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding
from tqdm import trange
from apax.data.input_pipeline import InMemoryDataset
from apax.train.checkpoints import CheckpointManager, load_state
log = logging.getLogger(__name__)
[docs]
def fit(
state,
train_ds: InMemoryDataset,
loss_fn,
Metrics: metrics.Collection,
callbacks: list,
n_epochs: int,
ckpt_dir,
ckpt_interval: int = 1,
val_ds: Optional[InMemoryDataset] = None,
sam_rho=0.0,
patience: Optional[int] = None,
disable_pbar: bool = False,
disable_batch_pbar: bool = True,
is_ensemble=False,
data_parallel=True,
):
"""
Trains the model using the provided training dataset.
Parameters
----------
state :
The initial state of the model.
train_ds : InMemoryDataset
The training dataset.
loss_fn :
The loss function to be minimized.
Metrics metrics.Collection :
Collection of metrics to evaluate during training.
callbacks : list
List of callback functions to be executed during training.
n_epochs : int
Number of epochs for training.
ckpt_dir:
Directory to save checkpoints.
ckpt_interval : int, default = 1
Interval for saving checkpoints.
val_ds : InMemoryDataset, default = None
Validation dataset.
sam_rho : float, default = 0.0
Rho parameter for Sharpness-Aware Minimization.
patience : int, default = None
Patience for early stopping.
disable_pbar : bool, default = False
Whether to disable progress bar for epochs..
disable_batch_pbar : bool, default = True
Whether to disable progress bar for batches.
is_ensemble : bool, default = False
Whether the model is an ensemble.
data_parallel : bool, default = True
Whether to use data parallelism.
"""
log.info("Beginning Training")
callbacks.on_train_begin()
latest_dir = ckpt_dir / "latest"
best_dir = ckpt_dir / "best"
ckpt_manager = CheckpointManager()
train_step, val_step = make_step_fns(
loss_fn, Metrics, model=state.apply_fn, sam_rho=sam_rho, is_ensemble=is_ensemble
)
if train_ds.n_jit_steps > 1:
train_step = jax.jit(functools.partial(jax.lax.scan, train_step))
state, start_epoch = load_state(state, latest_dir)
if start_epoch >= n_epochs:
raise ValueError(
f"n_epochs <= current epoch from checkpoint ({n_epochs} <= {start_epoch})"
)
devices = len(jax.devices())
if devices > 1 and data_parallel:
sharding = PositionalSharding(mesh_utils.create_device_mesh((devices,)))
state = jax.device_put(state, sharding.replicate())
else:
sharding = None
train_steps_per_epoch = train_ds.steps_per_epoch()
batch_train_ds = train_ds.shuffle_and_batch(sharding)
if val_ds is not None:
val_steps_per_epoch = val_ds.steps_per_epoch()
batch_val_ds = val_ds.batch(sharding)
best_loss = np.inf
early_stopping_counter = 0
epoch_loss = {}
epoch_pbar = trange(
start_epoch, n_epochs, desc="Epochs", ncols=100, disable=disable_pbar, leave=True
)
for epoch in range(start_epoch, n_epochs):
epoch_start_time = time.time()
callbacks.on_epoch_begin(epoch=epoch + 1)
epoch_loss.update({"train_loss": 0.0})
train_batch_metrics = Metrics.empty()
batch_pbar = trange(
0,
train_steps_per_epoch,
desc="Batches",
ncols=100,
mininterval=1.0,
disable=disable_batch_pbar,
leave=False,
)
for batch_idx in range(train_steps_per_epoch):
callbacks.on_train_batch_begin(batch=batch_idx)
batch = next(batch_train_ds)
(
(state, train_batch_metrics),
batch_loss,
) = train_step(
(state, train_batch_metrics),
batch,
)
epoch_loss["train_loss"] += jnp.mean(batch_loss)
callbacks.on_train_batch_end(batch=batch_idx)
batch_pbar.update()
epoch_loss["train_loss"] /= train_steps_per_epoch
epoch_loss["train_loss"] = float(epoch_loss["train_loss"])
epoch_metrics = {
f"train_{key}": float(val)
for key, val in train_batch_metrics.compute().items()
}
if val_ds is not None:
epoch_loss.update({"val_loss": 0.0})
val_batch_metrics = Metrics.empty()
batch_pbar = trange(
0,
val_steps_per_epoch,
desc="Batches",
ncols=100,
mininterval=1.0,
disable=disable_batch_pbar,
leave=False,
)
for batch_idx in range(val_steps_per_epoch):
batch = next(batch_val_ds)
batch_loss, val_batch_metrics = val_step(
state.params, batch, val_batch_metrics
)
epoch_loss["val_loss"] += batch_loss
batch_pbar.update()
epoch_loss["val_loss"] /= val_steps_per_epoch
epoch_loss["val_loss"] = float(epoch_loss["val_loss"])
epoch_metrics.update(
{
f"val_{key}": float(val)
for key, val in val_batch_metrics.compute().items()
}
)
epoch_metrics.update({**epoch_loss})
epoch_end_time = time.time()
epoch_metrics.update({"epoch_time": epoch_end_time - epoch_start_time})
ckpt = {"model": state, "epoch": epoch}
if epoch % ckpt_interval == 0:
ckpt_manager.save_checkpoint(ckpt, epoch, latest_dir)
if epoch_metrics["val_loss"] < best_loss:
best_loss = epoch_metrics["val_loss"]
ckpt_manager.save_checkpoint(ckpt, epoch, best_dir)
early_stopping_counter = 0
else:
early_stopping_counter += 1
callbacks.on_epoch_end(epoch=epoch, logs=epoch_metrics)
epoch_pbar.set_postfix(val_loss=epoch_metrics["val_loss"])
epoch_pbar.update()
if patience is not None and early_stopping_counter >= patience:
log.info(
"Early stopping patience exceeded. Stopping training after"
f" {epoch} epochs."
)
break
epoch_pbar.close()
callbacks.on_train_end()
train_ds.cleanup()
if val_ds:
val_ds.cleanup()
[docs]
def global_norm(updates) -> jnp.ndarray:
"""
Returns the l2 norm of the input.
Parameters
----------
updates: A pytree of ndarrays representing the gradient.
"""
norm = jax.tree_map(lambda u: jnp.sqrt(jnp.sum(jnp.square(u))), updates)
return norm
def calc_loss(params, inputs, labels, loss_fn, model):
R, Z, idx, box, offsets = (
inputs["positions"],
inputs["numbers"],
inputs["idx"],
inputs["box"],
inputs["offsets"],
)
predictions = model(params, R, Z, idx, box, offsets)
loss = loss_fn(inputs, labels, predictions)
return loss, predictions
def make_ensemble_update(update_fn: Callable) -> Callable:
# vmap over train state
v_update_fn = jax.vmap(update_fn, (0, None, None), (0, 0, 0))
def ensemble_update_fn(state, inputs, labels):
loss, predictions, state = v_update_fn(state, inputs, labels)
mean_predictions = jax.tree_map(lambda x: jnp.mean(x, axis=0), predictions)
mean_loss = jnp.mean(loss)
# Should we add std to predictions?
return mean_loss, mean_predictions, state
return ensemble_update_fn
def make_ensemble_eval(update_fn: Callable) -> Callable:
# vmap over train state
v_update_fn = jax.vmap(update_fn, (0, None, None), (0, 0))
def ensemble_eval_fn(state, inputs, labels):
loss, predictions = v_update_fn(state, inputs, labels)
mean_predictions = jax.tree_map(lambda x: jnp.mean(x, axis=0), predictions)
mean_loss = jnp.mean(loss)
return mean_loss, mean_predictions
return ensemble_eval_fn
def make_step_fns(loss_fn, Metrics, model, sam_rho, is_ensemble):
loss_calculator = partial(calc_loss, loss_fn=loss_fn, model=model)
grad_fn = jax.value_and_grad(loss_calculator, 0, has_aux=True)
rho = sam_rho
def update_step(state, inputs, labels):
(loss, predictions), grads = grad_fn(state.params, inputs, labels)
if rho > 1e-6:
# SAM step
grad_norm = global_norm(grads)
eps = jax.tree_map(lambda g, n: g * rho / n, grads, grad_norm)
params_eps = jax.tree_map(lambda p, e: p + e, state.params, eps)
(loss, _), grads = grad_fn(params_eps, inputs, labels) # maybe get rid of SAM
state = state.apply_gradients(grads=grads)
return loss, predictions, state
if is_ensemble:
update_fn = make_ensemble_update(update_step)
eval_fn = make_ensemble_eval(loss_calculator)
else:
update_fn = update_step
eval_fn = loss_calculator
@jax.jit
def train_step(carry, batch):
state, batch_metrics = carry
inputs, labels = batch
loss, predictions, state = update_fn(state, inputs, labels)
new_batch_metrics = Metrics.single_from_model_output(
label=labels, prediction=predictions
)
batch_metrics = batch_metrics.merge(new_batch_metrics)
new_carry = (state, batch_metrics)
return new_carry, loss
@jax.jit
def val_step(params, batch, batch_metrics):
inputs, labels = batch
loss, predictions = eval_fn(params, inputs, labels)
new_batch_metrics = Metrics.single_from_model_output(
label=labels, prediction=predictions
)
batch_metrics = batch_metrics.merge(new_batch_metrics)
return loss, batch_metrics
return train_step, val_step