kd.metrics.Ari

kd.metrics.Ari#

class kauldron.metrics.Ari(
*,
num_instances_true: int,
num_instances_pred: int,
ignored_ids: typing.Sequence[int] | int | None = None,
predictions: typing.Annotated[typing.Any,
<object object at 0x7824c478ba80>] = '__KEY_REQUIRED__',
labels: typing.Annotated[typing.Any,
<object object at 0x7824c478ba80>] = '__KEY_REQUIRED__',
mask: typing.Annotated[typing.Any,
<object object at 0x7824c478ba80>] | None = None,
)[source]

Bases: kauldron.metrics.base.Metric

Adjusted Rand Index (ARI) computed from predictions and labels.

ARI is a similarity score to compare two clusterings. ARI returns values in the range [-1, 1], where 1 corresponds to two identical clusterings (up to permutation), i.e. a perfect match between the predicted clustering and the ground-truth clustering. A value of (close to) 0 corresponds to chance. Negative values corresponds to cases where the agreement between the clusterings is less than expected from a random assignment.

In this implementation, we use ARI to compare predicted instance segmentation masks (including background prediction) with ground-truth segmentation annotations.

num_instances_true: int
num_instances_pred: int
ignored_ids: Sequence[int] | int | None = None
predictions: Annotated[Any, <object object at 0x7824c478ba80>] = '__KEY_REQUIRED__'
labels: Annotated[Any, <object object at 0x7824c478ba80>] = '__KEY_REQUIRED__'
mask: Annotated[Any, <object object at 0x7824c478ba80>] | None = None
class State(
total: "Float['']",
count: "Float['']",
*,
parent: '_MetricT' = <_EMPTY_TYPE.EMPTY: 1>,
)[source]

Bases: kauldron.metrics.base_state.AverageState

merge(
other: kauldron.metrics.base_state.AverageState,
) kauldron.metrics.base_state.AverageState[source]

Returns a new state that is the accumulation of self and other.

Parameters:

other – A State whose intermediate values should be accumulated onto the values of self.

Returns:

A new State that accumulates the value from both self and other.

replace(**updates)

Returns a new object replacing the specified fields with new values.

get_state(
predictions: jaxtyping.Integer[Array, '*b t h w 1'] | jaxtyping.Integer[ndarray, '*b t h w 1'],
labels: jaxtyping.Integer[Array, '*b t h w 1'] | jaxtyping.Integer[ndarray, '*b t h w 1'],
mask: jaxtyping.Bool[Array, '*b 1'] | jaxtyping.Bool[ndarray, '*b 1'] | jaxtyping.Float[Array, '*b 1'] | jaxtyping.Float[ndarray, '*b 1'] | None = None,
) kauldron.metrics.clustering.Ari.State[source]
empty() kauldron.metrics.base.Metric.State[source]