kd.train.KDMetricWriter#
- class kauldron.train.KDMetricWriter(
- *,
- _fake_refs: type[_FakeRefsUnset] | dict[str,
- _FakeRootCfg] = <class 'kauldron.utils.config_util._FakeRefsUnset'>,
- workdir: str | epath.Path = _FakeRootCfg('cfg.workdir'),
- collection: str = '$not_set$',
- add_artifacts: bool = True,
Bases:
kauldron.train.metric_writer.MetadataWriterWrites summaries to logs, tf_summaries and datatables.
- Differs from the clu default metric writer in a few ways:
It divides summaries into two datatables: one for scalars and one for arrays to improve datatable access speed for flatboards.
Doesn’t write hyperparameters to the datatable to avoid clutter.
Does not write to XM-Measurements.
offers additional methods to write config, param_overview and element_spec
- add_artifacts: bool = True
- write_summaries(
- step: int,
- values: Mapping[str, kauldron.typing.array_types.Array],
- metadata: Mapping[str, Any] | None = None,
Write arbitrary tensor summaries for the step.
- write_scalars(
- step: int,
- scalars: Mapping[str, jaxtyping.Shaped[Array, ''] | jaxtyping.Shaped[ndarray, '']],
Write scalar values for the step.
- write_images(
- step: int,
- images: Mapping[str, jaxtyping.Shaped[Array, 'n h w c'] | jaxtyping.Shaped[ndarray, 'n h w c']],
Write images for the step.
- write_histograms(
- step: int,
- arrays: Mapping[str, kauldron.typing.array_types.Array],
- num_buckets: Mapping[str, int] | None = None,
Write histograms for the step.
- write_videos(
- step: int,
- videos: Mapping[str, jaxtyping.Shaped[Array, 'n t h w c'] | jaxtyping.Shaped[ndarray, 'n t h w c']],
Write videos for the step.
- write_audios(
- step: int,
- audios: Mapping[str, jaxtyping.Float[Array, 'n t c'] | jaxtyping.Float[ndarray, 'n t c']],
- *,
- sample_rate: int,
Write audio samples for the step.
- write_texts(
- step: int,
- texts: Mapping[str, str],
Write text summaries for the step.
- write_pointcloud(
- step: int,
- point_clouds: Mapping[str, jaxtyping.Shaped[Array, 'n 3'] | jaxtyping.Shaped[ndarray, 'n 3']],
- *,
- point_colors: Mapping[str, jaxtyping.Shaped[Array, 'n 3'] | jaxtyping.Shaped[ndarray, 'n 3']] | None = None,
- configs: Mapping[str, str | float | bool | None] | None = None,
Write point cloud summaries for the step.
- write_hparams(
- hparams: Mapping[str, Any],
Write hyper parameters.
- write_context_structure(
- step: int,
- trainer: trainer_lib.Trainer,
Write the context structure.
- flush() None[source]
- close() None[source]