kd.optim.named_chain

kd.optim.named_chain#

kauldron.optim.named_chain(
**transforms: optax._src.base.GradientTransformation,
) optax._src.base.GradientTransformationExtraArgs[source]

Wraps optax.named_chain and allows passing transformations as kwargs.

Example Usage: .. code-block:

cfg.optimizer = kd.optim.named_chain(**{
    "clip": optax.clip_by_global_norm(max_norm=1.0),
    "adam": optax.scale_by_adam(b1=0.95),
    "decay": optax.add_decayed_weights(weight_decay=0.1),
    "lr": kd.optim.scale_by_learning_rate(0.003),
})
The advantages of this over using optax.chain are:
  1. Readability of the config and the sweeps because the path becomes “optimizer.adam.b1” rather than “optimizer[1].b1”.

  2. The state of the optimizer (as stored in the checkpoint and in context) becomes a dictionary instead of a tuple. So it is easier to understand, access and manipulate.

Parameters:

**transforms – A list of GradientTransformations with names passed as kwargs.

Returns:

An optax.GradientTransformation that corresponds to applying the list of transformations in sequence.