kd.train.Context#
- class kauldron.train.Context(
- step: int,
- batch: Any,
- params: Any = None,
- collections: Any = None,
- preds: Any = None,
- interms: Any = None,
- loss_states: Any = None,
- loss_total: Any = None,
- grads: Any = None,
- updates: Any = None,
- opt_state: Any = None,
- metric_states: Any = None,
- summary_states: Any = None,
Bases:
objectNamespace for retrieving information with path-based keys.
The context is progressively filled during the training/eval step.
# Initial context contain the params, batch,... ctx = kd.train.Context.from_state_and_batch(state=state, batch=batch) # Add pred, interms, loss_states,... loss, ctx = model_with_aux.forward(ctx, ...)
- step
The global step number. Used for evaluating schedules etc.
- Type:
int
- batch
The input batch as returned from the data iterator.
- Type:
Any
- params
The parameters of the model. (available after the init)
- Type:
Any
- collections
Other variable collections (such as batch norm statistics).
- Type:
Any
- preds
The output of the model. (available after the model has been applied, e.g. for losses and metrics)
- Type:
Any
- interms
The intermediate outputs of the model as returned by model.apply(…, capture_intermediates=True). (available after the model has been applied, e.g. for losses and metrics)
- Type:
Any
- loss_states
All the states of the losses as returned by
kd.losses.compute_losses. (available after the forward pass)- Type:
Any
- loss_total
The total loss value for that step (float).
- Type:
Any
- grads
The gradients of the loss_values[‘total’] wrt. params. (available after the backward pass, e.g. for metrics)
- Type:
Any
- updates
The transformed gradients as returned by the optimizer. (available after the backward pass, e.g. for metrics)
- Type:
Any
- opt_state
The state of the optimizer prior to the update. (available after the backward pass, e.g. for metrics). The old state is chosen to be consistent with parameters which are also pre-update.
- Type:
Any
- metric_states
The states of the metrics (after the backward pass)
- Type:
Any
- summary_states
The states of the summaries (after the backward pass)
- Type:
Any
- step: int
- batch: Any
- params: Any = None
- collections: Any = None
- preds: Any = None
- interms: Any = None
- loss_states: Any = None
- loss_total: Any = None
- grads: Any = None
- updates: Any = None
- opt_state: Any = None
- metric_states: Any = None
- summary_states: Any = None
- classmethod from_state_and_batch(
- *,
- state: train_step.TrainState,
- batch: Any,
- flatten() dict[str, Any][source]
- get_aux_state(
- *,
- return_losses: bool = False,
- return_metrics: bool = False,
- return_summaries: bool = False,
Returns the auxiliaries for the step.
- replace(**updates)
Returns a new object replacing the specified fields with new values.