kd.metrics.CollectingState

kd.metrics.CollectingState#

class kauldron.metrics.CollectingState(
*,
parent: kauldron.metrics.base_state._MetricT = _EMPTY_TYPE.EMPTY,
)[source]

Bases: kauldron.metrics.base_state.State[kauldron.metrics.base_state._MetricT]

Accumulate outputs across multiple steps (without reducing).

Example:

@flax.struct.dataclass
class AveragePrecision(kd.metrics.CollectingState):
  labels: Float['n_samples']
  logits: Float['n_samples n_classes']

  def compute(self):
    values = super().compute()  # Concatenate all accumulated values
    return sklearn.metrics.average_precision_score(  # Reduce
        values.labels,
        values.logits,
    )

state0 = AveragePrecision(labels=labels0, logits=logits0)
state1 = AveragePrecision(labels=labels1, logits=logits1)

final_state = state0.merge(state1)  # Accumulate the values

out = final_state.compute()  # Concatenate and reduce
  • Internally, the states are normalized and stored as tuple: * state0.labels = (labels0,)) * final_state.labels = (labels0, labels1))

  • state.merge(other_state) only accumulate the values (append to a tuple).

  • Reduction is only applied on .compute()

  • To support mask, the subclass can accumulate the mask values and use it in the final computation.

  • Because merge() keep all values, those metrics uses much more memory and can be slow to compute.

classmethod empty() kauldron.metrics.base_state._SelfT[source]

Returns an empty instance (i.e. .merge(State.empty()) is a no-op).

merge(
other: kauldron.metrics.base_state._SelfT,
) kauldron.metrics.base_state._SelfT[source]

Returns a new state that is the accumulation of self and other.

Parameters:

other – A State whose intermediate values should be accumulated onto the values of self.

Returns:

A new State that accumulates the value from both self and other.

compute() kauldron.metrics.base_state._SelfT[source]

Returns the concatenated values.

replace(**updates)

Returns a new object replacing the specified fields with new values.