kd.metrics.state_field

kd.metrics.state_field#

kauldron.metrics.state_field(
*,
default: typing.Any = <dataclasses._MISSING_TYPE object>,
**kwargs,
) Any[source]

Defines a AutoState data-field that is merged by calling its merge method.

This is useful for reusing other States as fields.

Usage:

@flax.struct.dataclass
class AggregateState(AutoState):
  state_a: StateA = state_field()
  state_b: StateB = state_field()

  def compute(self):
    return {"a": state_a.compute(), "b": state_b.compute()}
Parameters:
  • default – The default value of the field.

  • **kwargs – Additional arguments to pass to the dataclasses.field.

Returns:

A dataclasses.Field instance with additional metadata that marks this field as a pytree_node for jax and sets the field merger to _Merge(axis=axis, num_field=num_field).