kd.optim

kd.optim#

[[Source]]

Optimizers etc.

Symbols#

Class#

kd.optim.UseEmaParams

Use the EMA parameters stored by the ema_params transform.

Function#

kd.optim.decay_to_init

Add (params - init_params) scaled by weight_decay.

kd.optim.ema_params

Store an EMA version of model parameters.

kd.optim.exclude

Create a mask which selects all nodes except the ones matching the pattern.

kd.optim.named_chain

Wraps optax.named_chain and allows passing transformations as kwargs.

kd.optim.partial_updates

Applies the optimizer to a subset of the parameters.

kd.optim.select

Create a mask which selects only the sub-pytree matching the pattern.