kd.ckpts.items.StandardCheckpointItem#
- class kauldron.checkpoints.checkpoint_items.StandardCheckpointItem[source]
Bases:
kauldron.checkpoints.checkpoint_items.CheckpointItemStandard checkpoint item (for arbitrary jax.Array pytree).
Inheriting from this class add support for checkpointing. Usage:
@flax.struct.dataclass class MyState(StandardCheckpointItem): params: Tree[jax.Array]
Passing this base class to Checkpointer.restore allow to restore the state without knowing its structure.