kd.metrics.CollectingState#
- class kauldron.metrics.CollectingState(
- *,
- parent: kauldron.metrics.base_state._MetricT = _EMPTY_TYPE.EMPTY,
Bases:
kauldron.metrics.base_state.State[kauldron.metrics.base_state._MetricT]Accumulate outputs across multiple steps (without reducing).
Example:
@flax.struct.dataclass class AveragePrecision(kd.metrics.CollectingState): labels: Float['n_samples'] logits: Float['n_samples n_classes'] def compute(self): values = super().compute() # Concatenate all accumulated values return sklearn.metrics.average_precision_score( # Reduce values.labels, values.logits, ) state0 = AveragePrecision(labels=labels0, logits=logits0) state1 = AveragePrecision(labels=labels1, logits=logits1) final_state = state0.merge(state1) # Accumulate the values out = final_state.compute() # Concatenate and reduce
Internally, the states are normalized and stored as tuple: * state0.labels = (labels0,)) * final_state.labels = (labels0, labels1))
state.merge(other_state) only accumulate the values (append to a tuple).
Reduction is only applied on .compute()
To support mask, the subclass can accumulate the mask values and use it in the final computation.
Because merge() keep all values, those metrics uses much more memory and can be slow to compute.
- classmethod empty() kauldron.metrics.base_state._SelfT[source]
Returns an empty instance (i.e. .merge(State.empty()) is a no-op).
- merge(
- other: kauldron.metrics.base_state._SelfT,
Returns a new state that is the accumulation of self and other.
- compute() kauldron.metrics.base_state._SelfT[source]
Returns the concatenated values.
- replace(**updates)
Returns a new object replacing the specified fields with new values.