Train, eval, randomness#
Evaluation#
Use eval#
Eval can be defined on the evals attribute of kd.train.Trainer:
trainer = kd.train.Trainer(
evals={
'eval': kd.evals.Evaluator(
run=kd.evals.EveryNSteps(100),
num_batches=None,
ds=_make_ds(training=False),
metrics={},
)
}
)
If kd.evals.Evaluator does not define losses, metrics, summaries, those are
reused from train.
Where/how the kd.evals.Evaluator is run can be specified through the run=
kwargs. Evaluators can run:
Within the
trainjob:EveryNSteps: Run evaluation everyXstepsOnce: Run a single evaluation afterXsteps
Evaluators run in a standalone job can be grouped together through the
job_group='group_name' attribute. This allow to save resources by sharing the
same job for multiple evaluators.
The StandaloneXxx supports all kxm.Job parameters, if you need to run
evaluator on a different platform,…
See mnist_standalone_eval.py for an example.
Start an eval-only job#
Sometimes, you only want to run evaluation on a trainer from a previous
Kauldron experiment.
This can be achieved through kd.train.Trainer.eval_only():
def config():
return kd.train.Trainer.eval_only(
evals = {
'my_eval': kd.evals.Evaluator(
run=kd.evals.StandaloneLastCheckpoint(),
...,
),
}
)
See mnist_eval_only.py for an example.
Note: kd.train.Trainer.eval_only() only works when used inside konfig.
Train / eval in Module#
Model can detect if they are in training / eval mode by using the
kd.nn.train_property.
class MyModel(nn.Module):
# Create a `@property` that will look-up the global `is_training` value
# when called
is_training = kd.nn.train_property() # No annotations here !!
@nn.compact
def __call__(self, x):
# Inside the methods, `self.is_training` can be called
if self.is_training:
rng = self.make_rng('default')
x = jax.random.choice(rng, x)
# `kd.nn.Dropout` supports `is_training` by default (no need to
# propagate `deterministic=`)
x = kd.nn.Dropout(0.5)(x)
return x
The self.is_training value is set globally in model.apply / model.init for
all submodules. No more deterministic kwargs to propagate through your modules
!!
model = MyModel()
params = model.init(..., is_training_property=True)
y = model.apply(..., is_training_property=False) # Eval
Inside a module, you can overwrite the is_training value with the
kd.nn.set_train_property contextmanager:
class MyModule(nn.Module):
@nn.compact
def __call__(self, x):
with kd.nn.set_train_property(False):
x = self.pretrained_encoder(x)
kd.nn.set_train_property can also be used to call a Kauldron model inside a
non-Kauldron model (to propagate the train / deterministic kwarg to the
model).
Training#
Create the trainer#
The root trainer object is kd.train.Trainer which defines the model, datasets,
metrics, losses,…
See mnist_autoencoder.py for an example.
High level API#
The Config can be run by calling the .train() method. It will take care of
everything (checkpoint, eval, summaries,…).
trainer.train()
Mid level API#
If you only need to run the training loop:
state = trainer.init_state()
for batch in trainer.train_ds.device_put(trainer.sharding.ds):
state, aux = trainer.trainstep.step(state, batch)
The .device_put() is chained with the dataset to put examples on devices (
default to kd.sharding.FIRST_DIM).
Randomness#
Determinism#
Kauldron uses a global seed (trainer.seed = 42) that is then split into the
various sub-components (dataset, model,…). For more control, the seed can also
be explicitly set inside the submodules (e.g. trainer.train_ds.seed = 42)
Rng streams#
By default, the following rng streams are created:
params: Only during.init()dropout: Fornn.Dropout, only available in training (not eval).default: Defaultrngstream, only available in training (not eval).
If you need custom streams, or need to overwrite the default values. You can set
the rng_streams attribute of kd.train.Trainer to kd.train.RngStreams. Note
that the kd.train.RngStreams will be merged with the default streams (so
you don’t need to re-specify params,…):
cfg = kd.train.Trainer()
cfg.rng_streams = kd.train.RngStreams([
# Overwrite `dropout` stream to only be activated in `eval`
kd.train.RngStream('dropout', train=False, eval=True),
# Add a custom stream (by default only on `train`)
kd.train.RngStream('my_custom_stream'),
])
To get the {'dropout': rng, ...} values, call the rng_streams.train_rngs(),
.eval_rngs() or .init_rngs().
params = model.init(rng_streams.init_rngs(), ...)
@jax.jit
def forward(step, params, batch):
rngs = rng_streams.train_rngs(step) # Create the rng for current `step`
return model.apply(params, batch, rngs=rngs)
Use cases#
GAN & Multi optimizers#
Training on multi optimizer can be done using kd.contrib.train.multi_optimizer
and trainstep=kd.contrib.train.MultiTrainStep().
trainer = kd.train.Trainer(
...,
model=MyGan(
generator=MyGenerator(),
discriminator=MyDiscriminator(),
),
# Define the loss for the generator and the discriminator
losses={
'discriminator': kd.losses.L2(...),
'generator': kd.losses.L2(...),
},
optimizer=kd.contrib.train.multi_optimizer(
# Using `kd.optim.partial_update`, you can mask out which weights
# each of the optimizer will be applied too.
discriminator=kd.optim.partial_update(
optimizer=optax.adam(1e-4),
mask=kd.optim.select('discriminator'),
),
generator=kd.optim.partial_update(
optimizer=optax.adam(1e-4),
mask=kd.optim.select('generator'),
),
),
# Using `kd.contrib.train.multi_optimizer` require to use `MultiTrainStep`
trainstep=kd.contrib.train.MultiTrainStep(),
)