kd.metrics.CollectFirstState#
- class kauldron.metrics.CollectFirstState(
- *,
- parent: kauldron.metrics.base_state._MetricT = _EMPTY_TYPE.EMPTY,
- keep_first: int,
Bases:
kauldron.metrics.base_state.State[kauldron.metrics.base_state._MetricT]Get the first outputs (possibly) across multiple steps (no reducing).
Example:
@flax.struct.dataclass class FirstNImages(kd.metrics.CollectFirstState): images: Float['N h w 3'] state0 = FirstNImages(images=jnp.zeros((4, 16, 16, 3)), keep_first=5) state1 = FirstNImages(images=jnp.ones((4, 16, 16, 3)), keep_first=5) final_state = state0.merge(state1) assert final_state.compute().images.shape == (5, 16, 16, 3)
- keep_first: int
- classmethod empty() kauldron.metrics.base_state._SelfT[source]
Returns an empty instance (i.e. .merge(State.empty()) is a no-op).
- merge(
- other: kauldron.metrics.base_state._SelfT,
Returns a new state that is the accumulation of self and other.
- compute() kauldron.metrics.base_state._SelfT[source]
Returns the concatenated values.
- replace(**updates)
Returns a new object replacing the specified fields with new values.