kd.data.py.DataSourceBase#
- class kauldron.data.py.DataSourceBase(*, _fake_refs: type[_FakeRefsUnset] | dict[str, _FakeRootCfg] = <class 'kauldron.utils.config_util._FakeRefsUnset'>, batch_size: int | None = None, seed: int | collections.abc.Sequence[int] | numpy.ndarray | jaxtyping.UInt32[Array, '2'] | jaxtyping.UInt32[ndarray, '2'] | jax.Array | None = _FakeRootCfg('cfg.seed'), transforms: tr_normalize.Transformations = <factory>, num_epochs: Optional[int] = None, batch_drop_remainder: bool = True, num_workers: int = 16, read_options: grain.ReadOptions | None = None, enable_profiling: bool = False, per_worker_buffer_size: int = 1, shard_by_process: bool = True, worker_init_fn: Callable[[int, int], None] | None = None, shuffle: bool)[source]
Bases:
kauldron.data.py.base.PyGrainPipelineBase class to implement a data source.
Child classes should overwrite the data_source property. See
kd.data.py.Tfdsfor an example.- shuffle
whether to shuffle
- Type:
bool
- shuffle: bool
- data_source: grain.RandomAccessDataSource
- ds_for_current_process(
- rng: kauldron.random.random.PRNGKey,