Source code for apax.optimizer.get_optimizer

import logging

import jax.numpy as jnp
import numpy as np
import optax
from flax import traverse_util
from flax.core.frozen_dict import freeze
from optax import contrib
from optax._src import base

log = logging.getLogger(__name__)


[docs] def sam(lr=1e-3, b1=0.9, b2=0.999, rho=0.001, sync_period=2): """A SAM optimizer using Adam for the outer optimizer.""" opt = optax.adam(lr, b1=b1, b2=b2) adv_opt = optax.chain(contrib.normalize(), optax.sgd(rho)) return contrib.sam(opt, adv_opt, sync_period=sync_period)
[docs] def cyclic_cosine_decay_schedule( init_value: float, steps_per_epoch, period: int, decay_factor: float = 0.9, ) -> base.Schedule: r"""Returns a function which implements cyclic cosine learning rate decay. Args: init_value: An initial value for the learning rate. Returns: schedule A function that maps step counts to values. """ def schedule(count): cycle = count // (period * steps_per_epoch) step_in_period = jnp.mod(count, period * steps_per_epoch) lr = ( init_value / 2 * (jnp.cos(np.pi * step_in_period / (period * steps_per_epoch)) + 1) ) lr = lr * (decay_factor**cycle) return lr return schedule
[docs] def get_schedule( lr: float, n_epochs: int, steps_per_epoch: int, schedule_kwargs: dict, ) -> optax._src.base.Schedule: """ builds a linear learning rate schedule. """ schedule_kwargs = schedule_kwargs.copy() name = schedule_kwargs.pop("name") if name == "linear": lr_schedule = optax.linear_schedule( init_value=lr, transition_steps=n_epochs * steps_per_epoch, **schedule_kwargs ) elif name == "cyclic_cosine": lr_schedule = cyclic_cosine_decay_schedule(lr, steps_per_epoch, **schedule_kwargs) else: raise KeyError(f"unknown learning rate schedule: {name}") return lr_schedule
def make_optimizer(opt, lr, n_epochs, steps_per_epoch, kwargs, schedule): if lr <= 1e-7: optimizer = optax.set_to_zero() else: schedule = get_schedule(lr, n_epochs, steps_per_epoch, schedule) optimizer = opt(schedule, **kwargs) return optimizer
[docs] def get_opt( params, n_epochs: int, steps_per_epoch: int, emb_lr: float = 0.02, nn_lr: float = 0.03, scale_lr: float = 0.001, shift_lr: float = 0.05, zbl_lr: float = 0.001, name: str = "adam", kwargs: dict = {}, schedule: dict = {}, ) -> optax._src.base.GradientTransformation: """ Builds an optimizer with different learning rates for each parameter group. Several `optax` optimizers are supported. """ log.info("Initializing Optimizer") if name == "sam": opt = sam else: opt = getattr(optax, name) nn_opt = make_optimizer(opt, nn_lr, n_epochs, steps_per_epoch, kwargs, schedule) emb_opt = make_optimizer(opt, emb_lr, n_epochs, steps_per_epoch, kwargs, schedule) scale_opt = make_optimizer(opt, scale_lr, n_epochs, steps_per_epoch, kwargs, schedule) shift_opt = make_optimizer(opt, shift_lr, n_epochs, steps_per_epoch, kwargs, schedule) zbl_opt = make_optimizer(opt, zbl_lr, n_epochs, steps_per_epoch, kwargs, schedule) partition_optimizers = { "w": nn_opt, "b": nn_opt, "atomic_type_embedding": emb_opt, "scale_per_element": scale_opt, "shift_per_element": shift_opt, "a_exp": zbl_opt, "a_num": zbl_opt, "coefficients": zbl_opt, "exponents": zbl_opt, "rep_scale": zbl_opt, } param_partitions = freeze( traverse_util.path_aware_map(lambda path, v: path[-1], params) ) tx = optax.multi_transform(partition_optimizers, param_partitions) return tx