Optimizers#

apax.optimizer.get_optimizer.get_opt(params, transition_begin: int, transition_steps: int, emb_lr: float = 0.02, nn_lr: float = 0.03, scale_lr: float = 0.001, shift_lr: float = 0.05, zbl_lr: float = 0.001, opt_name: str = 'adam', opt_kwargs: dict = {}, **kwargs) GradientTransformation[source]#

Builds an optimizer with different learning rates for each parameter group. Several optax optimizers are supported.

apax.optimizer.get_optimizer.get_schedule(lr: float, transition_begin: int, transition_steps: int) Callable[[Array | ndarray | bool_ | number | float | int], Array | ndarray | bool_ | number | float | int][source]#

builds a linear learning rate schedule.

apax.optimizer.get_optimizer.map_nested_fn(fn: Callable[[str, Any], dict]) Callable[[dict], dict][source]#

Recursively apply fn to the key-value pairs of a nested dict See https://optax.readthedocs.io/en/latest/api.html?highlight=multitransform#multi-transform