kd.testing.assert_step_specs#
- kauldron.testing.assert_step_specs(
- trainer: kauldron.train.trainer_lib.Trainer,
Check the train step run correctly (fast).
This function run a single trainer.trainstep.step. This use jax.eval_shape so no computation is actually executed (only shape are checked).
This requires trainer.train_ds.element_spec to be available, you’ll likely need to mock the dataset. For example if using a TFDS dataset, you can use tfds.testing.mock_data:
cfg = my_config.get_config() ... # Eventually mutate the `cfg` with tfds.testing.mock_data(): kd.testing.assert_step_specs(trainer)
- Parameters:
trainer – The trainer to test.