kd.metrics.state_field#
- kauldron.metrics.state_field(
- *,
- default: typing.Any = <dataclasses._MISSING_TYPE object>,
- **kwargs,
Defines a AutoState data-field that is merged by calling its merge method.
This is useful for reusing other States as fields.
Usage:
@flax.struct.dataclass class AggregateState(AutoState): state_a: StateA = state_field() state_b: StateB = state_field() def compute(self): return {"a": state_a.compute(), "b": state_b.compute()}
- Parameters:
default – The default value of the field.
**kwargs – Additional arguments to pass to the dataclasses.field.
- Returns:
A dataclasses.Field instance with additional metadata that marks this field as a pytree_node for jax and sets the field merger to _Merge(axis=axis, num_field=num_field).