kd.ckpts.PartialKauldronLoader#
- class kauldron.checkpoints.PartialKauldronLoader(
- *,
- workdir: str | os.PathLike,
- new_to_old: collections.abc.MutableMapping[str,
- str] = <factory>,
- step: int = -1,
Bases:
etils.epy.contextlib.ContextManager,kauldron.checkpoints.partial_loader.AbstractPartialLoaderPartial loader for Kauldron checkpoints.
Allow to use pretrained weights from another Kauldron checkpoint.
Usage:
cfg.init_transform = kd.ckpts.PartialKauldronLoader( workdir='/path/to/original/work_unit/', new_to_old={ # Mapping params # '<new_path>': '<source_path>' 'params.decoder.layers_0': 'params.encoder', }, ) trainer = konfig.resolve(cfg) # When initializing the weights, the `init_transform` is applied init_state = trainer.init_state() # `init_state.params['decoder']['layers_0']` now contain the previous encoder # weights
- workdir
The work directory from which the checkpoint should be loaded ( can be created from
kd.ckpts.workdir_from_xid).- Type:
str | os.PathLike
- new_to_old
Mapping the pytree to copy to the new state from the original checkpoint. By default, copy all model params and
collections- Type:
collections.abc.MutableMapping[str, str]
- step
Which step to load (default to last one)
- Type:
int
- workdir: str | os.PathLike
- new_to_old: collections.abc.MutableMapping[str, str]
- step: int = -1
- transform(
- state: kauldron.checkpoints.partial_loader._T,
Transform the state by updating it with pre-trained values.
Notes:
transformfunctions can modify the state values but should NOT modify its structure, shape or dtypes.transformshould correctly propagate the sharding information from the given state.
- Parameters:
state – The state object to transform
- Returns:
The updated state
- transform_after_optimizer(
- state: kauldron.checkpoints.partial_loader._T,
Transformation applied after the optimizer has been restored.
The
transformmethod is called before the optimizer has been restored. This allows the optimizer to depend on the pre-trained values (e.g. when using optaxdecay_to_initand ema_weight_wrapper transforms).However sometimes, optimizer state also need to be restored from a pre-trained checkpoint. This can be done by this method.
- Parameters:
state – The state object to transform
- Returns:
The updated state
- close() None[source]