Optimizers

apax.optimizer.get_optimizer.cyclic_cosine_decay_schedule(init_value: float, steps_per_epoch, period: int, decay_factor: float = 0.9) Callable[[Array | ndarray | bool | number | float | int], Array | ndarray | bool | number | float | int][source]

Returns a function which implements cyclic cosine learning rate decay.

Parameters:

init_value – An initial value for the learning rate.

Returns:

schedule A function that maps step counts to values.

apax.optimizer.get_optimizer.get_opt(params, n_epochs: int, steps_per_epoch: int, emb_lr: float = 0.02, nn_lr: float = 0.03, scale_lr: float = 0.001, shift_lr: float = 0.05, zbl_lr: float = 0.001, rep_scale_lr: float = 0.001, rep_prefactor_lr: float = 0.0001, gradient_clipping=1000.0, freeze_layers: list[str] = [], name: str = 'adam', kwargs: dict = {}, schedule: dict = {}) GradientTransformation[source]

Builds an optimizer with different learning rates for each parameter group. Several optax optimizers are supported.

apax.optimizer.get_optimizer.get_schedule(lr: float, n_epochs: int, steps_per_epoch: int, schedule_kwargs: dict) Callable[[Array | ndarray | bool | number | float | int], Array | ndarray | bool | number | float | int][source]

builds a linear learning rate schedule.