kd.metrics.Norm#
- class kauldron.metrics.Norm(
- *,
- tensor: typing.Annotated[typing.Any,
- <object object at 0x76412092fb90>] = '__KEY_REQUIRED__',
- mask: typing.Annotated[typing.Any,
- <object object at 0x76412092fb90>] | None = None,
- axis: None | int | tuple[int,
- int] = -1,
- ord: float | int | None = None,
- aggregation_type: typing.Literal['average',
- 'concat'] | None = None,
Bases:
kauldron.metrics.base.MetricWraps jnp.linalg.norm to compute the average norm for given tensors.
Computes jnp.linalg.norm for the array corresponding to the “tensor” key, and averages the value over remaining dimensions (taking masking into account).
See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.norm.html
- tensor
kontext.Key for the tensor to compute the norm over.
- Type:
Any
- mask
Optional key for masking out some of the tensors (i.e. ignore them in the averaging).
- Type:
Any | None
- axis
Axis over which to compute the norm. If axis is an integer, it specifies the axis of x along which to compute the vector norms. If axis is a 2-tuple, it specifies the axes that hold 2-D matrices, and the matrix norms of these matrices are computed. If axis is None then either a vector norm (when x is 1-D) or a matrix norm (when x is 2-D) is returned.
- Type:
None | int | tuple[int, int]
- ord
Order of the norm. Possible values: None, “fro”, “nuc”, np.inf, -np.inf, -2, -1, 0, or any integer or float. See np.linalg.norm.
- Type:
float | int | None
- aggregation_type
How to aggregate the norms in TreeReduce. Average will compute the average of the norms. Concat will compute the norm as if all nodes of a tree were concatenated into a single vector. Average by default.
- Type:
Literal[‘average’, ‘concat’] | None
- tensor: Annotated[Any, <object object at 0x76412092fb90>] = '__KEY_REQUIRED__'
- mask: Annotated[Any, <object object at 0x76412092fb90>] | None = None
- axis: None | int | tuple[int, int] = -1
- ord: float | int | None = None
- aggregation_type: Literal['average', 'concat'] | None = None
- class State(
- total: Float[''],
- count: Float[''],
- *,
- parent: _MetricT = _EMPTY_TYPE.EMPTY,
Bases:
kauldron.metrics.base_state.AverageState[Norm]Wrapper around AverageState for Norm.
- merge(
- other: kauldron.metrics.base_state.AverageState,
Returns a new state that is the accumulation of self and other.
- compute() jaxtyping.Float[Array, ''] | jaxtyping.Float[ndarray, ''][source]
Computes final metrics from intermediate values.
- replace(**updates)
Returns a new object replacing the specified fields with new values.
- get_state(
- tensor: jaxtyping.Float[Array, '*any'] | jaxtyping.Float[ndarray, '*any'],
- mask: jaxtyping.Bool[Array, '*#any'] | jaxtyping.Bool[ndarray, '*#any'] | jaxtyping.Float[Array, '*#any'] | jaxtyping.Float[ndarray, '*#any'] | None = None,
- empty() kauldron.metrics.base.Metric.State[source]