kd.train.forward_with_loss

kd.train.forward_with_loss#

kauldron.train.forward_with_loss(
context: kauldron.train.context.Context,
*,
model: flax.linen.module.Module,
losses: Mapping[str, kauldron.losses.base.Loss],
rngs: dict[str, kauldron.random.random.PRNGKey],
is_training: bool,
) tuple[float, kauldron.train.context.Context][source]

Forward pass of the model, including losses.

Parameters:
  • context – Context to use for the forward pass. Should contain params, batch, step, and collections (and optionally opt_state).

  • model – Model to use for the forward pass.

  • losses – Losses to compute.

  • rngs – Random numbers to use for the forward pass.

  • is_training – Whether to run the model in training or eval mode.

Returns:

Total loss. context: Context with the updated loss_total, loss_states, preds,

interms, and collections.

Return type:

loss_total