kd.train.RngStreams

kd.train.RngStreams#

class kauldron.train.RngStreams(
stream_overwrites: Sequence[RngStream] = <factory>,
*,
_fake_refs: type[_FakeRefsUnset] | dict[str,
_FakeRootCfg] = <class 'kauldron.utils.config_util._FakeRefsUnset'>,
seed: int = _FakeRootCfg('cfg.seed'),
)[source]

Bases: kauldron.utils.config_util.UpdateFromRootCfg

Manager of rng streams.

See doc at https://kauldron.rtfd.io/en/latest/eval.html#rng-streams

Generate the rngs dict to pass to model.init / model.apply.

3 streams are always added: params, dropout, default but their values can be overwritten with stream_overwrites.

stream_overwrites

Additional streams to add. Will be merged with the default ones.

Type:

collections.abc.Sequence[kauldron.train.rngs_lib.RngStream]

seed

Seed to initialize the root_rng. If None, will reuse the global seed from kd.train.Trainer

Type:

int

stream_overwrites: Sequence[RngStream]
seed: int = _FakeRootCfg('cfg.seed')
property streams: dict[str, kauldron.train.rngs_lib.RngStream]

Streams (after default are merged).

property root_rng: kauldron.random.random.PRNGKey

Base root rng from which others are derived.

init_rngs() dict[str, kauldron.random.random.PRNGKey][source]

Rngs for model.init().

train_rngs(
step: int,
) dict[str, kauldron.random.random.PRNGKey][source]

Rngs for model.apply(…, is_training_property=True).

Parameters:

step – Current train/eval step

Returns:

The dict[<stream name>, kd.random.PRNGKey]

Return type:

rngs

eval_rngs(
step: int,
) dict[str, kauldron.random.random.PRNGKey][source]

Rngs for model.apply(…, is_training_property=False).

Parameters:

step – Current train/eval step

Returns:

The dict[<stream name>, kd.random.PRNGKey]

Return type:

rngs