Source code for apax.train.callbacks

import collections
import csv
import logging
from pathlib import Path

import numpy as np
import tensorflow as tf
from keras.callbacks import CSVLogger, TensorBoard

from apax.config.common import flatten
from apax.config.train_config import Config

log = logging.getLogger(__name__)


[docs] class CSVLoggerApax(CSVLogger): def __init__(self, filename, separator=",", append=False): super().__init__(filename, separator=",", append=False)
[docs] def on_test_batch_end(self, batch, logs=None): logs = logs or {} def handle_value(k): is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0 if isinstance(k, str): return k elif isinstance(k, collections.abc.Iterable) and not is_zero_dim_ndarray: return f"\"[{', '.join(map(str, k))}]\"" else: return k if self.keys is None: self.keys = sorted(logs.keys()) if not self.writer: class CustomDialect(csv.excel): delimiter = self.sep fieldnames = ["batch"] + self.keys self.writer = csv.DictWriter( self.csv_file, fieldnames=fieldnames, dialect=CustomDialect ) if self.append_header: self.writer.writeheader() row_dict = collections.OrderedDict({"batch": batch}) row_dict.update((key, handle_value(logs.get(key, "NA"))) for key in self.keys) self.writer.writerow(row_dict) self.csv_file.flush()
def initialize_callbacks(config: Config, model_version_path: Path): callback_configs = config.callbacks log.info("Initializing Callbacks") dummy_model = tf.keras.Model() callback_dict = { "csv": { "class": CSVLoggerApax, "log_path": model_version_path / "log.csv", "path_arg_name": "filename", "kwargs": {"append": True}, "model": dummy_model, }, "tensorboard": { "class": TensorBoard, "log_path": model_version_path, "path_arg_name": "log_dir", "kwargs": {}, "model": dummy_model, }, } names = [conf.name for conf in callback_configs] if "csv" in names and "tensorboard" in names: msg = ( "Using both csv and tensorboard callbacks is not supported at the moment." " Rerun training with only one of the two." ) raise ValueError(msg) callbacks = [] for callback_config in callback_configs: if callback_config.name == "mlflow": try: import mlflow from mlflow.tensorflow import MLflowCallback except ImportError: log.warning("Make sure MLFlow is installed correctly") mlflow.login() mlflow.tensorflow.autolog() experiment = callback_config.experiment mlflow.set_experiment(experiment) params = config.model_dump() params = flatten(params) mlflow.log_params(params) callback = MLflowCallback() callback.set_model(dummy_model) else: callback_info = callback_dict[callback_config.name] path_arg_name = callback_info["path_arg_name"] path = {path_arg_name: callback_info["log_path"]} kwargs = callback_info["kwargs"] callback = callback_info["class"](**path, **kwargs) callback.set_model(callback_info["model"]) callbacks.append(callback) return tf.keras.callbacks.CallbackList([callback])