kd.metrics.AutoState#
- class kauldron.metrics.AutoState(
- *,
- parent: kauldron.metrics.base_state._MetricT = _EMPTY_TYPE.EMPTY,
Bases:
kauldron.metrics.base_state.State[kauldron.metrics.auto_state._MetricT]Flexible base class for conveniently defining custom states.
Subclasses of AutoState have to use the @flax.struct.dataclass decorator and can define two kinds of fields:
Data fields are defined by the
sum_field,concat_fieldortruncate_fieldfunctions. E.g. d : Float[‘n’] = sum_field(). Data fields are pytrees of Jax arrays. They are merged by summing, concatenating or truncating, respectively.All other fields are static fields which are not merged, and instead checked for equality during merge. They are also not pytree nodes, so they are not touched by jax transforms (but can lead to recompilation if changed). These can be useful to store some parameters of the metric, e.g. the number of elements to keep. Note that static fields are rarely needed, since it is usually better to define static params in the corresponding metric and access them through the parent field.
The compute method by default returns a namespace with the final data field values (as np.ndarrays). It can be overridden to return any other object, but should still call super().compute() to finalize the data fields.
Example: .. code-block:
@flax.struct.dataclass(kw_only=True) class CustomErrorSummaryState(kd.metrics.AutoState): # static-fields cmap: str = "coolwarm" num_to_keep: int = 5 num_buckets: int = 30 # data-fields error: Float['n h w 1'] = kd.metrics.truncate_field(num_field="num_to_keep") summed_error: Float[''] = kd.metrics.sum_field() total_error: Float[''] = kd.metrics.sum_field() error_hist: Float['n'] = kd.metrics.concat_field() def compute(self): # NOTE: You should access the data-fields through super().compute() # rather than through self, to ensure that they are properly finalized. # (e.g. converted to np.ndarrays) data = super().compute() error_img = mediapy.to_rgb(data.error, cmap=self.cmap) return { "avg_error": data.summed_error / data.total_error, "error_images": error_img, "error_hist": kd.summaries.Histogram( tensor=data.error_hist, num_buckets=self.num_buckets ), }
- classmethod empty() Self[source]
Returns an empty instance (i.e. .merge(State.empty()) is a no-op).
- merge(
- other: kauldron.metrics.auto_state._SelfT,
Checks static fields for equality and merges data-fields.
- compute() kauldron.metrics.auto_state._SelfT[source]
Computes final metrics from intermediate values.