kd.optim.decay_to_init#
- kauldron.optim.decay_to_init(
- weight_decay: float | jax.Array,
- mask: Any | Callable[[optax.Params], Any] | None = None,
Add (params - init_params) scaled by weight_decay.
This effectively acts as weight decay not towards zero but towards the initialization of the model. Useful for finetuning of pre-trained models.
- Parameters:
weight_decay – A scalar weight decay rate.
mask – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the transformation to, and False for those you want to skip.
- Returns:
A GradientTransformation object.