kd.metrics.truncate_field#
- kauldron.metrics.truncate_field(
- *,
- num_field: str,
- axis: int | None = 0,
- default: typing.Any = <dataclasses._MISSING_TYPE object>,
- **kwargs,
Defines a AutoState data-field that is merged by truncation.
During merge the data is converted to numpy and concatenated along the given axis. It is then truncated to the number of elements given by the num_field of its state. Useful for metrics that need to collect the first few elements of a tensor, e.g. the first few images for plotting.
Usage:
@flax.struct.dataclass class CollectFirstKImages(AutoState): num_images: int images: Float['n h w 3'] = truncate_field(num_field="num_images")
- Parameters:
num_field – The name of the field (in the state) that determines the number of elements to keep.
axis – The axis along which to concatenate and truncate the two arrays. Defaults to 0.
default – The default value of the field.
**kwargs – Additional arguments to pass to the dataclasses.field.
- Returns:
A dataclasses.Field instance with additional metadata that marks this field as a pytree_node for jax and sets the field merger to _Truncate(axis=axis, num_field=num_field).