kd.train.Trainer

kd.train.Trainer#

class kauldron.train.Trainer(**kwargs)[source]

Bases: kauldron.utils.config_util.BaseConfig

Base trainer class.

This class is the root object containing all the configuration options ( datasets, model, optimizer, etc.).

Usage:

trainer = kd.train.Trainer(
    train_ds=...,
    model=...,
    optimizer=...,
    ...
)
trainer.train()
seed

Seed for all rngs

Type:

int

workdir

Root dir of the experiment (usually set by XManager)

Type:

etils.epath.abstract_path.Path

train_ds

Dataset used in training

Type:

kauldron.data.pipelines.Pipeline

eval_ds

Dataset used in eval (see https://kauldron.rtfd.io/en/latest/eval.html to activate eval)

Type:

kauldron.data.pipelines.Pipeline | None

model

Flax linen module

Type:

flax.linen.module.Module

rng_streams

Flax rng streams to use in addition of the default (params, dropout, default). If any of params, dropout, default is set here, it will overwrite the default value.

Type:

kauldron.train.rngs_lib.RngStreams

sharding

Model sharding (by default, use replicated sharding)

Type:

kauldron.utils.sharding_utils.ShardingStrategy

num_train_steps

Number of training steps. If None, train on the full dataset for the number of epoch specified in train_ds

Type:

int | None

stop_after_steps

Optionally stop already after running this many steps. If set, overwrite num_train_steps. Allow to debug on Colab without modifying the learning rate schedules and other values that depend on num_train_steps.

Type:

int | None

log_metrics_every

How often to compute and log the metrics (in TensorBoard,…)

Type:

int

log_summaries_every

How often to compute and log the summaries (in TensorBoard,…)

Type:

int

train_losses

A dict of losses

Type:

collections.abc.MutableMapping[str, kauldron.losses.base.Loss]

train_metrics

A dict of metrics

Type:

collections.abc.MutableMapping[str, kauldron.metrics.base.Metric]

train_summaries

A dict of summaries (summaries are subclasses of Metric)

Type:

collections.abc.MutableMapping[str, kauldron.metrics.base.Metric]

writer

Metric writer used for writing to TB, datatable, etc.

Type:

kauldron.train.metric_writer.WriterBase

profiler

Profiler can be customized (see kd.inspect.Profile)

Type:

kauldron.inspect.profile_utils.Profiler

checkify_error_categories

List of errors to enable checkify for.

Type:

frozenset[Any]

schedules

optax schedules (to be used in optimizer)

Type:

collections.abc.MutableMapping[str, collections.abc.Callable[[jax.Array | numpy.ndarray | numpy.bool | numpy.number | float | int], jax.Array | numpy.ndarray | numpy.bool | numpy.number | float | int]]

optimizer

optax optimizer

Type:

optax._src.base.GradientTransformation

checkpointer

Checkpoint used to save/restore the state

Type:

kauldron.checkpoints.checkpointer.BaseCheckpointer

init_transform

An initial state transformation. Used for partial checkpoint loading (re-use pre-trained weights).

Type:

kauldron.checkpoints.partial_loader.AbstractPartialLoader

trainstep

Training loop step. Do not set this field unless you need a custom training step.

Type:

kauldron.train.train_step.TrainStep

evals

Evaluators to use (e.g. {‘eval’: kd.eval.Evaluator()})

Type:

collections.abc.MutableMapping[str, kauldron.evals.evaluators.EvaluatorBase]

aux

A dict of arbitrary additional values (e.g. can be set once and referenced elsewhere cfg.model.num_layer = cfg.ref.aux.num_layers).

Type:

Any

setup

Global setup options

Type:

kauldron.train.setup_utils.Setup

xm_job

XManager runtime parameters (e.g. which target is the config using)

Type:

Any

raw_cfg

Original config from which this Config was created. Automatically set during konfig.resolve()

Type:

kauldron.konfig.configdict_base.ConfigDict | None

seed: int = 0
workdir: Annotated[etils.epath.abstract_path.Path, <object object at 0x764123e883c0>] = PosixGPath('.')
train_ds: kauldron.data.pipelines.Pipeline
eval_ds: kauldron.data.pipelines.Pipeline | None = None
model: flax.linen.module.Module
rng_streams: kauldron.train.rngs_lib.RngStreams
sharding: kauldron.utils.sharding_utils.ShardingStrategy
num_train_steps: int | None = None
stop_after_steps: int | None = None
log_metrics_every: int = 100
log_summaries_every: int = 1000
train_losses: collections.abc.MutableMapping[str, kauldron.losses.base.Loss]
train_metrics: collections.abc.MutableMapping[str, kauldron.metrics.base.Metric]
train_summaries: collections.abc.MutableMapping[str, kauldron.metrics.base.Metric]
writer: kauldron.train.metric_writer.WriterBase
profiler: kauldron.inspect.profile_utils.Profiler
checkify_error_categories: frozenset[Any] = frozenset({})
schedules: collections.abc.MutableMapping[str, collections.abc.Callable[[jax.Array | numpy.ndarray | numpy.bool | numpy.number | float | int], jax.Array | numpy.ndarray | numpy.bool | numpy.number | float | int]]
optimizer: optax._src.base.GradientTransformation
checkpointer: kauldron.checkpoints.checkpointer.BaseCheckpointer
init_transform: kauldron.checkpoints.partial_loader.AbstractPartialLoader
trainstep: kauldron.train.train_step.TrainStep
evals: collections.abc.MutableMapping[str, kauldron.evals.evaluators.EvaluatorBase]
aux: Any
setup: kauldron.train.setup_utils.Setup
xm_job: Any
raw_cfg: kauldron.konfig.configdict_base.ConfigDict | None = None
eval_only() Self

Returns a Trainer which only do evaluation.

Calling this function in a konfig context will pre-populate the konfig.ConfigDict returned object with the values defined in kauldron/konfig/default_values.py:

Usage:

cfg = kd.train.Trainer.eval_only()

# Should be set either here or in the CLI `--cfg.aux.xid=12345` to indicate
# which Kauldron experiment to load the model from.
cfg.aux.xid = 12345
cfg.aux.wid = 1

# WARNING: Do not overwrite the `cfg.aux` field directly. Only set its
# values like above. E.g. `cfg.aux = {"xid": 12345, 'wid': 1}` won't work.

cfg.evals {
    ...,
}

This function should NOT be directly called outside a konfig context.

Parameters:

**kwargs – Propagated to the kd.train.Trainer constructor.

init_state(
*,
skip_transforms: bool = False,
skip_optimizer: bool = False,
element_spec: PyTree[enp.ArraySpec] | None = None,
) train_step.TrainState[source]

Create the state: cfg.trainstep.init(element_spec).

train() tuple[kauldron.train.train_step.TrainState, kauldron.train.auxiliaries.AuxiliariesState][source]

Main method that train/evaluate the object.

Similar to:

state = trainer.init_state()

for batch in trainer.train_ds:
  batch = trainer.trainstep.step(batch, state)
Returns:

Final model state Auxiliaries

continuous_eval(
names: str | list[str],
) dict[str, kauldron.train.auxiliaries.AuxiliariesState][source]

Main method that perform auxiliary tasks (evaluation, rendering,…).

Trigger an evaluation everytime a new checkpoint is detected.

See https://kauldron.rtfd.io/en/latest/eval.html for details.

Parameters:

names – Name of the evaluators to run.

Returns:

Mapping eval name to auxiliary

Return type:

Auxiliaries

property state_specs: kauldron.train.train_step.TrainState

Returns the state specs.

property context_specs: kauldron.train.context.Context

Shape evaluate the model (fast) and return the context structure.