kd.ckpts.AbstractPartialLoader

kd.ckpts.AbstractPartialLoader#

class kauldron.checkpoints.AbstractPartialLoader[source]

Bases: abc.ABC

Abstract class for partial checkpoint loaders.

During state initialization, order is as follow:

  1. Initialize the model params (model.init())

  2. Apply the init_transform.transform() to the state

  3. Initialize the optimizer

  4. Apply the init_transform.transform_after_optimizer() to the state

This order allows:

  • To have the optimizer depend on the pre-trained values (e.g. when using optax decay_to_init and ema_weight_wrapper transforms).

  • To restore the optimizer state from a pre-trained checkpoint.

abstractmethod transform(
state: kauldron.checkpoints.partial_loader._T,
) kauldron.checkpoints.partial_loader._T[source]

Transform the state by updating it with pre-trained values.

Notes:

  • transform functions can modify the state values but should NOT modify its structure, shape or dtypes.

  • transform should correctly propagate the sharding information from the given state.

Parameters:

state – The state object to transform

Returns:

The updated state

transform_after_optimizer(
state: kauldron.checkpoints.partial_loader._T,
) kauldron.checkpoints.partial_loader._T[source]

Transformation applied after the optimizer has been restored.

The transform method is called before the optimizer has been restored. This allows the optimizer to depend on the pre-trained values (e.g. when using optax decay_to_init and ema_weight_wrapper transforms).

However sometimes, optimizer state also need to be restored from a pre-trained checkpoint. This can be done by this method.

Parameters:

state – The state object to transform

Returns:

The updated state