kd.train.AuxiliariesState

kd.train.AuxiliariesState#

class kauldron.train.AuxiliariesState(
loss_states: typing.Mapping[str,
kauldron.metrics.base_state.State] = <factory>,
metric_states: typing.Mapping[str,
kauldron.metrics.base_state.State] = <factory>,
summary_states: typing.Mapping[str,
kauldron.metrics.base_state.State] = <factory>,
error: jax._src.checkify.Error = Error(_pred={},
_code={},
_metadata={},
_payload={}),
)[source]

Bases: kauldron.checkpoints.checkpoint_items.StandardCheckpointItem

Auxiliaries (intermediate states to be accumulated).

loss_states: Mapping[str, kauldron.metrics.base_state.State]
metric_states: Mapping[str, kauldron.metrics.base_state.State]
summary_states: Mapping[str, kauldron.metrics.base_state.State]
error: jax._src.checkify.Error = Error(_pred={}, _code={}, _metadata={}, _payload={})
replace(**updates)

Returns a new object replacing the specified fields with new values.

merge(
other: kauldron.train.auxiliaries.AuxiliariesState | None,
) kauldron.train.auxiliaries.AuxiliariesState[source]

Accumulate auxiliary.

finalize() kauldron.train.auxiliaries.AuxiliariesState[source]

Finalizes the auxiliary state.

compute(
*,
flatten: bool = True,
) kauldron.train.auxiliaries.AuxiliariesOutput[source]

Compute losses and metrics.