Source code for apax.transfer_learning.parameter_transfer

import logging
from typing import Union

from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.training.train_state import TrainState
from flax.traverse_util import flatten_dict, unflatten_dict

from apax.config.train_config import TransferLearningConfig
from apax.train.checkpoints import load_params

log = logging.getLogger(__name__)


[docs] def black_list_param_transfer( source_params: Union[FrozenDict, dict], target_params: Union[FrozenDict, dict], param_black_list: list[str], ) -> FrozenDict: """Transfer parameters from one dictionary to another, while keeping some key-value pairs unchanged. Args: source_params (Union[FrozenDict, dict]): source parameters target_params (Union[FrozenDict, dict]): target parameters param_black_list (list[str]): list of keys to keep unchanged. Returns: transfered_target (dict): target_params with key-value pairs updated. """ source_params = unfreeze(source_params) target_params = unfreeze(target_params) flat_source = flatten_dict(source_params) flat_target = flatten_dict(target_params) for p, v in flat_source.items(): if p[-2] not in param_black_list: flat_target[p] = v log.info("Transferring parameter: %s", p[-2]) transfered_target = unflatten_dict(flat_target) transfered_target = freeze(transfered_target) return transfered_target
[docs] def transfer_parameters( state: TrainState, ckpt_config: TransferLearningConfig ) -> TrainState: """Transfer the parameters from the checkpoint to the train state. Args: state (TrainState): train state ckpt_config (TransferLearningConfig): transfer learning configuration Returns: state (TrainState): TrainState with the `params` attribute updated according to the transfer learning configuration. """ source_params = load_params(ckpt_config.base_model_checkpoint) log.info("Transferring parameters from %s", ckpt_config.base_model_checkpoint) params = black_list_param_transfer( source_params, state.params, ckpt_config.reset_layers ) state = state.replace(params=params) return state