kd.metrics.RocAuc

kd.metrics.RocAuc#

class kauldron.metrics.RocAuc(
*,
logits: typing.Annotated[typing.Any,
<object object at 0x76412092fb90>] = '__KEY_REQUIRED__',
labels: typing.Annotated[typing.Any,
<object object at 0x76412092fb90>] = '__KEY_REQUIRED__',
mask: typing.Annotated[typing.Any,
<object object at 0x76412092fb90>] | None = None,
unique_labels: typing.List[int] | None = None,
multi_class_mode: str = 'ovr',
)[source]

Bases: kauldron.metrics.base.Metric

Area Under the Receiver Operating Characteristic Curve (ROC AUC).

logits

The logits to evaluate.

Type:

Any

labels

The groundtruth labels.

Type:

Any

mask

Sample weights.

Type:

Any | None

unique_labels

If we are testing on a small subset of data and by chance it does not contain all classes, we need to provide the groundtruth labels separately. In case None, unique_labels will be determined from the labels.

Type:

List[int] | None

multi_class_mode

One-vs-Rest (“ovr”) or One-vs-One (“ovo”)

Type:

str

logits: Annotated[Any, <object object at 0x76412092fb90>] = '__KEY_REQUIRED__'
labels: Annotated[Any, <object object at 0x76412092fb90>] = '__KEY_REQUIRED__'
mask: Annotated[Any, <object object at 0x76412092fb90>] | None = None
unique_labels: List[int] | None = None
multi_class_mode: str = 'ovr'
class State(
labels: Int['*b 1'],
probs: Float['*b n'],
mask: Bool['*b 1'] | Float['*b 1'],
*,
parent: _MetricT = _EMPTY_TYPE.EMPTY,
)[source]

Bases: kauldron.metrics.auto_state.AutoState[RocAuc]

RocAuc state.

labels: Int['*b 1']
probs: Float['*b n']
mask: Bool['*b 1'] | Float['*b 1']
compute() float[source]

Computes final metrics from intermediate values.

merge(
other: kauldron.metrics.auto_state._SelfT,
) kauldron.metrics.auto_state._SelfT[source]

Checks static fields for equality and merges data-fields.

replace(**updates)

Returns a new object replacing the specified fields with new values.

get_state(
logits: jaxtyping.Float[Array, '*b n'] | jaxtyping.Float[ndarray, '*b n'],
labels: jaxtyping.Int[Array, '*b 1'] | jaxtyping.Int[ndarray, '*b 1'],
mask: jaxtyping.Bool[Array, '*b 1'] | jaxtyping.Bool[ndarray, '*b 1'] | jaxtyping.Float[Array, '*b 1'] | jaxtyping.Float[ndarray, '*b 1'] | None = None,
) kauldron.metrics.classification.RocAuc.State[source]
empty() kauldron.metrics.base.Metric.State[source]