kd.train

kd.train#

[[Source]]

Train.

Symbols#

Class#

kd.train.Auxiliaries

Wrapper around the losses, summaries and metrics.

kd.train.AuxiliariesOutput

Auxiliaries final values (after merge and compute).

kd.train.AuxiliariesState

Auxiliaries (intermediate states to be accumulated).

kd.train.Context

Namespace for retrieving information with path-based keys.

kd.train.KDMetricWriter

Writes summaries to logs, tf_summaries and datatables.

kd.train.RngStream

Info on one rng stream.

kd.train.RngStreams

Manager of rng streams.

kd.train.Setup

Setup/environment options.

kd.train.TqdmInfo

TqdmInfo(*, desc: ‘str’ = ‘train’, log_xm: ‘bool’ = True)

kd.train.TrainState

Data structure for checkpointing the model.

kd.train.TrainStep

Base Training Step.

kd.train.Trainer

Base trainer class.

Function#

kd.train.forward

Forward pass of the model.

kd.train.forward_with_loss

Forward pass of the model, including losses.