kd#

[[Source]]

import kauldron as kd

Kauldron public API.

from kauldron import kd

All symbols#

Module#

kd

Kauldron public API.

kd._filter_logs

Helper to filter logs from verbose modules.

kd.ckpts

Checkpoints API.

kd.ckpts.items

Checkpoint handler.

kd.contrib

Contrib public API.

kd.data

Data modules.

kd.data.iters

Dataset iterators.

kd.data.py

PyGrain public API.

kd.evals

Evaluator.

kd.from_xid

Helper for loading configs,… from XManager experiments.

kd.inspect

Inspect utils.

kd.kdash

Small library for creating Flatboard dashboards.

kd.knn

Wrapper aound flax.linen.Module to add torch-like API.

kd.konfig

Wrapper around ConfigDict to support auto-complete/type checking.

kd.konfig._default_values

Default values and configuration.

kd.kontext

Kontext is a small self-contained library to manipulate nested trees.

kd.losses

Losses.

kd.metrics

Metrics.

kd.nn

Collection of nn.Modules to build neural networks.

kd.optim

Optimizers etc.

kd.random

Small wrapper around jax.random for OO API.

kd.summaries

Summaries.

kd.summaries.deprecated

Deprecated summaries.

kd.testing

Testing utilities.

kd.train

Train.

kd.typing

Common Typing Annotations.

kd.utils

Utils public API.

kd.xm

XManager utils.

Class#

kd.ckpts.AbstractPartialLoader

Abstract class for partial checkpoint loaders.

kd.ckpts.Checkpointer

Wrapper around Orbax CheckpointManager.

kd.ckpts.MultiTransform

Transform which applies multiple transformations sequentially.

kd.ckpts.NoopCheckpointer

Does nothing.

kd.ckpts.NoopTransform

init_transform that does nothing.

kd.ckpts.PartialKauldronLoader

Partial loader for Kauldron checkpoints.

kd.ckpts.items.CheckpointItem

Interface for a checkpoint item.

kd.ckpts.items.StandardCheckpointItem

Standard checkpoint item (for arbitrary jax.Array pytree).

kd.ckpts.items.TopLevelCheckpointItem

Checkpoint item that contains other sub-checkpoint items.

kd.data.AddConstants

Adds constant elements.

kd.data.BatchSize

Batch size.

kd.data.Cast

Cast an element to the specified dtype.

kd.data.CenterCrop

Crop the input data to the specified shape from the center.

kd.data.ElementWiseTransform

Base class for elementwise transforms.

kd.data.Elements

Modify the elements by keeping xor dropping and/or renaming and/or copying.

kd.data.FilterTransform

Abstract base class for filter transformations for individual elements.

kd.data.Gather

Gathers entries along a single dimension.

kd.data.InMemoryPipeline

Pipeline which fit in memory.

kd.data.IterableDataset

General interface for iterable datasets.

kd.data.MapTransform

Abstract base class for all 1:1 transformations of elements.

kd.data.Pipeline

Base class for kauldron data pipelines.

kd.data.Rearrange

Einops rearrange on a single element.

kd.data.Resize

Resizes an image.

kd.data.TreeFlattenWithPath

Flatten any tree-structured elements.

kd.data.ValueRange

Map the value range of an element.

kd.data.iters.Iterator

Wrapper around a dataset iterator.

kd.data.iters.NonCheckpointableIterator

Handler that is not-checkpointable.

kd.data.iters.PyGrainIterator

PyGrain iterator.

kd.data.iters.TFDataIterator

Checkpointable tf.data iterator.

kd.data.py.AddConstants

Adds constant elements.

kd.data.py.Cast

Cast an element to the specified dtype.

kd.data.py.CenterCrop

Crop the input data to the specified shape from the center.

kd.data.py.DataSource

Generic loader of arbitrary grain data source.

kd.data.py.DataSourceBase

Base class to implement a data source.

kd.data.py.ElementWiseRandomTransform

Base class for elementwise transforms.

kd.data.py.ElementWiseTransform

Base class for elementwise transforms.

kd.data.py.Elements

Modify the elements by keeping xor dropping and/or renaming and/or copying.

kd.data.py.Gather

