kd.evals.Evaluator

kd.evals.Evaluator#

class kauldron.evals.Evaluator(**kwargs)[source]

Bases: kauldron.evals.evaluators.EvaluatorBase

Evaluator running num_batches times.

Evaluators can be launched as separate XManager jobs (e.g. run=kd.evals.StandaloneEveryCheckpoint()) or along train (e.g. run=kd.evals.EveryNSteps(100)).

If not provided, losses, metrics, summaries are reused from train.

Usage:

evaluator = kd.evals.Evaluator(
    run=kd.evals.EveryNSteps(100),
    ds=test_ds,
    base_cfg=cfg,
)
evaluator.maybe_eval(step=0, state=state)
num_batches

How many batches to run evaluation on. Use None to evaluate on the full test dataset. Note that each evaluation reinitializes the dataset iterator, so setting to 1 will run all evaluations on the same batch.

Type:

int | None

cache

Whether to cache the iterator

Type:

bool

ds

Dataset to evaluate on.

Type:

kauldron.data.pipelines.Pipeline

losses

Losses

Type:

dict[str, kauldron.losses.base.Loss]

metrics

Metrics

Type:

dict[str, kauldron.metrics.base.Metric]

summaries

Summaries

Type:

dict[str, kauldron.metrics.base.Metric]

model

Model to use for evaluation (if different from train).

Type:

flax.linen.module.Module

model_method

Name of the flax model method to use (defaults to __call__)

Type:

str | None

init_transform

Transform to apply to the state before evaluation. This is useful for example for replacing the weights of the network with EMA weights.

Type:

kauldron.checkpoints.partial_loader.AbstractPartialLoader

num_batches: int | None = None
cache: bool = False
ds: data.Pipeline = _FakeRootCfg('cfg.eval_ds')
losses: dict[str, losses_lib.Loss] = _FakeRootCfg('cfg.train_losses')
metrics: dict[str, metrics_lib.Metric] = _FakeRootCfg('cfg.train_metrics')
summaries: dict[str, metrics_lib.Metric] = _FakeRootCfg('cfg.train_summaries')
model: nn.Module = _FakeRootCfg('cfg.model')
model_method: str | None = None
init_transform: checkpoints.AbstractPartialLoader
checkify_error_categories: frozenset[CheckifyErrorCategory] = _FakeRootCfg('cfg.checkify_error_categories')
property ds_iter: kauldron.data.data_utils.IterableDataset
property aux: kauldron.train.auxiliaries.Auxiliaries
evaluate(
state: kauldron.train.train_step.TrainState,
step: int,
) kauldron.train.auxiliaries.AuxiliariesState[source]

Run one full evaluation.

step(
*,
step_nr: int,
state: kauldron.train.train_step.TrainState,
batch: Any,
) kauldron.train.auxiliaries.AuxiliariesState[source]
property model_with_aux: kauldron.train.train_step.ModelWithAux

Deprecated. Use a forward function directly instead.

See e.g. kd.train.train_step.forward_with_loss.