kd.ckpts.Checkpointer

kd.ckpts.Checkpointer#

class kauldron.checkpoints.Checkpointer(
*,
_fake_refs: type[_FakeRefsUnset] | dict[str,
_FakeRootCfg] = <class 'kauldron.utils.config_util._FakeRefsUnset'>,
workdir: epath.PathLike = _FakeRootCfg('cfg.workdir'),
save_interval_steps: int,
max_to_keep: Optional[int] = 3,
keep_time_interval: Optional[datetime.timedelta] = None,
keep_period: Optional[int] = None,
save_on_steps: Optional[Sequence[int]] = None,
best_metric_path: Optional[str] = None,
best_mode: str = 'max',
multiprocessing_options: ocp.options.MultiprocessingOptions = <factory>,
fast: bool = True,
create: bool = True,
)[source]

Bases: kauldron.checkpoints.checkpointer.BaseCheckpointer

Wrapper around Orbax CheckpointManager.

workdir

Root directory of the task

Type:

str | os.PathLike

save_interval_steps

See ocp.CheckpointManagerOptions

Type:

int

max_to_keep

See ocp.CheckpointManagerOptions

Type:

int | None

keep_time_interval

See ocp.CheckpointManagerOptions

Type:

datetime.timedelta | None

keep_period

See ocp.CheckpointManagerOptions

Type:

int | None

save_on_steps

See ocp.CheckpointManagerOptions

Type:

Sequence[int] | None

best_metric_path

Path to evaluator’s metric for best checkpoint selection. Warning: If using a best_metric_path, the evaluator must be run inside the train loop and cannot be run as a separate job.

Type:

str | None

best_mode

See ocp.CheckpointManagerOptions

Type:

str

multiprocessing_options

See ocp.MultiprocessingOptions

Type:

orbax.checkpoint.options.MultiprocessingOptions

fast

(internal) Activate some optimizations

Type:

bool

create

(internal) Whether to create the checkpoint directory, this is set by kauldron automatically based on whether the job is a training job (True) or an eval job (False).

Type:

bool

workdir: epath.PathLike = _FakeRootCfg('cfg.workdir')
save_interval_steps: int
max_to_keep: int | None = 3
keep_time_interval: datetime.timedelta | None = None
keep_period: int | None = None
save_on_steps: Sequence[int] | None = None
best_metric_path: str | None = None
best_mode: str = 'max'
multiprocessing_options: ocp.options.MultiprocessingOptions
fast: bool = True
create: bool = True
restore(
state: kauldron.checkpoints.checkpointer._StateT,
*,
step: int = -1,
noop_if_missing: bool = False,
donate: bool = True,
) kauldron.checkpoints.checkpointer._StateT[source]

Restore state.

Parameters:
  • state – The state object initialized from the trainer. If the state is not known, you can pass kd.ckpt.items.StandardCheckpointItem() to restore the nested dict of weights.

  • step – The training step of the checkpoint to restore. -1 means last step.

  • noop_if_missing – If False will raise an error when no checkpoint is found.

  • donate – Whether delete the initial_state to free up memory when restoring the checkpoint. This avoids 2x memory consumption. It is safe to donate the initial_state if you no longer need it after restoring.

Returns:

The restored state.

Raises:

FileNotFoundError – An error occurred when no checkpoint is found.

should_save(step: int) bool[source]
delete(step: int) None[source]
save(
state: kauldron.checkpoints.checkpoint_items.CheckpointItem,
*,
step: int,
force: bool = False,
metrics: Any | None = None,
) bool[source]

Save state.

maybe_save(
state,
*,
step: int,
force: bool = False,
) bool[source]

Save state.

property latest_step: int | None
property all_steps: Sequence[int]
reload() None[source]

Refresh the cache.

For performance, the checkpointer caches the directory names. Calling this function resets the cache to allow scanning the checkpoint directory for new checkpoints.

item_metadata(
step: int = -1,
) dict[str, Any][source]

Returns the metadata (tree, shape,…) associated with the step.

iter_new_checkpoints(
*,
min_interval_secs: int = 0,
timeout: int | None = None,
timeout_fn: collections.abc.Callable[[], bool] | None = None,
) collections.abc.Iterator[int][source]

Wrapper around ocp.checkpoint_utils.checkpoints_iterator.

wait_until_finished() None[source]

Synchronizes the asynchronous checkpointing.

close() None[source]

Closes the checkpointer.