kd.evals.FewShotEvaluator

kd.evals.FewShotEvaluator#

class kauldron.evals.FewShotEvaluator(**kwargs)[source]

Bases: kauldron.evals.evaluators.EvaluatorBase

FewShotEvaluator running closed-form few-shot classification.

Compute the features from the model, solve closed-form L2-regularized linear regression for few-shot classification. This is fairly fast, so can be run regularly during training.

Following (and largely copying) google-research/big_vision

ds_train

Dataset to train few-shot classification on

Type:

data.Pipeline

ds_val

Dataset to validate few-shot classification on (to select L2 reg)

Type:

data.Pipeline

ds_test

Dataset to test few-shot classification on

Type:

data.Pipeline

metric_prefix

String prefix to be used for the metrics from this evaluator

Type:

str

num_classes

Number of classes in the classification task

Type:

int

num_shots

A sequence of integers - numbers of shots to be evaluated

Type:

Sequence[int]

repr_names

A dictionary of representations to be evaluated. Keys are names to be used to refer to the representations, values are paths in the context from which to take the actual features

Type:

Mapping[str, str]

l2_regs

Possible values for L2 regularization.

Type:

Sequence[float]

label_name

key by which to get the labels from the context

Type:

str

selected_repr

a key from repr_names for which to put the accuracies to the main metrics

Type:

str

seed

random seed for selecting the training data subset

Type:

int | Sequence[int]

Usage example:
“fewshot_i1k”: kd.evals.FewShotEvaluator(

run=kd.evals.EveryNSteps(10_000), metric_prefix=”i1k”, ds_train=_make_i1k_fewshot(split=”train[:-10000]”, batch_size=4096), ds_val=_make_i1k_fewshot(split=”train[-10000:]”, batch_size=4096), ds_test=_make_i1k_fewshot(split=”validation”, batch_size=4096), num_classes=1000, num_shots=(1, 2, 5, 10), repr_names={“pre_logits”: “interms.pre_logits.__call__[0]”}, label_name=”batch.label”,

)

ds_train: data.Pipeline
ds_val: data.Pipeline
ds_test: data.Pipeline
metric_prefix: str
num_classes: int
num_shots: Sequence[int]
repr_names: Mapping[str, str]
l2_regs: Sequence[float] = (64, 128, 256, 512, 1024)
label_name: str
selected_repr: str = 'pre_logits'
seed: int | Sequence[int] = _FakeRootCfg('cfg.seed')
property seeds: list[int]
evaluate(
state: kauldron.train.train_step.TrainState,
step: int,
)[source]

Run one full evaluation.

compute_features(
state: kauldron.train.train_step.TrainState,
ds: kauldron.data.data_utils.IterableDataset,
split: str,
) tuple[dict[str, jaxtyping.Shaped[Array, '...'] | jaxtyping.Shaped[ndarray, '...']], jaxtyping.Shaped[Array, '...'] | jaxtyping.Shaped[ndarray, '...']][source]
property aux: kauldron.train.auxiliaries.Auxiliaries