kd.metrics.AverageState

kd.metrics.AverageState#

class kauldron.metrics.AverageState(
total: jaxtyping.Float[Array, ''] | jaxtyping.Float[ndarray, ''],
count: jaxtyping.Float[Array, ''] | jaxtyping.Float[ndarray, ''],
*,
parent: kauldron.metrics.base_state._MetricT = _EMPTY_TYPE.EMPTY,
)[source]

Bases: kauldron.metrics.base_state.State[kauldron.metrics.base_state._MetricT]

Computes the average of a scalar or a batch of tensors.

Supports the following types of masks:

  • A one-dimensional mask with the same leading dimension as the scalars, or,

  • A multi-dimensional mask with the exact same dimensions as the scalars. This allows the use of per-example masks for examples in a batch, as well as per-target masks for targets for examples in a batch.

The result is always a scalar.

total: jaxtyping.Float[Array, ''] | jaxtyping.Float[ndarray, '']
count: jaxtyping.Float[Array, ''] | jaxtyping.Float[ndarray, '']
classmethod from_values(
values: jaxtyping.Float[Array, 'b *any'] | jaxtyping.Float[ndarray, 'b *any'],
*,
mask: jaxtyping.Bool[Array, 'b *#any'] | jaxtyping.Bool[ndarray, 'b *#any'] | jaxtyping.Float[Array, 'b *#any'] | jaxtyping.Float[ndarray, 'b *#any'] | None = None,
) kauldron.metrics.base_state.AverageState[source]

Factory to create the state from an array.

classmethod empty() kauldron.metrics.base_state.AverageState[source]

Returns an empty instance (i.e. .merge(State.empty()) is a no-op).

merge(
other: kauldron.metrics.base_state.AverageState,
) kauldron.metrics.base_state.AverageState[source]

Returns a new state that is the accumulation of self and other.

Parameters:

other – A State whose intermediate values should be accumulated onto the values of self.

Returns:

A new State that accumulates the value from both self and other.

compute() jaxtyping.Float[Array, ''] | jaxtyping.Float[ndarray, ''][source]

Computes final metrics from intermediate values.

replace(**updates)

Returns a new object replacing the specified fields with new values.