kd.optim.decay_to_init

kd.optim.decay_to_init#

kauldron.optim.decay_to_init(
weight_decay: float | jax.Array,
mask: Any | Callable[[optax.Params], Any] | None = None,
) optax.GradientTransformation[source]

Add (params - init_params) scaled by weight_decay.

This effectively acts as weight decay not towards zero but towards the initialization of the model. Useful for finetuning of pre-trained models.

Parameters:
  • weight_decay – A scalar weight decay rate.

  • mask – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the transformation to, and False for those you want to skip.

Returns:

A GradientTransformation object.