Gathers entries along a single dimension.

kd.data.py.HuggingFace

HuggingFace loader.

kd.data.py.Json

Json pipeline.

kd.data.py.Mix

Create a dataset mixture from given weights.

kd.data.py.PyGrainPipeline

Abstract base class to construct PyGrain data pipeline.

kd.data.py.RandomCrop

Randomly crop the input data to the specified shape.

kd.data.py.RandomFlipLeftRight

Flips an image horizontally with probability 50%.

kd.data.py.Rearrange

Einops rearrange on a single element.

kd.data.py.Resize

Resizes an image.

kd.data.py.SelectFromDatasets

Create a dataset mixture using a selection map.

kd.data.py.SliceDataset

Transform which select a subset of the dataset.

kd.data.py.Tfds

Base TFDS loader.

kd.data.py.TreeFlattenWithPath

Flatten any tree-structured elements.

kd.data.py.ValueRange

Map the value range of an element.

kd.evals.CollectionKeys

Names of the metrics/summaries/losses (displayed in flatboard).

kd.evals.Evaluator

Evaluator running num_batches times.

kd.evals.EvaluatorBase

Base class for inline evaluators.

kd.evals.EveryNSteps

Run eval every N train steps.

kd.evals.FewShotEvaluator

FewShotEvaluator running closed-form few-shot classification.

kd.evals.Once

Run eval only after the XX train steps.

kd.evals.RunStrategy

Base class for info on how to run the evaluation.

kd.evals.StandaloneEveryCheckpoint

Run eval continuously everytime a new checkpoint is found.

kd.evals.StandaloneLastCheckpoint

Run eval only after the last checkpoint, after train has completed.

kd.inspect.Profiler

kd.inspect.Profiler.

kd.kdash.BuildContext

Context for building the dashboard.

kd.kdash.DashboardsBase

Flatboard dashboard structure.

kd.kdash.MetricDashboards

Standard metrics & losses dashboards for a single collection.

kd.kdash.MultiDashboards

Container of multiple dashboards.

kd.kdash.NoopDashboard

Empty dashboard.

kd.kdash.Plot

Single plot inside a dashboard.

kd.kdash.SingleDashboard

Single dashboard containing multiple plots.

kd.knn.Dense

Dense(features: int, use_bias: bool = True, dtype: Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any, NoneType] = None, param_dtype: Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any] = <class ‘jax.numpy.float32’>, precision: Union[NoneType, str, jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]] = None, kernel_init: Union[jax.nn.initializers.Initializer, collections.abc.Callable[…, Any]] = <function variance_scaling..init at 0x76412019ad40>, bias_init: Union[jax.nn.initializers.Initializer, collections.abc.Callable[…, Any]] = <function zeros at 0x76410a396fc0>, promote_dtype: flax.linen.linear.PromoteDtypeFn = <function promote_dtype at 0x76410689fce0>, dot_general: collections.abc.Callable[…, typing.Union[jax.Array, typing.Any]]

kd.knn.Dropout

