kd.testing.assert_step_specs

kd.testing.assert_step_specs#

kauldron.testing.assert_step_specs(
trainer: kauldron.train.trainer_lib.Trainer,
) None[source]

Check the train step run correctly (fast).

This function run a single trainer.trainstep.step. This use jax.eval_shape so no computation is actually executed (only shape are checked).

This requires trainer.train_ds.element_spec to be available, you’ll likely need to mock the dataset. For example if using a TFDS dataset, you can use tfds.testing.mock_data:

cfg = my_config.get_config()

...  # Eventually mutate the `cfg`

with tfds.testing.mock_data():
  kd.testing.assert_step_specs(trainer)
Parameters:

trainer – The trainer to test.