kd.metrics.concat_field#
- kauldron.metrics.concat_field(
- *,
- axis: int = 0,
- default: typing.Any = <dataclasses._MISSING_TYPE object>,
- **kwargs,
Defines a AutoState data-field that is merged by concatenation.
During merge the data is converted to numpy and kept in a tuple of arrays. That way this data does not take up memory on device. The final compute() method concatenates the arrays along the given axis.
Usage:
@flax.struct.dataclass class CollectTokens(AutoState): # merged along token axis ('n') by concatenation tokens: Float['b n d'] = concat_field(axis=1)
- Parameters:
axis – The axis along which to concatenate the two arrays. Defaults to 0.
default – The default value of the field.
**kwargs – Additional arguments to pass to the dataclasses.field.
- Returns:
A dataclasses.Field instance with additional metadata that marks this field as a pytree_node for jax and sets the field merger to _Concatenate(axis=axis).