kd.ckpts.items.StandardCheckpointItem

kd.ckpts.items.StandardCheckpointItem#

class kauldron.checkpoints.checkpoint_items.StandardCheckpointItem[source]

Bases: kauldron.checkpoints.checkpoint_items.CheckpointItem

Standard checkpoint item (for arbitrary jax.Array pytree).

Inheriting from this class add support for checkpointing. Usage:

@flax.struct.dataclass
class MyState(StandardCheckpointItem):
  params: Tree[jax.Array]

Passing this base class to Checkpointer.restore allow to restore the state without knowing its structure.