kd.train.KDMetricWriter

kd.train.KDMetricWriter#

class kauldron.train.KDMetricWriter(
*,
_fake_refs: type[_FakeRefsUnset] | dict[str,
_FakeRootCfg] = <class 'kauldron.utils.config_util._FakeRefsUnset'>,
workdir: str | epath.Path = _FakeRootCfg('cfg.workdir'),
collection: str = '$not_set$',
add_artifacts: bool = True,
)[source]

Bases: kauldron.train.metric_writer.MetadataWriter

Writes summaries to logs, tf_summaries and datatables.

Differs from the clu default metric writer in a few ways:
  • It divides summaries into two datatables: one for scalars and one for arrays to improve datatable access speed for flatboards.

  • Doesn’t write hyperparameters to the datatable to avoid clutter.

  • Does not write to XM-Measurements.

  • offers additional methods to write config, param_overview and element_spec

add_artifacts: bool = True
write_summaries(
step: int,
values: Mapping[str, kauldron.typing.array_types.Array],
metadata: Mapping[str, Any] | None = None,
) None[source]

Write arbitrary tensor summaries for the step.

write_scalars(
step: int,
scalars: Mapping[str, jaxtyping.Shaped[Array, ''] | jaxtyping.Shaped[ndarray, '']],
) None[source]

Write scalar values for the step.

write_images(
step: int,
images: Mapping[str, jaxtyping.Shaped[Array, 'n h w c'] | jaxtyping.Shaped[ndarray, 'n h w c']],
) None[source]

Write images for the step.

write_histograms(
step: int,
arrays: Mapping[str, kauldron.typing.array_types.Array],
num_buckets: Mapping[str, int] | None = None,
) None[source]

Write histograms for the step.

write_videos(
step: int,
videos: Mapping[str, jaxtyping.Shaped[Array, 'n t h w c'] | jaxtyping.Shaped[ndarray, 'n t h w c']],
) None[source]

Write videos for the step.

write_audios(
step: int,
audios: Mapping[str, jaxtyping.Float[Array, 'n t c'] | jaxtyping.Float[ndarray, 'n t c']],
*,
sample_rate: int,
) None[source]

Write audio samples for the step.

write_texts(
step: int,
texts: Mapping[str, str],
) None[source]

Write text summaries for the step.

write_pointcloud(
step: int,
point_clouds: Mapping[str, jaxtyping.Shaped[Array, 'n 3'] | jaxtyping.Shaped[ndarray, 'n 3']],
*,
point_colors: Mapping[str, jaxtyping.Shaped[Array, 'n 3'] | jaxtyping.Shaped[ndarray, 'n 3']] | None = None,
configs: Mapping[str, str | float | bool | None] | None = None,
) None[source]

Write point cloud summaries for the step.

write_hparams(
hparams: Mapping[str, Any],
) None[source]

Write hyper parameters.

write_context_structure(
step: int,
trainer: trainer_lib.Trainer,
) None[source]

Write the context structure.

flush() None[source]
close() None[source]