kd.data.py.PyGrainPipeline#
- class kauldron.data.py.PyGrainPipeline(*, _fake_refs: type[_FakeRefsUnset] | dict[str, _FakeRootCfg] = <class 'kauldron.utils.config_util._FakeRefsUnset'>, batch_size: int | None = None, seed: int | collections.abc.Sequence[int] | numpy.ndarray | jaxtyping.UInt32[Array, '2'] | jaxtyping.UInt32[ndarray, '2'] | jax.Array | None = _FakeRootCfg('cfg.seed'), transforms: tr_normalize.Transformations = <factory>, num_epochs: Optional[int] = None, batch_drop_remainder: bool = True, num_workers: int = 16, read_options: grain.ReadOptions | None = None, enable_profiling: bool = False, per_worker_buffer_size: int = 1, worker_init_fn: Callable[[int, int], None] | None = None)[source]
Bases:
kauldron.data.pipelines.PipelineAbstract base class to construct PyGrain data pipeline.
See doc:
- transforms
A list of transformations to apply to the dataset. Each transformation should be either a grain.MapTransform or a grain.RandomMapTransform.
- Type:
tr_normalize.Transformations
- num_epochs
Number of epoch. If missing, iterate indefinitely (number of iteration is given by cfg.num_training_steps)
- Type:
Optional[int]
- batch_drop_remainder
Whether or not drop the last examples if len(ds) % batch_size != 0
- Type:
bool
- num_workers
how many worker processes to use for data loading (0 to disable multiprocessing)
- Type:
int
- read_options
Options for reading data from the DataSource.
- Type:
grain.ReadOptions | None
- enable_profiling
If True data worker process 0 will be profiled.
- Type:
bool
- worker_init_fn
If set, will initialize subprocesses with this function instead of the kauldron default.
- Type:
Callable[[int, int], None] | None
- transforms: tr_normalize.Transformations
- num_epochs: int | None = None
- batch_drop_remainder: bool = True
- num_workers: int = 16
- read_options: grain.ReadOptions | None = None
- enable_profiling: bool = False
- per_worker_buffer_size: int = 1
- worker_init_fn: Callable[[int, int], None] | None = None
- ds_for_current_process(
- rng: kauldron.random.random.PRNGKey,
- ds_with_transforms(
- rng: kauldron.random.random.PRNGKey,
Create the tf.data.Dataset and apply all the transforms.
- property element_spec: PyTree[enp.ArraySpec]
Returns the element specs of a single batch.