kd.metrics.CollectFirstState

kd.metrics.CollectFirstState#

class kauldron.metrics.CollectFirstState(
*,
parent: kauldron.metrics.base_state._MetricT = _EMPTY_TYPE.EMPTY,
keep_first: int,
)[source]

Bases: kauldron.metrics.base_state.State[kauldron.metrics.base_state._MetricT]

Get the first outputs (possibly) across multiple steps (no reducing).

Example:

@flax.struct.dataclass
class FirstNImages(kd.metrics.CollectFirstState):
  images: Float['N h w 3']


state0 = FirstNImages(images=jnp.zeros((4, 16, 16, 3)), keep_first=5)
state1 = FirstNImages(images=jnp.ones((4, 16, 16, 3)), keep_first=5)
final_state = state0.merge(state1)
assert final_state.compute().images.shape == (5, 16, 16, 3)
keep_first: int
classmethod empty() kauldron.metrics.base_state._SelfT[source]

Returns an empty instance (i.e. .merge(State.empty()) is a no-op).

merge(
other: kauldron.metrics.base_state._SelfT,
) kauldron.metrics.base_state._SelfT[source]

Returns a new state that is the accumulation of self and other.

Parameters:

other – A State whose intermediate values should be accumulated onto the values of self.

Returns:

A new State that accumulates the value from both self and other.

compute() kauldron.metrics.base_state._SelfT[source]

Returns the concatenated values.

replace(**updates)

Returns a new object replacing the specified fields with new values.