kd.ckpts.MultiTransform

kd.ckpts.MultiTransform#

class kauldron.checkpoints.MultiTransform(
**transforms: kauldron.checkpoints.partial_loader.AbstractPartialLoader,
)[source]

Bases: kauldron.checkpoints.partial_loader.AbstractPartialLoader

Transform which applies multiple transformations sequentially.

transform(state)[source]

Transform the state by updating it with pre-trained values.

Notes:

  • transform functions can modify the state values but should NOT modify its structure, shape or dtypes.

  • transform should correctly propagate the sharding information from the given state.

Parameters:

state – The state object to transform

Returns:

The updated state

transform_after_optimizer(
state: kauldron.checkpoints.partial_loader._T,
) kauldron.checkpoints.partial_loader._T[source]

Transformation applied after the optimizer has been restored.

The transform method is called before the optimizer has been restored. This allows the optimizer to depend on the pre-trained values (e.g. when using optax decay_to_init and ema_weight_wrapper transforms).

However sometimes, optimizer state also need to be restored from a pre-trained checkpoint. This can be done by this method.

Parameters:

state – The state object to transform

Returns:

The updated state