import collections
import csv
import logging
import typing as t
from pathlib import Path
import numpy as np
import tensorflow as tf
from keras.callbacks import Callback, CSVLogger, TensorBoard
from apax.config.optuna_config import get_pruner
from apax.config.train_config import Config
try:
from apax.train.mlflow import MLFlowLogger
except ImportError:
MLFlowLogger = None
try:
import optuna
except ImportError:
optuna = None
if t.TYPE_CHECKING:
import optuna
log = logging.getLogger(__name__)
class CallbackCollection:
def __init__(self, callbacks: list) -> None:
self.callbacks = callbacks
def on_train_begin(self, logs=None):
for cb in self.callbacks:
cb.on_train_begin(logs)
def on_epoch_begin(self, epoch, logs=None):
for cb in self.callbacks:
cb.on_epoch_begin(epoch)
def on_train_batch_begin(self, batch, logs=None):
for cb in self.callbacks:
cb.on_train_batch_begin(batch)
def on_train_batch_end(self, batch, logs=None):
for cb in self.callbacks:
cb.on_train_batch_end(batch)
def on_epoch_end(self, epoch, logs):
for cb in self.callbacks:
cb.on_epoch_end(epoch, logs)
def on_train_end(self, logs=None):
for cb in self.callbacks:
cb.on_train_end(logs)
def on_test_batch_end(self, batch, logs=None):
for cb in self.callbacks:
cb.on_test_batch_end(batch, logs)
def format_str(k):
return f"{k:.5f}"
[docs]
class CSVLoggerApax(CSVLogger):
def __init__(self, filename, separator=",", append=False):
super().__init__(filename, separator=separator, append=append)
[docs]
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
def handle_value(k):
is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0
if isinstance(k, str):
return k
elif isinstance(k, collections.abc.Iterable) and not is_zero_dim_ndarray:
return f'"[{", ".join(map(format_str, k))}]"'
else:
return format_str(k)
if self.keys is None:
self.keys = sorted(logs.keys())
# When validation_freq > 1, `val_` keys are not in first epoch logs
# Add the `val_` keys so that its part of the fieldnames of writer.
val_keys_found = False
for key in self.keys:
if key.startswith("val_"):
val_keys_found = True
break
if not val_keys_found:
self.keys.extend(["val_" + k for k in self.keys])
if not self.writer:
class CustomDialect(csv.excel):
delimiter = self.sep
fieldnames = ["epoch"] + self.keys
self.writer = csv.DictWriter(
self.csv_file, fieldnames=fieldnames, dialect=CustomDialect
)
if self.append_header:
self.writer.writeheader()
row_dict = collections.OrderedDict({"epoch": epoch})
row_dict.update((key, handle_value(logs.get(key, "NA"))) for key in self.keys)
self.writer.writerow(row_dict)
self.csv_file.flush()
[docs]
def on_test_batch_end(self, batch, logs=None):
logs = logs or {}
def handle_value(k):
is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0
if isinstance(k, str):
return k
elif isinstance(k, collections.abc.Iterable) and not is_zero_dim_ndarray:
return f'"[{", ".join(map(format_str, k))}]"'
else:
return format_str(k)
if self.keys is None:
self.keys = sorted(logs.keys())
if not self.writer:
class CustomDialect(csv.excel):
delimiter = self.sep
fieldnames = ["batch"] + self.keys
self.writer = csv.DictWriter(
self.csv_file, fieldnames=fieldnames, dialect=CustomDialect
)
if self.append_header:
self.writer.writeheader()
row_dict = collections.OrderedDict({"batch": batch})
row_dict.update((key, handle_value(logs.get(key, "NA"))) for key in self.keys)
self.writer.writerow(row_dict)
self.csv_file.flush()
[docs]
class KerasPruningCallback(Callback):
"""Adapted from https://optuna-integration.readthedocs.io/en/latest/_modules/optuna_integration/keras/keras.html#KerasPruningCallback
Keras callback to prune unpromising trials.
See `the example <https://github.com/optuna/optuna-examples/blob/main/
keras/keras_integration.py>`__
if you want to add a pruning callback which observes validation accuracy.
Args:
study_id:
id number of study
trial_id:
id of current trial
study_log_file:
path to log file
monitor:
An evaluation metric for pruning, e.g., ``val_loss`` and
``val_accuracy``. Please refer to `keras.Callback reference
<https://keras.io/callbacks/#callback>`_ for further details.
interval:
Check if trial should be pruned every n-th epoch. By default ``interval=1`` and
pruning is performed after every epoch. Increase ``interval`` to run several
epochs faster before applying pruning.
""" # noqa: E501
def __init__(
self,
study_name: str,
trial_id: int,
study_log_file: str | Path,
monitor: str = "val_loss",
interval: int = 1,
) -> None:
super().__init__()
if optuna is None:
raise ImportError(f"{self.__name__} requires optuna, but could not import it")
storage = optuna.storages.JournalStorage(
optuna.storages.journal.JournalFileBackend(study_log_file)
)
study = optuna.load_study(
study_name=study_name, storage=storage, sampler=None, pruner=None
)
if "pruner" in study.user_attrs:
pruner_class = get_pruner(study.user_attrs["pruner"]["name"])
study.pruner = pruner_class(**study.user_attrs["pruner"]["kwargs"])
else:
study.pruner = None
self._monitor = monitor
self._interval = interval
self._trial = optuna.trial.Trial(study, trial_id)
[docs]
def on_epoch_end(self, epoch: int, logs: dict[str, float] | None = None) -> None:
if epoch % self._interval != 0:
return
logs = logs or {}
current_score = logs.get(self._monitor)
if current_score is None:
message = (
"The metric '{}' is not in the evaluation logs for pruning. "
"Please make sure you set the correct metric name.".format(self._monitor)
)
log.warning(message)
return
self._trial.report(float(current_score), step=epoch)
if self._trial.should_prune():
message = "Trial was pruned at epoch {}.".format(epoch)
log.info(message)
raise optuna.TrialPruned(message)
def initialize_callbacks(config: Config, model_version_path: Path):
callback_configs = config.callbacks
log.info("Initializing Callbacks")
dummy_model = tf.keras.Model()
dummy_model.compile(loss="mse", optimizer="adam")
callback_dict = {
"csv": {
"class": CSVLoggerApax,
"log_path": model_version_path / "log.csv",
"path_arg_name": "filename",
"kwargs": {"append": True},
"model": dummy_model,
},
"tensorboard": {
"class": TensorBoard,
"log_path": model_version_path,
"path_arg_name": "log_dir",
"kwargs": {},
"model": dummy_model,
"write_graph": False,
},
"mlflow": {
"class": MLFlowLogger,
"log_path": model_version_path,
"path_arg_name": "log_dir",
"kwargs": {"run_name": config.data.experiment},
},
}
callbacks = []
for callback_config in callback_configs:
if callback_config.name == "mlflow":
callback = MLFlowLogger(
experiment=callback_config.experiment, run_name=config.data.experiment
)
elif callback_config.name == "pruning":
callback = KerasPruningCallback(
study_name=callback_config.study_name,
trial_id=callback_config.trial_id,
study_log_file=callback_config.study_log_file,
monitor=callback_config.monitor,
interval=callback_config.interval,
)
else:
callback_info = callback_dict[callback_config.name]
path_arg_name = callback_info["path_arg_name"]
path = {path_arg_name: callback_info["log_path"]}
kwargs = callback_info["kwargs"]
callback = callback_info["class"](**path, **kwargs)
callback.set_model(callback_info["model"])
callbacks.append(callback)
return CallbackCollection(callbacks)