kd.evals.FewShotEvaluator#
- class kauldron.evals.FewShotEvaluator(**kwargs)[source]
Bases:
kauldron.evals.evaluators.EvaluatorBaseFewShotEvaluator running closed-form few-shot classification.
Compute the features from the model, solve closed-form L2-regularized linear regression for few-shot classification. This is fairly fast, so can be run regularly during training.
Following (and largely copying) google-research/big_vision
- ds_train
Dataset to train few-shot classification on
- Type:
data.Pipeline
- ds_val
Dataset to validate few-shot classification on (to select L2 reg)
- Type:
data.Pipeline
- ds_test
Dataset to test few-shot classification on
- Type:
data.Pipeline
- metric_prefix
String prefix to be used for the metrics from this evaluator
- Type:
str
- num_classes
Number of classes in the classification task
- Type:
int
- num_shots
A sequence of integers - numbers of shots to be evaluated
- Type:
Sequence[int]
- repr_names
A dictionary of representations to be evaluated. Keys are names to be used to refer to the representations, values are paths in the context from which to take the actual features
- Type:
Mapping[str, str]
- l2_regs
Possible values for L2 regularization.
- Type:
Sequence[float]
- label_name
key by which to get the labels from the context
- Type:
str
- selected_repr
a key from repr_names for which to put the accuracies to the main metrics
- Type:
str
- seed
random seed for selecting the training data subset
- Type:
int | Sequence[int]
- Usage example:
- “fewshot_i1k”: kd.evals.FewShotEvaluator(
run=kd.evals.EveryNSteps(10_000), metric_prefix=”i1k”, ds_train=_make_i1k_fewshot(split=”train[:-10000]”, batch_size=4096), ds_val=_make_i1k_fewshot(split=”train[-10000:]”, batch_size=4096), ds_test=_make_i1k_fewshot(split=”validation”, batch_size=4096), num_classes=1000, num_shots=(1, 2, 5, 10), repr_names={“pre_logits”: “interms.pre_logits.__call__[0]”}, label_name=”batch.label”,
)
- ds_train: data.Pipeline
- ds_val: data.Pipeline
- ds_test: data.Pipeline
- metric_prefix: str
- num_classes: int
- num_shots: Sequence[int]
- repr_names: Mapping[str, str]
- l2_regs: Sequence[float] = (64, 128, 256, 512, 1024)
- label_name: str
- selected_repr: str = 'pre_logits'
- seed: int | Sequence[int] = _FakeRootCfg('cfg.seed')
- property seeds: list[int]
- evaluate(
- state: kauldron.train.train_step.TrainState,
- step: int,
Run one full evaluation.
- compute_features(
- state: kauldron.train.train_step.TrainState,
- ds: kauldron.data.data_utils.IterableDataset,
- split: str,
- property aux: kauldron.train.auxiliaries.Auxiliaries