kd.train.RngStreams#
- class kauldron.train.RngStreams(
- stream_overwrites: Sequence[RngStream] = <factory>,
- *,
- _fake_refs: type[_FakeRefsUnset] | dict[str,
- _FakeRootCfg] = <class 'kauldron.utils.config_util._FakeRefsUnset'>,
- seed: int = _FakeRootCfg('cfg.seed'),
Bases:
kauldron.utils.config_util.UpdateFromRootCfgManager of rng streams.
See doc at https://kauldron.rtfd.io/en/latest/eval.html#rng-streams
Generate the rngs dict to pass to model.init / model.apply.
3 streams are always added: params, dropout, default but their values can be overwritten with stream_overwrites.
- stream_overwrites
Additional streams to add. Will be merged with the default ones.
- Type:
collections.abc.Sequence[kauldron.train.rngs_lib.RngStream]
- seed
Seed to initialize the root_rng. If None, will reuse the global seed from
kd.train.Trainer- Type:
int
- stream_overwrites: Sequence[RngStream]
- seed: int = _FakeRootCfg('cfg.seed')
- property streams: dict[str, kauldron.train.rngs_lib.RngStream]
Streams (after default are merged).
- property root_rng: kauldron.random.random.PRNGKey
Base root rng from which others are derived.
- init_rngs() dict[str, kauldron.random.random.PRNGKey][source]
Rngs for model.init().
- train_rngs(
- step: int,
Rngs for model.apply(…, is_training_property=True).
- Parameters:
step – Current train/eval step
- Returns:
The dict[<stream name>, kd.random.PRNGKey]
- Return type:
rngs
- eval_rngs(
- step: int,
Rngs for model.apply(…, is_training_property=False).
- Parameters:
step – Current train/eval step
- Returns:
The dict[<stream name>, kd.random.PRNGKey]
- Return type:
rngs