kd.metrics.Metric#
- class kauldron.metrics.Metric[source]
Bases:
abc.ABCBase class for metrics.
Usage:
metric = kd.metrics.Norm() # Initialize the metric state = metric.get_state(tensor=x) state = state.merge(other_state) # States can be accumulated loss = state.compute() # Get final value
All metric implementations should be dataclasses that inherit from this class and:
Overwrite the Metric.State class by inheriting from an appropriate
kd.metrics.Statethat collects and aggregates the required information. In most cases this will either be:kd.metrics.AverageState(for simple averaging of a value),kd.metrics.CollectingState(for metrics that need to collect andconcatenate model outputs over many batches)
Define a set of
kd.kontext.Keyannotated fields that are used to set the paths for gathering information from the train/eval context.Override the get_state(…) method which should take arguments with the same names as the keys defined in 2). This method will usually be executed on device within a pmap. It should return an instance of
State(1).Optionally override the State.compute(…) method which returns the final value of the metric. This method will be executed outside of jit/pmap and can thus make use of external libraries to perform its computation.
- class State(*, parent: '_MetricT' = <_EMPTY_TYPE.EMPTY: 1>)[source]
Bases:
kauldron.metrics.base_state.State- abstractmethod merge(
- other: kauldron.metrics.base_state._SelfT,
Returns a new state that is the accumulation of self and other.
- replace(**updates)
Returns a new object replacing the specified fields with new values.
- abstractmethod get_state(
- **kwargs,
- empty() kauldron.metrics.base.Metric.State[source]
- get_state_from_context(
- context: Any,