kd.train.forward#
- kauldron.train.forward(
- context: kauldron.train.context.Context,
- *,
- model: flax.linen.module.Module,
- rngs: dict[str, kauldron.random.random.PRNGKey],
- is_training: bool,
- method: str | None = None,
Forward pass of the model.
- Parameters:
context – Context to use for the forward pass. Should contain params, batch, step, and
collections(and optionally opt_state).model – Model to use for the forward pass.
rngs – Random numbers to use for the forward pass.
is_training – Whether to run the model in training or eval mode.
method – Name of the flax model method to call (defaults to __call__).
- Returns:
Context with the updated preds, interms, and
collections.- Return type:
context