kd.train.Trainer#
- class kauldron.train.Trainer(**kwargs)[source]
Bases:
kauldron.utils.config_util.BaseConfigBase trainer class.
This class is the root object containing all the configuration options ( datasets, model, optimizer, etc.).
Usage:
trainer = kd.train.Trainer( train_ds=..., model=..., optimizer=..., ... ) trainer.train()
- seed
Seed for all rngs
- Type:
int
- workdir
Root dir of the experiment (usually set by XManager)
- Type:
etils.epath.abstract_path.Path
- train_ds
Dataset used in training
- Type:
kauldron.data.pipelines.Pipeline
- eval_ds
Dataset used in eval (see https://kauldron.rtfd.io/en/latest/eval.html to activate eval)
- Type:
kauldron.data.pipelines.Pipeline | None
- model
Flax linen module
- Type:
flax.linen.module.Module
- rng_streams
Flax rng streams to use in addition of the default (params, dropout, default). If any of params, dropout, default is set here, it will overwrite the default value.
- Type:
kauldron.train.rngs_lib.RngStreams
- sharding
Model sharding (by default, use replicated sharding)
- Type:
kauldron.utils.sharding_utils.ShardingStrategy
- num_train_steps
Number of training steps. If None, train on the full dataset for the number of epoch specified in train_ds
- Type:
int | None
- stop_after_steps
Optionally stop already after running this many steps. If set, overwrite num_train_steps. Allow to debug on Colab without modifying the learning rate schedules and other values that depend on num_train_steps.
- Type:
int | None
- log_metrics_every
How often to compute and log the metrics (in TensorBoard,…)
- Type:
int
- log_summaries_every
How often to compute and log the summaries (in TensorBoard,…)
- Type:
int
- train_losses
A dict of losses
- Type:
collections.abc.MutableMapping[str, kauldron.losses.base.Loss]
- train_metrics
A dict of metrics
- Type:
collections.abc.MutableMapping[str, kauldron.metrics.base.Metric]
- train_summaries
A dict of summaries (summaries are subclasses of
Metric)- Type:
collections.abc.MutableMapping[str, kauldron.metrics.base.Metric]
- writer
Metric writer used for writing to TB, datatable, etc.
- Type:
kauldron.train.metric_writer.WriterBase
- profiler
Profiler can be customized (see kd.inspect.Profile)
- Type:
kauldron.inspect.profile_utils.Profiler
- checkify_error_categories
List of errors to enable checkify for.
- Type:
frozenset[Any]
- schedules
optax schedules (to be used in optimizer)
- Type:
collections.abc.MutableMapping[str, collections.abc.Callable[[jax.Array | numpy.ndarray | numpy.bool | numpy.number | float | int], jax.Array | numpy.ndarray | numpy.bool | numpy.number | float | int]]
- optimizer
optax optimizer
- Type:
optax._src.base.GradientTransformation
- checkpointer
Checkpoint used to save/restore the state
- Type:
kauldron.checkpoints.checkpointer.BaseCheckpointer
- init_transform
An initial state transformation. Used for partial checkpoint loading (re-use pre-trained weights).
- Type:
kauldron.checkpoints.partial_loader.AbstractPartialLoader
- trainstep
Training loop step. Do not set this field unless you need a custom training step.
- Type:
kauldron.train.train_step.TrainStep
- evals
Evaluators to use (e.g. {‘eval’: kd.eval.Evaluator()})
- Type:
collections.abc.MutableMapping[str, kauldron.evals.evaluators.EvaluatorBase]
- aux
A dict of arbitrary additional values (e.g. can be set once and referenced elsewhere cfg.model.num_layer = cfg.ref.aux.num_layers).
- Type:
Any
- setup
Global setup options
- Type:
kauldron.train.setup_utils.Setup
- xm_job
XManager runtime parameters (e.g. which target is the config using)
- Type:
Any
- raw_cfg
Original config from which this Config was created. Automatically set during konfig.resolve()
- Type:
kauldron.konfig.configdict_base.ConfigDict | None
- seed: int = 0
- workdir: Annotated[etils.epath.abstract_path.Path, <object object at 0x764123e883c0>] = PosixGPath('.')
- train_ds: kauldron.data.pipelines.Pipeline
- eval_ds: kauldron.data.pipelines.Pipeline | None = None
- model: flax.linen.module.Module
- rng_streams: kauldron.train.rngs_lib.RngStreams
- sharding: kauldron.utils.sharding_utils.ShardingStrategy
- num_train_steps: int | None = None
- stop_after_steps: int | None = None
- log_metrics_every: int = 100
- log_summaries_every: int = 1000
- train_losses: collections.abc.MutableMapping[str, kauldron.losses.base.Loss]
- train_metrics: collections.abc.MutableMapping[str, kauldron.metrics.base.Metric]
- train_summaries: collections.abc.MutableMapping[str, kauldron.metrics.base.Metric]
- writer: kauldron.train.metric_writer.WriterBase
- profiler: kauldron.inspect.profile_utils.Profiler
- checkify_error_categories: frozenset[Any] = frozenset({})
- schedules: collections.abc.MutableMapping[str, collections.abc.Callable[[jax.Array | numpy.ndarray | numpy.bool | numpy.number | float | int], jax.Array | numpy.ndarray | numpy.bool | numpy.number | float | int]]
- optimizer: optax._src.base.GradientTransformation
- checkpointer: kauldron.checkpoints.checkpointer.BaseCheckpointer
- init_transform: kauldron.checkpoints.partial_loader.AbstractPartialLoader
- trainstep: kauldron.train.train_step.TrainStep
- evals: collections.abc.MutableMapping[str, kauldron.evals.evaluators.EvaluatorBase]
- aux: Any
- setup: kauldron.train.setup_utils.Setup
- xm_job: Any
- raw_cfg: kauldron.konfig.configdict_base.ConfigDict | None = None
- eval_only() Self
Returns a
Trainerwhich only do evaluation.Calling this function in a konfig context will pre-populate the konfig.ConfigDict returned object with the values defined in kauldron/konfig/default_values.py:
Usage:
cfg = kd.train.Trainer.eval_only() # Should be set either here or in the CLI `--cfg.aux.xid=12345` to indicate # which Kauldron experiment to load the model from. cfg.aux.xid = 12345 cfg.aux.wid = 1 # WARNING: Do not overwrite the `cfg.aux` field directly. Only set its # values like above. E.g. `cfg.aux = {"xid": 12345, 'wid': 1}` won't work. cfg.evals { ..., }
This function should NOT be directly called outside a konfig context.
- Parameters:
**kwargs – Propagated to the
kd.train.Trainerconstructor.
- init_state(
- *,
- skip_transforms: bool = False,
- skip_optimizer: bool = False,
- element_spec: PyTree[enp.ArraySpec] | None = None,
Create the state: cfg.trainstep.init(element_spec).
- train() tuple[kauldron.train.train_step.TrainState, kauldron.train.auxiliaries.AuxiliariesState][source]
Main method that train/evaluate the object.
Similar to:
state = trainer.init_state() for batch in trainer.train_ds: batch = trainer.trainstep.step(batch, state)
- Returns:
Final model state Auxiliaries
- continuous_eval(
- names: str | list[str],
Main method that perform auxiliary tasks (evaluation, rendering,…).
Trigger an evaluation everytime a new checkpoint is detected.
See https://kauldron.rtfd.io/en/latest/eval.html for details.
- Parameters:
names – Name of the evaluators to run.
- Returns:
Mapping eval name to auxiliary
- Return type:
Auxiliaries
- property state_specs: kauldron.train.train_step.TrainState
Returns the state specs.
- property context_specs: kauldron.train.context.Context
Shape evaluate the model (fast) and return the context structure.