kd.train.TrainStep#
- class kauldron.train.TrainStep(
- *,
- _fake_refs: type[_FakeRefsUnset] | dict[str,
- _FakeRootCfg] = <class 'kauldron.utils.config_util._FakeRefsUnset'>,
- model: nn.Module = _FakeRootCfg('cfg.model'),
- optimizer: optax.GradientTransformation = _FakeRootCfg('cfg.optimizer'),
- rng_streams: rngs_lib.RngStreams = _FakeRootCfg('cfg.rng_streams'),
- sharding: sharding_lib.ShardingStrategy = _FakeRootCfg('cfg.sharding'),
- init_transform: partial_loader.AbstractPartialLoader = _FakeRootCfg('cfg.init_transform'),
- aux: auxiliaries.Auxiliaries = <factory>,
Bases:
kauldron.utils.config_util.UpdateFromRootCfgBase Training Step.
Subclasses can overwrite the _step method to implement custom training steps.
- model: nn.Module = _FakeRootCfg('cfg.model')
- optimizer: optax.GradientTransformation = _FakeRootCfg('cfg.optimizer')
- rng_streams: rngs_lib.RngStreams = _FakeRootCfg('cfg.rng_streams')
- sharding: sharding_lib.ShardingStrategy = _FakeRootCfg('cfg.sharding')
- init_transform: partial_loader.AbstractPartialLoader = _FakeRootCfg('cfg.init_transform')
- aux: auxiliaries.Auxiliaries
- init(
- elem_spec: ElementSpec,
- *,
- model_method: str | None = None,
- skip_transforms: bool = False,
- skip_optimizer: bool = False,
Initialize the model and return the initial TrainState.
- Parameters:
elem_spec – Structure of the input batch
model_method – Name of the flax model method (default to __call__)
skip_transforms – If False, apply the init_transform on the state (e.g. to overwrite the weights with ones from another checkpoint).
skip_optimizer – If True, do not initialize the optimizer.
- Returns:
The training state
- Return type:
state
- step(
- state: TrainState,
- batch: PyTree[Any],
- *,
- return_losses: bool = False,
- return_metrics: bool = False,
- return_summaries: bool = False,
- checkify_error_categories: frozenset[trainer_lib.CheckifyErrorCategory] = frozenset({}),
Training step: forward, losses, gradients, update, and metrics.
- Parameters:
state – The training state
batch – The batch to use for the training step
return_losses – Whether to return the losses
return_metrics – Whether to return the metrics
return_summaries – Whether to return the summaries
checkify_error_categories – Categories of errors to checkify. If empty, no checkify is performed.
- Returns:
The updated training state auxiliaries: Auxiliaries containing the losses, metrics and summaries
states.
- Return type:
state