kd.data.py.PyGrainPipeline

kd.data.py.PyGrainPipeline#

class kauldron.data.py.PyGrainPipeline(*, _fake_refs: type[_FakeRefsUnset] | dict[str, _FakeRootCfg] = <class 'kauldron.utils.config_util._FakeRefsUnset'>, batch_size: int | None = None, seed: int | collections.abc.Sequence[int] | numpy.ndarray | jaxtyping.UInt32[Array, '2'] | jaxtyping.UInt32[ndarray, '2'] | jax.Array | None = _FakeRootCfg('cfg.seed'), transforms: tr_normalize.Transformations = <factory>, num_epochs: Optional[int] = None, batch_drop_remainder: bool = True, num_workers: int = 16, read_options: grain.ReadOptions | None = None, enable_profiling: bool = False, per_worker_buffer_size: int = 1, shard_by_process: bool = True, worker_init_fn: Callable[[int, int], None] | None = None)[source]

Bases: kauldron.data.pipelines.Pipeline

Abstract base class to construct PyGrain data pipeline.

See doc:

transforms

A list of transformations to apply to the dataset. Each transformation should be either a grain.MapTransform or a grain.RandomMapTransform.

Type:

tr_normalize.Transformations

num_epochs

Number of epoch. If missing, iterate indefinitely (number of iteration is given by cfg.num_training_steps)

Type:

Optional[int]

batch_drop_remainder

Whether or not drop the last examples if len(ds) % batch_size != 0

Type:

bool

num_workers

how many worker processes to use for data loading (0 to disable multiprocessing)

Type:

int

read_options

Options for reading data from the DataSource.

Type:

grain.ReadOptions | None

enable_profiling

If True data worker process 0 will be profiled.

Type:

bool

shard_by_process

Whether to shard the dataset by process count and index. Should use the default for most cases. Use non-default values for experimental use only.

Type:

bool

worker_init_fn

If set, will initialize subprocesses with this function instead of the kauldron default.

Type:

Callable[[int, int], None] | None

transforms: tr_normalize.Transformations
num_epochs: int | None = None
batch_drop_remainder: bool = True
num_workers: int = 16
read_options: grain.ReadOptions | None = None
enable_profiling: bool = False
per_worker_buffer_size: int = 1
shard_by_process: bool = True
worker_init_fn: Callable[[int, int], None] | None = None
ds_for_current_process(
rng: kauldron.random.random.PRNGKey,
) grain._src.python.dataset.dataset.MapDataset[source]
ds_with_transforms(
rng: kauldron.random.random.PRNGKey,
) grain._src.python.dataset.dataset.MapDataset[source]

Create the tf.data.Dataset and apply all the transforms.

property element_spec: PyTree[enp.ArraySpec]

Returns the element specs of a single batch.