kd.losses.Loss

kd.losses.Loss#

class kauldron.losses.Loss(*, step: typing.Annotated[typing.Any, <object object at 0x76412092fb90>] = 'step', mask: typing.Annotated[typing.Any, <object object at 0x76412092fb90>] | None = None, weight: int | float | typing.Callable[[int], float] = 1.0, normalize_by: typing.Literal['mask', 'values'] = 'mask')[source]

Bases: kauldron.metrics.base.Metric, abc.ABC

Base class for losses which handles masks, averaging, and loss-weight.

Subclasses should implement get_values which should compute the loss value and return self.State.from_values(values=values, mask=mask)

Example

# Instantiate Loss with parameters and required keys:
loss = SoftmaxCrossEntropy(logits="preds.logits", labels="batch.labels")

# Shorthand computation given a context object ctx which contains
# the logits and labels in the previously specified paths:
value = loss(context=ctx)

value = loss(logits=..., labels=...)  # directly passing logits and labels

# The above shorthand is only recommended for interactive usage.
# In training code use get_state and compute:
loss_state = loss.get_state_from_context(ctx)        # from context
loss_state = loss.get_state(logits=..., labels=...)  # directly

value = loss_state.compute()
step

The key for determining the current step (for weight schedules).

Type:

Any

weight

Determines the weight of this loss term for the total loss. Can be either a float constant or a schedule (a function from step-number to float). Defaults to 1.0.

Type:

int | float | Callable[[int], float]

mask

Optional key for a mask of values in [0, 1] which can be used to ignore parts of the batch. The shape of the mask should be broadcastable to the output shape of the compute function. Defaults to None. Losses can be computed in two ways: 1. directly by passing the required arguments to the __call__ method. 2. using apply_to_context to automatically gather the arguments from a given context. This takes into account the weight of the loss.

Type:

Any | None

normalize_by

Whether to divide the total loss over the number of mask elements (normalize_by = “mask”), or over the total number of values (normalize_by = “values”). Defaults to “mask”.

Type:

Literal[‘mask’, ‘values’]

step: Annotated[Any, <object object at 0x76412092fb90>] = 'step'
mask: Annotated[Any, <object object at 0x76412092fb90>] | None = None
weight: int | float | Callable[[int], float] = 1.0
normalize_by: Literal['mask', 'values'] = 'mask'
State[source]

alias of kauldron.losses.base.AllReduceMean

abstractmethod get_values(
*args,
**kwargs,
) jaxtyping.Shaped[Array, '...'] | jaxtyping.Shaped[ndarray, '...'][source]

Compute the loss values (before masking, averaging and weighting).

Subclasses need to implement this method. :param *args: Any required arguments (names should match kontext.Key annotations) :param **kwargs: Any arguments (names should match kontext.Key annotations)

Returns:

A jnp.Array of loss values compatible in shape with any desired masking.

get_state(
*args,
mask: jaxtyping.Shaped[Array, '...'] | jaxtyping.Shaped[ndarray, '...'] | None = None,
step: int | None = None,
**kwargs,
) kauldron.losses.base.AllReduceMean[source]

Compute the loss state, and takes care of masking and loss-weight.

The Loss.State is AllReduceMean by default which keeps track of a single scalar loss value, but ensures correctly averaging even while using masks.

Parameters:
  • *args – Positional arguments to be passed on to get_values.

  • mask – An optional mask to exclude some of the loss values from the total. The shape of this mask needs to be broadcastable to the shape of values returned from get_values. A value of 1 means that a value should be included (and 0 to exclude).

  • step – The current step to be used to compute the loss-weight if self.weight is set to a schedule. Otherwise step is ignored.

  • **kwargs – Keyword arguments to be passed on to get_values.

Returns:

An instance of Loss.State (AllReduceMean by default) which keeps track of a single scalar loss value, but ensures correctly averaging even while using masks. This final loss value can be computed from this state by calling state.compute(). Optionally the state first can be reduced (to remove the device dimension after pmap) or merged with other (previous) loss states.

get_state_from_context(
context: Any,
) kauldron.losses.base.AllReduceMean[source]

Compute the loss-state by auto-filling args from given context.

This is a wrapper around get_state that gathers the required arguments from the given context, using kontext.Key`s of the loss. For example if the loss has a `target : kontext.Key set to “batch.label”, then context.batch[“label”] will be passed to the get_state function of the loss.

Parameters:

context – A context object that holds the information (e.g. the current batch and the model outputs) against which the kontext.Keys of the loss are resolved.

Returns:

An instance of Loss.State. See get_state for details.

get_weight(
step: int | None = None,
) jaxtyping.Float[Array, ''] | jaxtyping.Float[ndarray, ''][source]

Return the weight of this loss at the given step number.

Parameters:

step – If the loss is set to a schedule, then this is the step used for computing the weight. Otherwise it is unused/optional.

Returns:

The weight of this loss term for the total loss (for the given step).

empty() kauldron.metrics.base.Metric.State[source]