kd.metrics.Std

kd.metrics.Std#

class kauldron.metrics.Std(
*,
values: typing.Annotated[typing.Any,
<object object at 0x76412092fb90>] = '__KEY_REQUIRED__',
mask: typing.Annotated[typing.Any,
<object object at 0x76412092fb90>] | None = None,
)[source]

Bases: kauldron.metrics.base.Metric

Compute the standard deviation for float values.

values: Annotated[Any, <object object at 0x76412092fb90>] = '__KEY_REQUIRED__'
mask: Annotated[Any, <object object at 0x76412092fb90>] | None = None
class State(
total: 'jnp.ndarray',
sum_of_squares: 'jnp.ndarray',
count: 'jnp.ndarray',
*,
parent: '_MetricT' = <_EMPTY_TYPE.EMPTY: 1>,
)[source]

Bases: kauldron.metrics.stats.StdState

merge(
other: kauldron.metrics.stats.StdState,
) kauldron.metrics.stats.StdState[source]

Returns a new state that is the accumulation of self and other.

Parameters:

other – A State whose intermediate values should be accumulated onto the values of self.

Returns:

A new State that accumulates the value from both self and other.

replace(**updates)

Returns a new object replacing the specified fields with new values.

get_state(
values: jaxtyping.Float[Array, '*b n'] | jaxtyping.Float[ndarray, '*b n'],
mask: jaxtyping.Bool[Array, '*b 1'] | jaxtyping.Bool[ndarray, '*b 1'] | jaxtyping.Float[Array, '*b 1'] | jaxtyping.Float[ndarray, '*b 1'] | None = None,
) kauldron.metrics.stats.Std.State[source]
empty() kauldron.metrics.base.Metric.State[source]