kd.train.TrainState#
- class kauldron.train.TrainState(
- *,
- step: int,
- params: _Params | None,
- collections: _Collections | None,
- opt_state: PyTree[Float['...']] | None,
Bases:
kauldron.checkpoints.checkpoint_items.StandardCheckpointItemData structure for checkpointing the model.
- step
Current training step.
- Type:
int
- params
Model parameters.
- Type:
Optional[_Params]
- opt_state
Optimizer state.
- Type:
Optional[PyTree[Float[’…’]]]
- collections
Mutable flax collections (e.g. ‘batch_stats’).
- Type:
Optional[_Collections]
- training_time_hours
Training time in hours.
- step: int
- params: _Params | None
- collections: _Collections | None
- opt_state: PyTree[Float['...']] | None
- replace(**updates)
Returns a new object replacing the specified fields with new values.