kd.train.Context

kd.train.Context#

class kauldron.train.Context(
step: int,
batch: Any,
params: Any = None,
collections: Any = None,
preds: Any = None,
interms: Any = None,
loss_states: Any = None,
loss_total: Any = None,
grads: Any = None,
updates: Any = None,
opt_state: Any = None,
metric_states: Any = None,
summary_states: Any = None,
)[source]

Bases: object

Namespace for retrieving information with path-based keys.

The context is progressively filled during the training/eval step.

# Initial context contain the params, batch,...
ctx = kd.train.Context.from_state_and_batch(state=state, batch=batch)

# Add pred, interms, loss_states,...
loss, ctx = model_with_aux.forward(ctx, ...)
step

The global step number. Used for evaluating schedules etc.

Type:

int

batch

The input batch as returned from the data iterator.

Type:

Any

params

The parameters of the model. (available after the init)

Type:

Any

collections

Other variable collections (such as batch norm statistics).

Type:

Any

preds

The output of the model. (available after the model has been applied, e.g. for losses and metrics)

Type:

Any

interms

The intermediate outputs of the model as returned by model.apply(…, capture_intermediates=True). (available after the model has been applied, e.g. for losses and metrics)

Type:

Any

loss_states

All the states of the losses as returned by kd.losses.compute_losses. (available after the forward pass)

Type:

Any

loss_total

The total loss value for that step (float).

Type:

Any

grads

The gradients of the loss_values[‘total’] wrt. params. (available after the backward pass, e.g. for metrics)

Type:

Any

updates

The transformed gradients as returned by the optimizer. (available after the backward pass, e.g. for metrics)

Type:

Any

opt_state

The state of the optimizer prior to the update. (available after the backward pass, e.g. for metrics). The old state is chosen to be consistent with parameters which are also pre-update.

Type:

Any

metric_states

The states of the metrics (after the backward pass)

Type:

Any

summary_states

The states of the summaries (after the backward pass)

Type:

Any

step: int
batch: Any
params: Any = None
collections: Any = None
preds: Any = None
interms: Any = None
loss_states: Any = None
loss_total: Any = None
grads: Any = None
updates: Any = None
opt_state: Any = None
metric_states: Any = None
summary_states: Any = None
classmethod from_state_and_batch(
*,
state: train_step.TrainState,
batch: Any,
) Self[source]
flatten() dict[str, Any][source]
get_aux_state(
*,
return_losses: bool = False,
return_metrics: bool = False,
return_summaries: bool = False,
) kauldron.train.auxiliaries.AuxiliariesState[source]

Returns the auxiliaries for the step.

replace(**updates)

Returns a new object replacing the specified fields with new values.