kd.train.TrainState

kd.train.TrainState#

class kauldron.train.TrainState(
*,
step: int,
params: _Params | None,
collections: _Collections | None,
opt_state: PyTree[Float['...']] | None,
)[source]

Bases: kauldron.checkpoints.checkpoint_items.StandardCheckpointItem

Data structure for checkpointing the model.

step

Current training step.

Type:

int

params

Model parameters.

Type:

Optional[_Params]

opt_state

Optimizer state.

Type:

Optional[PyTree[Float[’…’]]]

collections

Mutable flax collections (e.g. ‘batch_stats’).

Type:

Optional[_Collections]

training_time_hours

Training time in hours.

step: int
params: _Params | None
collections: _Collections | None
opt_state: PyTree[Float['...']] | None
replace(**updates)

Returns a new object replacing the specified fields with new values.