kd.ckpts.AbstractPartialLoader#
- class kauldron.checkpoints.AbstractPartialLoader[source]
Bases:
abc.ABCAbstract class for partial checkpoint loaders.
During state initialization, order is as follow:
Initialize the model params (model.init())
Apply the init_transform.transform() to the state
Initialize the optimizer
Apply the init_transform.transform_after_optimizer() to the state
This order allows:
To have the optimizer depend on the pre-trained values (e.g. when using optax
decay_to_initand ema_weight_wrapper transforms).To restore the optimizer state from a pre-trained checkpoint.
- abstractmethod transform(
- state: kauldron.checkpoints.partial_loader._T,
Transform the state by updating it with pre-trained values.
Notes:
transformfunctions can modify the state values but should NOT modify its structure, shape or dtypes.transformshould correctly propagate the sharding information from the given state.
- Parameters:
state – The state object to transform
- Returns:
The updated state
- transform_after_optimizer(
- state: kauldron.checkpoints.partial_loader._T,
Transformation applied after the optimizer has been restored.
The
transformmethod is called before the optimizer has been restored. This allows the optimizer to depend on the pre-trained values (e.g. when using optaxdecay_to_initand ema_weight_wrapper transforms).However sometimes, optimizer state also need to be restored from a pre-trained checkpoint. This can be done by this method.
- Parameters:
state – The state object to transform
- Returns:
The updated state