kd.train.forward

kd.train.forward#

kauldron.train.forward(
context: kauldron.train.context.Context,
*,
model: flax.linen.module.Module,
rngs: dict[str, kauldron.random.random.PRNGKey],
is_training: bool,
method: str | None = None,
) kauldron.train.context.Context[source]

Forward pass of the model.

Parameters:
  • context – Context to use for the forward pass. Should contain params, batch, step, and collections (and optionally opt_state).

  • model – Model to use for the forward pass.

  • rngs – Random numbers to use for the forward pass.

  • is_training – Whether to run the model in training or eval mode.

  • method – Name of the flax model method to call (defaults to __call__).

Returns:

Context with the updated preds, interms, and collections.

Return type:

context