kd.metrics.AutoState

kd.metrics.AutoState#

class kauldron.metrics.AutoState(
*,
parent: kauldron.metrics.base_state._MetricT = _EMPTY_TYPE.EMPTY,
)[source]

Bases: kauldron.metrics.base_state.State[kauldron.metrics.auto_state._MetricT]

Flexible base class for conveniently defining custom states.

Subclasses of AutoState have to use the @flax.struct.dataclass decorator and can define two kinds of fields:

  1. Data fields are defined by the sum_field, min_field, max_field, `concat_field or truncate_field functions. E.g. d : Float[‘n’] = sum_field(). Data fields are PyTrees of Jax arrays. They are merged by summing, minimum, maximum, concatenating or truncating, respectively.

  2. All other fields are static fields which are not merged, and instead checked for equality during merge. They are also not pytree nodes, so they are not touched by jax transforms (but can lead to recompilation if changed). These can be useful to store some parameters of the metric, e.g. the number of elements to keep. Note that static fields are rarely needed, since it is usually better to define static params in the corresponding metric and access them through the parent field.

The compute method by default returns a namespace with the final data field values (as np.ndarrays). It can be overridden to return any other object.

Example: .. code-block:

@flax.struct.dataclass(kw_only=True)
class CustomErrorSummaryState(kd.metrics.AutoState):
  # static-fields
  cmap: str = "coolwarm"
  num_to_keep: int = 5
  num_buckets: int = 30

  # data-fields
  error: Float['n h w 1'] = kd.metrics.truncate_field(num_field="num_to_keep")
  summed_error: Float[''] = kd.metrics.sum_field()
  total_error: Float[''] = kd.metrics.sum_field()
  min_error: Float[''] = kd.metrics.min_field()
  max_error: Float[''] = kd.metrics.max_field()
  error_hist: Float['n'] = kd.metrics.concat_field()

  def compute(self):
    error_img = mediapy.to_rgb(self.error, cmap=self.cmap)
    return {
        "avg_error": self.summed_error / self.total_error,
        "min_error": self.min_error,
        "max_error": self.max_error,
        "error_images": error_img,
        "error_hist": kd.summaries.Histogram(
            tensor=self.error_hist, num_buckets=self.num_buckets
        ),
    }
classmethod empty() Self[source]

Returns an empty instance (i.e. .merge(State.empty()) is a no-op).

merge(
other: kauldron.metrics.auto_state._SelfT,
) kauldron.metrics.auto_state._SelfT[source]

Checks static fields for equality and merges data-fields.

finalize() kauldron.metrics.auto_state._SelfT[source]

Finalizes the state (e.g. concatenate and converting to np.ndarrays).

compute() kauldron.metrics.auto_state._SelfT[source]

Computes final metrics from intermediate values.