kd.metrics.Norm

kd.metrics.Norm#

class kauldron.metrics.Norm(
*,
tensor: typing.Annotated[typing.Any,
<object object at 0x7824c478ba80>] = '__KEY_REQUIRED__',
mask: typing.Annotated[typing.Any,
<object object at 0x7824c478ba80>] | None = None,
axis: None | int | tuple[int,
int] = -1,
ord: float | int | None = None,
aggregation_type: typing.Literal['average',
'concat'] | None = None,
)[source]

Bases: kauldron.metrics.base.Metric

Wraps jnp.linalg.norm to compute the average norm for given tensors.

Computes jnp.linalg.norm for the array corresponding to the “tensor” key, and averages the value over remaining dimensions (taking masking into account).

See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.norm.html

tensor

kontext.Key for the tensor to compute the norm over.

Type:

Any

mask

Optional key for masking out some of the tensors (i.e. ignore them in the averaging).

Type:

Any | None

axis

Axis over which to compute the norm. If axis is an integer, it specifies the axis of x along which to compute the vector norms. If axis is a 2-tuple, it specifies the axes that hold 2-D matrices, and the matrix norms of these matrices are computed. If axis is None then either a vector norm (when x is 1-D) or a matrix norm (when x is 2-D) is returned.

Type:

None | int | tuple[int, int]

ord

Order of the norm. Possible values: None, “fro”, “nuc”, np.inf, -np.inf, -2, -1, 0, or any integer or float. See np.linalg.norm.

Type:

float | int | None

aggregation_type

How to aggregate the norms in TreeReduce. Average will compute the average of the norms. Concat will compute the norm as if all nodes of a tree were concatenated into a single vector. Average by default.

Type:

Literal[‘average’, ‘concat’] | None

tensor: Annotated[Any, <object object at 0x7824c478ba80>] = '__KEY_REQUIRED__'
mask: Annotated[Any, <object object at 0x7824c478ba80>] | None = None
axis: None | int | tuple[int, int] = -1
ord: float | int | None = None
aggregation_type: Literal['average', 'concat'] | None = None
class State(
total: Float[''],
count: Float[''],
*,
parent: _MetricT = _EMPTY_TYPE.EMPTY,
)[source]

Bases: kauldron.metrics.base_state.AverageState[Norm]

Wrapper around AverageState for Norm.

merge(
other: kauldron.metrics.base_state.AverageState,
) kauldron.metrics.base_state.AverageState[source]

Returns a new state that is the accumulation of self and other.

Parameters:

other – A State whose intermediate values should be accumulated onto the values of self.

Returns:

A new State that accumulates the value from both self and other.

compute() jaxtyping.Float[Array, ''] | jaxtyping.Float[ndarray, ''][source]

Computes final metrics from intermediate values.

replace(**updates)

Returns a new object replacing the specified fields with new values.

get_state(
tensor: jaxtyping.Float[Array, '*any'] | jaxtyping.Float[ndarray, '*any'],
mask: jaxtyping.Bool[Array, '*#any'] | jaxtyping.Bool[ndarray, '*#any'] | jaxtyping.Float[Array, '*#any'] | jaxtyping.Float[ndarray, '*#any'] | None = None,
) kauldron.metrics.stats.Norm.State[source]
empty() kauldron.metrics.base.Metric.State[source]