kd.train.AuxiliariesState#
- class kauldron.train.AuxiliariesState(
- loss_states: typing.Mapping[str,
- kauldron.metrics.base_state.State] = <factory>,
- metric_states: typing.Mapping[str,
- kauldron.metrics.base_state.State] = <factory>,
- summary_states: typing.Mapping[str,
- kauldron.metrics.base_state.State] = <factory>,
- error: jax._src.checkify.Error = Error(_pred={},
- _code={},
- _metadata={},
- _payload={}),
Bases:
objectAuxiliaries (intermediate states to be accumulated).
- loss_states: Mapping[str, kauldron.metrics.base_state.State]
- metric_states: Mapping[str, kauldron.metrics.base_state.State]
- summary_states: Mapping[str, kauldron.metrics.base_state.State]
- error: jax._src.checkify.Error = Error(_pred={}, _code={}, _metadata={}, _payload={})
- replace(**updates)
Returns a new object replacing the specified fields with new values.
- merge(
- other: kauldron.train.auxiliaries.AuxiliariesState | None,
Accumulate auxiliary.
- compute(
- *,
- flatten: bool = True,
Compute losses and metrics.