kd.metrics.Metric

kd.metrics.Metric#

class kauldron.metrics.Metric[source]

Bases: abc.ABC

Base class for metrics.

Usage:

metric = kd.metrics.Norm()  # Initialize the metric
state = metric.get_state(tensor=x)

state = state.merge(other_state)  # States can be accumulated

loss = state.compute()  # Get final value

All metric implementations should be dataclasses that inherit from this class and:

  1. Overwrite the Metric.State class by inheriting from an appropriate kd.metrics.State that collects and aggregates the required information. In most cases this will either be:

  2. Define a set of kd.kontext.Key annotated fields that are used to set the paths for gathering information from the train/eval context.

  3. Override the get_state(…) method which should take arguments with the same names as the keys defined in 2). This method will usually be executed on device within a pmap. It should return an instance of State (1).

  4. Optionally override the State.compute(…) method which returns the final value of the metric. This method will be executed outside of jit/pmap and can thus make use of external libraries to perform its computation.

class State(*, parent: '_MetricT' = <_EMPTY_TYPE.EMPTY: 1>)[source]

Bases: kauldron.metrics.base_state.State

abstractmethod merge(
other: kauldron.metrics.base_state._SelfT,
) kauldron.metrics.base_state._SelfT[source]

Returns a new state that is the accumulation of self and other.

Parameters:

other – A State whose intermediate values should be accumulated onto the values of self.

Returns:

A new State that accumulates the value from both self and other.

replace(**updates)

Returns a new object replacing the specified fields with new values.

abstractmethod get_state(
**kwargs,
) kauldron.metrics.base.Metric.State[source]
empty() kauldron.metrics.base.Metric.State[source]
get_state_from_context(
context: Any,
) kauldron.metrics.base.Metric.State[source]