Source code for apax.optimizer.get_optimizer

import logging
from typing import Any, Callable

import optax
from flax import traverse_util
from flax.core.frozen_dict import freeze

log = logging.getLogger(__name__)


[docs] def map_nested_fn(fn: Callable[[str, Any], dict]) -> Callable[[dict], dict]: """ Recursively apply `fn` to the key-value pairs of a nested dict See https://optax.readthedocs.io/en/latest/api.html?highlight=multitransform#multi-transform """ def map_fn(nested_dict): return { k: map_fn(v) if isinstance(v, dict) else fn(k, v) for k, v in nested_dict.items() } return map_fn
[docs] def get_schedule( lr: float, transition_begin: int, transition_steps: int ) -> optax._src.base.Schedule: """ builds a linear learning rate schedule. """ lr_schedule = optax.linear_schedule( init_value=lr, end_value=1e-6, transition_begin=transition_begin, transition_steps=transition_steps, ) return lr_schedule
def make_optimizer(opt, lr, transition_begin, transition_steps, opt_kwargs): if lr <= 1e-7: optimizer = optax.set_to_zero() else: schedule = get_schedule(lr, transition_begin, transition_steps) optimizer = opt(schedule, **opt_kwargs) return optimizer
[docs] def get_opt( params, transition_begin: int, transition_steps: int, emb_lr: float = 0.02, nn_lr: float = 0.03, scale_lr: float = 0.001, shift_lr: float = 0.05, zbl_lr: float = 0.001, opt_name: str = "adam", opt_kwargs: dict = {}, **kwargs, ) -> optax._src.base.GradientTransformation: """ Builds an optimizer with different learning rates for each parameter group. Several `optax` optimizers are supported. """ log.info("Initializing Optimizer") opt = getattr(optax, opt_name) nn_opt = make_optimizer(opt, nn_lr, transition_begin, transition_steps, opt_kwargs) emb_opt = make_optimizer(opt, emb_lr, transition_begin, transition_steps, opt_kwargs) scale_opt = make_optimizer( opt, scale_lr, transition_begin, transition_steps, opt_kwargs ) shift_opt = make_optimizer( opt, shift_lr, transition_begin, transition_steps, opt_kwargs ) zbl_opt = make_optimizer(opt, zbl_lr, transition_begin, transition_steps, opt_kwargs) partition_optimizers = { "w": nn_opt, "b": nn_opt, "atomic_type_embedding": emb_opt, "scale_per_element": scale_opt, "shift_per_element": shift_opt, "a_exp": zbl_opt, "a_num": zbl_opt, "coefficients": zbl_opt, "exponents": zbl_opt, "rep_scale": zbl_opt, } param_partitions = freeze( traverse_util.path_aware_map(lambda path, v: path[-1], params) ) tx = optax.multi_transform(partition_optimizers, param_partitions) return tx