kd.metrics.concat_field

kd.metrics.concat_field#

kauldron.metrics.concat_field(
*,
axis: int = 0,
default: typing.Any = <dataclasses._MISSING_TYPE object>,
**kwargs,
)[source]

Defines a AutoState data-field that is merged by concatenation.

During merge the data is converted to numpy and kept in a tuple of arrays. That way this data does not take up memory on device. The final compute() method concatenates the arrays along the given axis.

Usage:

@flax.struct.dataclass
class CollectTokens(AutoState):
  # merged along token axis ('n') by concatenation
  tokens: Float['b n d'] = concat_field(axis=1)
Parameters:
  • axis – The axis along which to concatenate the two arrays. Defaults to 0.

  • default – The default value of the field.

  • **kwargs – Additional arguments to pass to the dataclasses.field.

Returns:

A dataclasses.Field instance with additional metadata that marks this field as a pytree_node for jax and sets the field merger to _Concatenate(axis=axis).