kd.optim.ema_params

kd.optim.ema_params#

kauldron.optim.ema_params(
*,
decay: float,
debias: bool = True,
accumulator_dtype: Any | None = None,
) optax._src.base.GradientTransformation[source]

Store an EMA version of model parameters.

Different from optax.ema, here we do not alter the gradient. Instead, we maintain a copy of model parameters, which is an EMA over training steps. These weights can then e.g. be used during evalutation.

NOTE: This function should be called last, e.g., at the end of optax.chain, because it applies the updates to the parameters and uses the updated parameters to update the EMA parameters.

Example usage: .. code-block:

cfg.optimizer = kd.optim.named_chain(**{
    "adam": optax.scale_by_adam(b1=0.95),
    "ema_params": kd.optim.ema_params(decay=0.999),
})

cfg.evals = {
    "ema_eval": kd.evals.Evaluator(
        init_transform=kd.optim.UseEmaParams(),
    )
}
Parameters:
  • decay – Decay rate for the exponential moving average.

  • debias – Whether to debias the transformed gradient.

  • accumulator_dtype – Optional dtype to used for the accumulator; if None then the dtype is inferred from params and updates.

Returns:

A GradientTransformation object.