kd.metrics.AutoState#
- class kauldron.metrics.AutoState(
- *,
- parent: kauldron.metrics.base_state._MetricT = _EMPTY_TYPE.EMPTY,
Bases:
kauldron.metrics.base_state.State[kauldron.metrics.auto_state._MetricT]Flexible base class for conveniently defining custom states.
Subclasses of AutoState have to use the @flax.struct.dataclass decorator and can define two kinds of fields:
Data fields are defined by the
sum_field, min_field, max_field, `concat_field ortruncate_fieldfunctions. E.g. d : Float[‘n’] = sum_field(). Data fields are PyTrees of Jax arrays. They are merged by summing, minimum, maximum, concatenating or truncating, respectively.All other fields are static fields which are not merged, and instead checked for equality during merge. They are also not pytree nodes, so they are not touched by jax transforms (but can lead to recompilation if changed). These can be useful to store some parameters of the metric, e.g. the number of elements to keep. Note that static fields are rarely needed, since it is usually better to define static params in the corresponding metric and access them through the parent field.
The compute method by default returns a namespace with the final data field values (as np.ndarrays). It can be overridden to return any other object.
Example: .. code-block:
@flax.struct.dataclass(kw_only=True) class CustomErrorSummaryState(kd.metrics.AutoState): # static-fields cmap: str = "coolwarm" num_to_keep: int = 5 num_buckets: int = 30 # data-fields error: Float['n h w 1'] = kd.metrics.truncate_field(num_field="num_to_keep") summed_error: Float[''] = kd.metrics.sum_field() total_error: Float[''] = kd.metrics.sum_field() min_error: Float[''] = kd.metrics.min_field() max_error: Float[''] = kd.metrics.max_field() error_hist: Float['n'] = kd.metrics.concat_field() def compute(self): error_img = mediapy.to_rgb(self.error, cmap=self.cmap) return { "avg_error": self.summed_error / self.total_error, "min_error": self.min_error, "max_error": self.max_error, "error_images": error_img, "error_hist": kd.summaries.Histogram( tensor=self.error_hist, num_buckets=self.num_buckets ), }
- classmethod empty() Self[source]
Returns an empty instance (i.e. .merge(State.empty()) is a no-op).
- merge(
- other: kauldron.metrics.auto_state._SelfT,
Checks static fields for equality and merges data-fields.
- finalize() kauldron.metrics.auto_state._SelfT[source]
Finalizes the state (e.g. concatenate and converting to np.ndarrays).
- compute() kauldron.metrics.auto_state._SelfT[source]
Computes final metrics from intermediate values.