kd.ckpts.MultiTransform#
- class kauldron.checkpoints.MultiTransform(
- **transforms: kauldron.checkpoints.partial_loader.AbstractPartialLoader,
Bases:
kauldron.checkpoints.partial_loader.AbstractPartialLoaderTransform which applies multiple transformations sequentially.
- transform(state)[source]
Transform the state by updating it with pre-trained values.
Notes:
transformfunctions can modify the state values but should NOT modify its structure, shape or dtypes.transformshould correctly propagate the sharding information from the given state.
- Parameters:
state – The state object to transform
- Returns:
The updated state
- transform_after_optimizer(
- state: kauldron.checkpoints.partial_loader._T,
Transformation applied after the optimizer has been restored.
The
transformmethod is called before the optimizer has been restored. This allows the optimizer to depend on the pre-trained values (e.g. when using optaxdecay_to_initand ema_weight_wrapper transforms).However sometimes, optimizer state also need to be restored from a pre-trained checkpoint. This can be done by this method.
- Parameters:
state – The state object to transform
- Returns:
The updated state