kd.train.forward_with_loss#
- kauldron.train.forward_with_loss(
- context: kauldron.train.context.Context,
- *,
- model: flax.linen.module.Module,
- losses: Mapping[str, kauldron.losses.base.Loss],
- rngs: dict[str, kauldron.random.random.PRNGKey],
- is_training: bool,
Forward pass of the model, including losses.
- Parameters:
context – Context to use for the forward pass. Should contain params, batch, step, and
collections(and optionally opt_state).model – Model to use for the forward pass.
losses – Losses to compute.
rngs – Random numbers to use for the forward pass.
is_training – Whether to run the model in training or eval mode.
- Returns:
Total loss. context: Context with the updated loss_total, loss_states, preds,
interms, and
collections.- Return type:
loss_total