kd.data.Pipeline

kd.data.Pipeline#

class kauldron.data.Pipeline(
*,
_fake_refs: type[_FakeRefsUnset] | dict[str,
_FakeRootCfg] = <class 'kauldron.utils.config_util._FakeRefsUnset'>,
batch_size: int | None = None,
seed: int | collections.abc.Sequence[int] | numpy.ndarray | jaxtyping.UInt32[Array,
'2'] | jaxtyping.UInt32[ndarray,
'2'] | jax.Array | None = _FakeRootCfg('cfg.seed'),
)[source]

Bases: kauldron.data.data_utils.IterableDataset, kauldron.utils.config_util.UpdateFromRootCfg

Base class for kauldron data pipelines.

Subclasses should implement:

  • __iter__: Yield individual batches

  • (optionally) __len__: Number of iterations

Subclasses are responsible for:

  • batching

  • shuffling

  • sharding: Each host yields different examples

batch_size

Global batch size. Has to be divisible by number of global devices. Pipeline should take care of sharding the data between hosts. Setting to 0 disables batching.

Type:

int | None

seed

Random seed to be used for things like shuffling and randomness in preprocessing. Defaults to the seed from the root config.

Type:

int | collections.abc.Sequence[int] | numpy.ndarray | jaxtyping.UInt32[Array, ‘2’] | jaxtyping.UInt32[ndarray, ‘2’] | jax.Array | None

batch_size: int | None = None
seed: int | collections.abc.Sequence[int] | numpy.ndarray | jaxtyping.UInt32[Array, '2'] | jaxtyping.UInt32[ndarray, '2'] | jax.Array | None = _FakeRootCfg('cfg.seed')
property element_spec: etils.enp.array_spec.ArraySpec | Sequence[PyTree[L]] | Mapping[str, PyTree[L]]

Returns the element specs of a single batch.

property host_batch_size: int