Optimizers

apax.optimizer.get_optimizer.cyclic_cosine_decay_schedule(init_value: float, steps_per_epoch, period: int, decay_factor: float = 0.9) Callable[[Array | ndarray | bool_ | number | float | int], Array | ndarray | bool_ | number | float | int][source]

Returns a function which implements cyclic cosine learning rate decay.

Parameters:

init_value – An initial value for the learning rate.

Returns:

schedule A function that maps step counts to values.

apax.optimizer.get_optimizer.get_opt(params, n_epochs: int, steps_per_epoch: int, emb_lr: float = 0.02, nn_lr: float = 0.03, scale_lr: float = 0.001, shift_lr: float = 0.05, zbl_lr: float = 0.001, name: str = 'adam', kwargs: dict = {}, schedule: dict = {}) GradientTransformation[source]

Builds an optimizer with different learning rates for each parameter group. Several optax optimizers are supported.

apax.optimizer.get_optimizer.get_schedule(lr: float, n_epochs: int, steps_per_epoch: int, schedule_kwargs: dict) Callable[[Array | ndarray | bool_ | number | float | int], Array | ndarray | bool_ | number | float | int][source]

builds a linear learning rate schedule.

apax.optimizer.get_optimizer.sam(lr=0.001, b1=0.9, b2=0.999, rho=0.001, sync_period=2)[source]

A SAM optimizer using Adam for the outer optimizer.