kd.optim.named_chain#
- kauldron.optim.named_chain(
- **transforms: optax._src.base.GradientTransformation,
Wraps optax.named_chain and allows passing transformations as kwargs.
Example Usage: .. code-block:
cfg.optimizer = kd.optim.named_chain(**{ "clip": optax.clip_by_global_norm(max_norm=1.0), "adam": optax.scale_by_adam(b1=0.95), "decay": optax.add_decayed_weights(weight_decay=0.1), "lr": kd.optim.scale_by_learning_rate(0.003), })
- The advantages of this over using optax.chain are:
Readability of the config and the sweeps because the path becomes “optimizer.adam.b1” rather than “optimizer[1].b1”.
The state of the optimizer (as stored in the checkpoint and in context) becomes a dictionary instead of a tuple. So it is easier to understand, access and manipulate.
- Parameters:
**transforms – A list of GradientTransformations with names passed as kwargs.
- Returns:
An optax.GradientTransformation that corresponds to applying the list of transformations in sequence.