Dropout(rate: float, broadcast_dims: collections.abc.Sequence[int] = (), deterministic: bool

kd.knn.Module

Base Module class.

kd.knn.Sequential

Sequential(layers: collections.abc.Sequence[collections.abc.Callable[…, typing.Any]], parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object at 0x76410688a690>, name: Optional[str] = None, *, _kd_state: ‘Optional[_ModuleState]’ = None)

kd.konfig.ConfigDict

Wrapper around ConfigDict.

kd.konfig.WithRef

Protocol to better access lazy fields.

kd.kontext.Context

kd.kontext.GlobPath

Represents a string path.

kd.kontext.Path

Represents a (non-glob) string path.

kd.losses.AbsoluteValue

Absolute value loss.

kd.losses.Huber

Huber loss.

kd.losses.L1

L1 loss.

kd.losses.L2

L2 loss.

kd.losses.Loss

Base class for losses which handles masks, averaging, and loss-weight.

kd.losses.NegativeCosineSimilarity

Negative Cosine Similarity loss.

kd.losses.SigmoidBinaryCrossEntropy

Sigmoid cross-entropy loss with binary labels.

kd.losses.SoftmaxCrossEntropy

Softmax cross-entropy loss.

kd.losses.SoftmaxCrossEntropyWithIntLabels

Softmax cross-entropy loss with integer labels.

kd.losses.Value

Value loss.

kd.metrics.Accuracy

Classification Accuracy.

kd.metrics.Ari

Adjusted Rand Index (ARI) computed from predictions and labels.

kd.metrics.AutoState

Flexible base class for conveniently defining custom states.

kd.metrics.AverageState

Computes the average of a scalar or a batch of tensors.

kd.metrics.BinaryAccuracy

Classification Accuracy for Binary classification tasks.

kd.metrics.CollectFirstState

Get the first outputs (possibly) across multiple steps (no reducing).

kd.metrics.CollectingState

Accumulate outputs across multiple steps (without reducing).

kd.metrics.EmptyState

Empty state.

kd.metrics.LpipsVgg

VGG LPIPS.

kd.metrics.Metric

Base class for metrics.

kd.metrics.NoopMetric

Metric that does nothing. Can be used in sweeps to remove a metric.

kd.metrics.Norm

Wraps jnp.linalg.norm to compute the average norm for given tensors.

kd.metrics.Precision1

Precision@1 for multilabel classification.

kd.metrics.Psnr

PSNR.

kd.metrics.RocAuc

Area Under the Receiver Operating Characteristic Curve (ROC AUC).

kd.metrics.SingleDimension

Returns a single chosen dimension of the tensor.

kd.metrics.SkipIfMissing

Skip this metric if any of the keys are missing.

kd.metrics.Ssim

Structural similarity (SSIM).

kd.metrics.State

Base metric state class.

kd.metrics.Std

Compute the standard deviation for float values.

kd.metrics.TreeMap

Maps an inner metric to a pytree and returns a pytree of results.

kd.metrics.TreeReduce

Applies a metric to a pytree and returns the aggregated result.

kd.nn.AddEmbedding

Helper Module for adding a PositionEmbedding e.g. in a knn.Sequential.

kd.nn.AddLearnedEmbedding

Adds learned positional embeddings to the inputs.

kd.nn.AttentionModule

Interface specification for Attention modules.

kd.nn.Dropout

Wrapper around nn.Dropout but using kd.nn.train_property.

kd.nn.DummyModel

Empty model that ignores inputs and always produces a single logit of 42.

kd.nn.ExternalModule

Module that is defined outside Kauldron.

kd.nn.FlatAutoencoder

Very simple auto-encoder class to showcase using keys and submodules.

kd.nn.FourierEmbedding

Apply Fourier position embedding to a grid of coordinates.

kd.nn.Identity

Module that applies the identity function to a single tensor.

kd.nn.ImageTokenizer

Interface for modules that convert images into tokens.

kd.nn.ImprovedMultiHeadDotProductAttention

Multi-head dot-product attention.

kd.nn.LearnedEmbedding

Learned positional embeddings.

kd.nn.MultiHeadDotProductAttention

Wrapper around nn.MultiHeadDotProductAttention using knn.train_property.

kd.nn.NormModule

Interface specification for norm modules (to be used as type annotation).

kd.nn.ParallelAttentionBlock

Parallel self attention (see Vit22B paper: arxiv.org/abs/2302.05442).

kd.nn.Patchify

Patchify an image, as in ViT (without linear embedding).

kd.nn.PatchifyEmbed

Patchify and linearly embed and image, as in ViT.

kd.nn.PostNormBlock

Post-LN Transformer layer (not recommended).

kd.nn.PreNormBlock

Pre-LN Transformer layer (default transformer layer).

kd.nn.Rearrange

Wrapper around einops.rearrange for usage e.g. in nn.Sequential.

kd.nn.Reduce

Wrapper around einops.reduce for usage e.g. in nn.Sequential.

kd.nn.Sequential

Like nn.Sequential but allows configuring input and output keys.

kd.nn.TransformerBlock

Interface definition for transformer blocks (for use in type annotations).

kd.nn.TransformerMLP

Simple MLP with a single hidden layer for use in Transformer blocks.

kd.nn.Vit

Basic Vision Transformer classifer with GAP.

kd.nn.VitEncoder

Basic Vit Encoder.

kd.nn.WrapperModule

Base class to wrapper a module.

kd.nn.ZeroEmbedding

Embedding that returns zero (for deactivating position embeddings).

kd.optim.UseEmaParams

Use the EMA parameters stored by the ema_params transform.

kd.random.PRNGKey

Small wrapper around jax.random key arrays to reduce boilerplate.

kd.summaries.Histogram

Output type for histogram summaries.

kd.summaries.HistogramSummary

Basic histogram summary.

kd.summaries.PointCloud

Output type for point cloud summaries.

kd.summaries.ShowBoxes

Show a set of boxes with optional image reshaping.

kd.summaries.ShowDifferenceImages

Show a set of difference images with optional reshaping.

kd.summaries.ShowImages

Show image summaries with optional reshaping.

kd.summaries.ShowPointCloud

Show a point cloud with optional reshaping.

kd.summaries.ShowSegmentations

Show a set of segmentations with optional reshaping.

kd.summaries.ShowTexts

Show texts.

kd.summaries.deprecated.ImageSummary

Deprecated ImageSummary. Raises an error if instantiated.

kd.summaries.deprecated.Summary

Deprecated Summaries. Raises an error if instantiated.

kd.train.Auxiliaries

Wrapper around the losses, summaries and metrics.

kd.train.AuxiliariesOutput

Auxiliaries final values (after merge and compute).

kd.train.AuxiliariesState

Auxiliaries (intermediate states to be accumulated).

kd.train.Context

Namespace for retrieving information with path-based keys.

kd.train.KDMetricWriter

Writes summaries to logs, tf_summaries and datatables.

kd.train.RngStream

Info on one rng stream.

kd.train.RngStreams

Manager of rng streams.

kd.train.Setup

Setup/environment options.

kd.train.TqdmInfo

TqdmInfo(*, desc: ‘str’ = ‘train’, log_xm: ‘bool’ = True)

kd.train.TrainState

Data structure for checkpointing the model.

kd.train.TrainStep

Base Training Step.

kd.train.Trainer

Base trainer class.

kd.typing.Any

kd.typing.Array

kd.typing.ArraySpec

Describes an array via it’s dtype and shape.

kd.typing.AxisName

kd.typing.Bool

kd.typing.Complex

kd.typing.Complex64

kd.typing.Float

kd.typing.Float32

kd.typing.Hashable

kd.typing.Int

kd.typing.Integer

kd.typing.Memo

Jaxtyping information about the shapes in the current scope.

kd.typing.Num

kd.typing.Sequence

All the operations on a read-only sequence.

kd.typing.Shape

Helper to construct concrete shape tuples from shape-specs.

kd.typing.TfArray

kd.typing.TfFloat

kd.typing.TfFloat32

kd.typing.TfInt

kd.typing.TfUInt8

kd.typing.TypeCheckError

Indicates a runtime typechecking error from the @typechecked decorator.

kd.typing.UInt32

kd.typing.UInt8

kd.typing.XArray

kd.xm.Experiment

XManager experiment wrapper.

kd.xm.WorkUnit

XManager work unit wrapper.

Function#

kd._filter_logs.add_filter

Add a filter to the absl logging handler.

kd.ckpts.workdir_from_xid

kd.from_xid.get_cfg

Returns the config/sub-config from an xmanager experiment.

kd.from_xid.get_element_spec

Returns the element_spec of the train dataset of an xmanager experiment.

kd.from_xid.get_resolved

Returns the resolved config/sub-config from an xmanager experiment.

kd.from_xid.get_workdir

Returns the workdir of an xmanager experiment.

kd.inspect.get_batch_stats

Return pd.DataFrame containing the batch stats.

kd.inspect.get_colab_model_overview

Return pd.DataFrame for displaying the model params, inputs,…

kd.inspect.get_connection_graph

Build the graphviz.

kd.inspect.json_spec_like

Convert etree.spec_like output to json and displays it in colab form.

kd.inspect.lower_trainstep

Returns lowered trainerstep.step.

kd.inspect.plot_batch

Display batch images.

kd.inspect.plot_context

Display the context structure.

kd.inspect.plot_schedules

Overview plot for (nested) dict of schedules.

kd.inspect.plot_sharding

Plot sharding.

kd.inspect.show_trainer_info

Display various plot on the trainer.

kd.kdash.build_and_upload

Create the dashboards.

kd.knn.convert

Decorator that convert a flax class into klinen.

kd.konfig.DEFINE_config_file

Defines flag for ConfigDict.

kd.konfig.imports

Contextmanager which replace import statements by configdicts.

kd.konfig.mock_modules

Contextmanager which replaces list of modules with ConfigDictProxyObjects.

kd.konfig.placeholder

Defines an entry in a ConfigDict that has no value yet.

kd.konfig.ref_copy

One-way recursive copy of the ConfigDict.

kd.konfig.ref_fn

Wrap a function for lazy-evaluation.

kd.konfig.register_aliases

Register module aliases for nicer display.

kd.konfig.register_default_values

Register default values when creating the ConfigDict.

kd.konfig.required

Defines a required attribute in the config that has no value yet.

kd.konfig.resolve

Recursively parses a nested ConfigDict and resolves module constructors.

kd.konfig.set_lazy_imported_modules

Set which modules inside with konfig.imports() will be lazy-imported.

kd.kontext.filter_by_path

Filters a context by a path.

kd.kontext.flatten_with_path

Flatten any PyTree / ConfigDict into a dict with ‘keys.like[0].this’.

kd.kontext.get_by_path

Get (nested) item or attribute by given path.

kd.kontext.get_keypaths

Return a dictionary mapping Key-annotated fieldnames to their paths.

kd.kontext.is_key_annotated

Check if a given class or instance has fields annotated with Key.

kd.kontext.path_builder_from

Create a path builder from a class.

kd.kontext.resolve_from_keyed_obj

Resolve the Key annotations of an object for given context.

kd.kontext.resolve_from_keypaths

Get values for key_paths from context with useful errors when failing.

kd.kontext.set_by_path

Mutate the obj to set the value.

kd.losses.compute_losses

Compute all losses based on given context.

kd.metrics.concat_field

Defines a AutoState data-field that is merged by concatenation.

kd.metrics.state_field

Defines a AutoState data-field that is merged by calling its merge method.

kd.metrics.static_field

Define an AutoState static field.

kd.metrics.sum_field

Define an AutoState data-field that is merged by summation (a + b).

kd.metrics.truncate_field

Defines a AutoState data-field that is merged by truncation.

kd.nn.convert_to_fourier_features

Convert inputs to Fourier features, e.g. for positional encoding.

kd.nn.interms_property

interms property that makes storing intermediates more convenient.

kd.nn.set_train_property

Set the self.is_training state to the given value.

kd.nn.train_property

is_training property.

kd.optim.decay_to_init

Add (params - init_params) scaled by weight_decay.

kd.optim.ema_params

Store an EMA version of model parameters.

kd.optim.exclude

Create a mask which selects all nodes except the ones matching the pattern.

kd.optim.named_chain

Wraps optax.named_chain and allows passing transformations as kwargs.

kd.optim.partial_updates

Applies the optimizer to a subset of the parameters.

kd.optim.select

Create a mask which selects only the sub-pytree matching the pattern.

kd.testing.assert_step_specs

Check the train step run correctly (fast).

kd.train.forward

Forward pass of the model.

kd.train.forward_with_loss

Forward pass of the model, including losses.

kd.typing.Dim

Helper to construct concrete Dim (for single-axis Shape).

kd.typing.check_type

Ensure that value matches expected_type, alias for typeguard.check_type.

kd.typing.enable_kd_type_checking

Enable custom type checking for Kauldron types.

kd.typing.set_shape

Validates the given shape and sets any previously unknown shapes.

kd.typing.typechecked

Decorator to enable runtime type-checking and shape-checking.

kd.xm.add_colab_artifacts

Add a link to the kd-infer colab.

kd.xm.add_log_artifacts

Add XManager artifacts for easy access to the Python logs.

kd.xm.add_tags_to_xm

Add tags to the xmanager experiment.

kd.xm.load_config_from_path

Loads a config from a path.

Attribute#

Typing#