kd.train.TrainStep

kd.train.TrainStep#

class kauldron.train.TrainStep(
*,
_fake_refs: type[_FakeRefsUnset] | dict[str,
_FakeRootCfg] = <class 'kauldron.utils.config_util._FakeRefsUnset'>,
model: nn.Module = _FakeRootCfg('cfg.model'),
optimizer: optax.GradientTransformation = _FakeRootCfg('cfg.optimizer'),
rng_streams: rngs_lib.RngStreams = _FakeRootCfg('cfg.rng_streams'),
sharding: sharding_lib.ShardingStrategy = _FakeRootCfg('cfg.sharding'),
init_transform: partial_loader.AbstractPartialLoader = _FakeRootCfg('cfg.init_transform'),
aux: auxiliaries.Auxiliaries = <factory>,
)[source]

Bases: kauldron.utils.config_util.UpdateFromRootCfg

Base Training Step.

Subclasses can overwrite the _step method to implement custom training steps.

model: nn.Module = _FakeRootCfg('cfg.model')
optimizer: optax.GradientTransformation = _FakeRootCfg('cfg.optimizer')
rng_streams: rngs_lib.RngStreams = _FakeRootCfg('cfg.rng_streams')
sharding: sharding_lib.ShardingStrategy = _FakeRootCfg('cfg.sharding')
init_transform: partial_loader.AbstractPartialLoader = _FakeRootCfg('cfg.init_transform')
aux: auxiliaries.Auxiliaries
init(
elem_spec: ElementSpec,
*,
model_method: str | None = None,
skip_transforms: bool = False,
skip_optimizer: bool = False,
) TrainState[source]

Initialize the model and return the initial TrainState.

Parameters:
  • elem_spec – Structure of the input batch

  • model_method – Name of the flax model method (default to __call__)

  • skip_transforms – If False, apply the init_transform on the state (e.g. to overwrite the weights with ones from another checkpoint).

  • skip_optimizer – If True, do not initialize the optimizer.

Returns:

The training state

Return type:

state

step(
state: TrainState,
batch: PyTree[Any],
*,
return_losses: bool = False,
return_metrics: bool = False,
return_summaries: bool = False,
checkify_error_categories: frozenset[trainer_lib.CheckifyErrorCategory] = frozenset({}),
) tuple[TrainState, auxiliaries.AuxiliariesState][source]

Training step: forward, losses, gradients, update, and metrics.

Parameters:
  • state – The training state

  • batch – The batch to use for the training step

  • return_losses – Whether to return the losses

  • return_metrics – Whether to return the metrics

  • return_summaries – Whether to return the summaries

  • checkify_error_categories – Categories of errors to checkify. If empty, no checkify is performed.

Returns:

The updated training state auxiliaries: Auxiliaries containing the losses, metrics and summaries

states.

Return type:

state