kd.metrics.AverageState#
- class kauldron.metrics.AverageState(
- total: jaxtyping.Float[Array, ''] | jaxtyping.Float[ndarray, ''],
- count: jaxtyping.Float[Array, ''] | jaxtyping.Float[ndarray, ''],
- *,
- parent: kauldron.metrics.base_state._MetricT = _EMPTY_TYPE.EMPTY,
Bases:
kauldron.metrics.base_state.State[kauldron.metrics.base_state._MetricT]Computes the average of a scalar or a batch of tensors.
Supports the following types of masks:
A one-dimensional mask with the same leading dimension as the scalars, or,
A multi-dimensional mask with the exact same dimensions as the scalars. This allows the use of per-example masks for examples in a batch, as well as per-target masks for targets for examples in a batch.
The result is always a scalar.
- total: jaxtyping.Float[Array, ''] | jaxtyping.Float[ndarray, '']
- count: jaxtyping.Float[Array, ''] | jaxtyping.Float[ndarray, '']
- classmethod from_values(
- values: jaxtyping.Float[Array, 'b *any'] | jaxtyping.Float[ndarray, 'b *any'],
- *,
- mask: jaxtyping.Bool[Array, 'b *#any'] | jaxtyping.Bool[ndarray, 'b *#any'] | jaxtyping.Float[Array, 'b *#any'] | jaxtyping.Float[ndarray, 'b *#any'] | None = None,
Factory to create the state from an array.
- classmethod empty() kauldron.metrics.base_state.AverageState[source]
Returns an empty instance (i.e. .merge(State.empty()) is a no-op).
- merge(
- other: kauldron.metrics.base_state.AverageState,
Returns a new state that is the accumulation of self and other.
- compute() jaxtyping.Float[Array, ''] | jaxtyping.Float[ndarray, ''][source]
Computes final metrics from intermediate values.
- replace(**updates)
Returns a new object replacing the specified fields with new values.