kd.ckpts.PartialKauldronLoader

kd.ckpts.PartialKauldronLoader#

class kauldron.checkpoints.PartialKauldronLoader(
*,
workdir: str | os.PathLike,
new_to_old: collections.abc.MutableMapping[str,
str] = <factory>,
step: int = -1,
)[source]

Bases: etils.epy.contextlib.ContextManager, kauldron.checkpoints.partial_loader.AbstractPartialLoader

Partial loader for Kauldron checkpoints.

Allow to use pretrained weights from another Kauldron checkpoint.

Usage:

cfg.init_transform = kd.ckpts.PartialKauldronLoader(
    workdir='/path/to/original/work_unit/',
    new_to_old={  # Mapping params
        # '<new_path>':            '<source_path>'
        'params.decoder.layers_0': 'params.encoder',
    },
)

trainer = konfig.resolve(cfg)

# When initializing the weights, the `init_transform` is applied
init_state = trainer.init_state()

# `init_state.params['decoder']['layers_0']` now contain the previous encoder
# weights
workdir

The work directory from which the checkpoint should be loaded ( can be created from kd.ckpts.workdir_from_xid).

Type:

str | os.PathLike

new_to_old

Mapping the pytree to copy to the new state from the original checkpoint. By default, copy all model params and collections

Type:

collections.abc.MutableMapping[str, str]

step

Which step to load (default to last one)

Type:

int

workdir: str | os.PathLike
new_to_old: collections.abc.MutableMapping[str, str]
step: int = -1
transform(
state: kauldron.checkpoints.partial_loader._T,
) kauldron.checkpoints.partial_loader._T[source]

Transform the state by updating it with pre-trained values.

Notes:

  • transform functions can modify the state values but should NOT modify its structure, shape or dtypes.

  • transform should correctly propagate the sharding information from the given state.

Parameters:

state – The state object to transform

Returns:

The updated state

transform_after_optimizer(
state: kauldron.checkpoints.partial_loader._T,
) kauldron.checkpoints.partial_loader._T[source]

Transformation applied after the optimizer has been restored.

The transform method is called before the optimizer has been restored. This allows the optimizer to depend on the pre-trained values (e.g. when using optax decay_to_init and ema_weight_wrapper transforms).

However sometimes, optimizer state also need to be restored from a pre-trained checkpoint. This can be done by this method.

Parameters:

state – The state object to transform

Returns:

The updated state

close() None[source]