# kd

[[[Source]]](https://github.com/google-research/kauldron/tree/main/kauldron/kd.py)

```{code-block}
import kauldron as kd
```

```{eval-rst}
.. automodule:: kauldron.kd
  :no-members:
```

## All symbols


### Module

|  |  |
--- | ---
[kd](index) | Kauldron public API.
[kd._filter_logs](_filter_logs/index) | Helper to filter logs from verbose modules.
[kd.ckpts](ckpts/index) | Checkpoints API.
[kd.ckpts.items](ckpts/items/index) | Checkpoint handler.
[kd.contrib](contrib/index) | Contrib public API.
[kd.data](data/index) | Data modules.
[kd.data.iters](data/iters/index) | Dataset iterators.
[kd.data.py](data/py/index) | PyGrain public API.
[kd.evals](evals/index) | Evaluator.
[kd.from_xid](from_xid/index) | Helper for loading configs,... from XManager experiments.
[kd.inspect](inspect/index) | Inspect utils.
[kd.kdash](kdash/index) | Small library for creating Flatboard dashboards.
[kd.knn](knn/index) | Wrapper aound `flax.linen.Module` to add torch-like API.
[kd.konfig](konfig/index) | Wrapper around `ConfigDict` to support auto-complete/type checking.
[kd.konfig._default_values](konfig/_default_values/index) | Default values and configuration.
[kd.kontext](kontext/index) | Kontext is a small self-contained library to manipulate nested trees.
[kd.losses](losses/index) | Losses.
[kd.metrics](metrics/index) | Metrics.
[kd.nn](nn/index) | Collection of nn.Modules to build neural networks.
[kd.optim](optim/index) | Optimizers etc.
[kd.random](random/index) | Small wrapper around `jax.random` for OO API.
[kd.summaries](summaries/index) | Summaries.
[kd.summaries.deprecated](summaries/deprecated/index) | Deprecated summaries.
[kd.testing](testing/index) | Testing utilities.
[kd.train](train/index) | Train.
[kd.typing](typing/index) | Common Typing Annotations.
[kd.utils](utils/index) | Utils public API.
[kd.xm](xm/index) | XManager utils.

### Class

|  |  |
--- | ---
[kd.ckpts.AbstractPartialLoader](ckpts/AbstractPartialLoader) | Abstract class for partial checkpoint loaders.
[kd.ckpts.Checkpointer](ckpts/Checkpointer) | Wrapper around Orbax CheckpointManager.
[kd.ckpts.MultiTransform](ckpts/MultiTransform) | Transform which applies multiple transformations sequentially.
[kd.ckpts.NoopCheckpointer](ckpts/NoopCheckpointer) | Does nothing.
[kd.ckpts.NoopTransform](ckpts/NoopTransform) | `init_transform` that does nothing.
[kd.ckpts.PartialKauldronLoader](ckpts/PartialKauldronLoader) | Partial loader for Kauldron checkpoints.
[kd.ckpts.items.CheckpointItem](ckpts/items/CheckpointItem) | Interface for a checkpoint item.
[kd.ckpts.items.StandardCheckpointItem](ckpts/items/StandardCheckpointItem) | Standard checkpoint item (for arbitrary `jax.Array` pytree).
[kd.ckpts.items.TopLevelCheckpointItem](ckpts/items/TopLevelCheckpointItem) | Checkpoint item that contains other sub-checkpoint items.
[kd.data.AddConstants](data/AddConstants) | Adds constant elements.
[kd.data.BatchSize](data/BatchSize) | Batch size.
[kd.data.Cast](data/Cast) | Cast an element to the specified dtype.
[kd.data.ElementWiseTransform](data/ElementWiseTransform) | Base class for elementwise transforms.
[kd.data.Elements](data/Elements) | Modify the elements by keeping xor dropping and/or renaming and/or copying.
[kd.data.FilterTransform](data/FilterTransform) | Abstract base class for filter transformations for individual elements.
[kd.data.Gather](data/Gather) | Gathers entries along a single dimension.
[kd.data.InMemoryPipeline](data/InMemoryPipeline) | Pipeline which fit in memory.
[kd.data.IterableDataset](data/IterableDataset) | General interface for iterable datasets.
[kd.data.MapTransform](data/MapTransform) | Abstract base class for all 1:1 transformations of elements.
[kd.data.Pipeline](data/Pipeline) | Base class for kauldron data pipelines.
[kd.data.Rearrange](data/Rearrange) | Einops rearrange on a single element.
[kd.data.Resize](data/Resize) | Resizes an image.
[kd.data.TreeFlattenWithPath](data/TreeFlattenWithPath) | Flatten any tree-structured elements.
[kd.data.ValueRange](data/ValueRange) | Map the value range of an element.
[kd.data.iters.Iterator](data/iters/Iterator) | Wrapper around a dataset iterator.
[kd.data.iters.NonCheckpointableIterator](data/iters/NonCheckpointableIterator) | Handler that is not-checkpointable.
[kd.data.iters.PyGrainIterator](data/iters/PyGrainIterator) | PyGrain iterator.
[kd.data.iters.TFDataIterator](data/iters/TFDataIterator) | Checkpointable `tf.data` iterator.
[kd.data.py.AddConstants](data/py/AddConstants) | Adds constant elements.
[kd.data.py.Cast](data/py/Cast) | Cast an element to the specified dtype.
[kd.data.py.DataSource](data/py/DataSource) | Generic loader of arbitrary grain data source.
[kd.data.py.DataSourceBase](data/py/DataSourceBase) | Base class to implement a data source.
[kd.data.py.ElementWiseTransform](data/py/ElementWiseTransform) | Base class for elementwise transforms.
[kd.data.py.Elements](data/py/Elements) | Modify the elements by keeping xor dropping and/or renaming and/or copying.
[kd.data.py.Gather](data/py/Gather) | Gathers entries along a single dimension.
[kd.data.py.HuggingFace](data/py/HuggingFace) | HuggingFace loader.
[kd.data.py.Json](data/py/Json) | Json pipeline.
[kd.data.py.Mix](data/py/Mix) | Create a dataset mixture.
[kd.data.py.PyGrainPipeline](data/py/PyGrainPipeline) | Abstract base class to construct PyGrain data pipeline.
[kd.data.py.Rearrange](data/py/Rearrange) | Einops rearrange on a single element.
[kd.data.py.Resize](data/py/Resize) | Resizes an image.
[kd.data.py.SliceDataset](data/py/SliceDataset) | Transform which select a subset of the dataset.
[kd.data.py.Tfds](data/py/Tfds) | Base TFDS loader.
[kd.data.py.TreeFlattenWithPath](data/py/TreeFlattenWithPath) | Flatten any tree-structured elements.
[kd.data.py.ValueRange](data/py/ValueRange) | Map the value range of an element.
[kd.evals.CollectionKeys](evals/CollectionKeys) | Names of the metrics/summaries/losses (displayed in flatboard).
[kd.evals.Evaluator](evals/Evaluator) | Evaluator running `num_batches` times.
[kd.evals.EvaluatorBase](evals/EvaluatorBase) | Base class for inline evaluators.
[kd.evals.EveryNSteps](evals/EveryNSteps) | Run eval every N train steps.
[kd.evals.FewShotEvaluator](evals/FewShotEvaluator) | FewShotEvaluator running closed-form few-shot classification.
[kd.evals.Once](evals/Once) | Run eval only after the `XX` train steps.
[kd.evals.RunStrategy](evals/RunStrategy) | Base class for info on how to run the evaluation.
[kd.evals.StandaloneEveryCheckpoint](evals/StandaloneEveryCheckpoint) | Run eval continuously everytime a new checkpoint is found.
[kd.evals.StandaloneLastCheckpoint](evals/StandaloneLastCheckpoint) | Run eval only after the last checkpoint, after train has completed.
[kd.inspect.Profiler](inspect/Profiler) | `kd.inspect.Profiler`.
[kd.kdash.BuildContext](kdash/BuildContext) | Context for building the dashboard.
[kd.kdash.DashboardsBase](kdash/DashboardsBase) | Flatboard dashboard structure.
[kd.kdash.MetricDashboards](kdash/MetricDashboards) | Standard `metrics` & `losses` dashboards for a single collection.
[kd.kdash.MultiDashboards](kdash/MultiDashboards) | Container of multiple dashboards.
[kd.kdash.NoopDashboard](kdash/NoopDashboard) | Empty dashboard.
[kd.kdash.Plot](kdash/Plot) | Single plot inside a dashboard.
[kd.kdash.SingleDashboard](kdash/SingleDashboard) | Single dashboard containing multiple plots.
[kd.knn.Dense](knn/Dense) | Dense(features: int, use_bias: bool = True, dtype: Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any, NoneType] = None, param_dtype: Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any] = <class 'jax.numpy.float32'>, precision: Union[NoneType, str, jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]] = None, kernel_init: Union[jax.nn.initializers.Initializer, collections.abc.Callable[..., Any]] = <function variance_scaling.<locals>.init at 0x7824acfd7f60>, bias_init: Union[jax.nn.initializers.Initializer, collections.abc.Callable[..., Any]] = <function zeros at 0x7824b06c2f20>, promote_dtype: flax.linen.linear.PromoteDtypeFn = <function promote_dtype at 0x7824ad016d40>, dot_general: collections.abc.Callable[..., typing.Union[jax.Array, typing.Any]] | None = None, dot_general_cls: Any = None, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object at 0x7824acff2d90>, name: Optional[str] = None, *, _kd_state: 'Optional[_ModuleState]' = None)
[kd.knn.Dropout](knn/Dropout) | Dropout(rate: float, broadcast_dims: collections.abc.Sequence[int] = (), deterministic: bool | None = None, rng_collection: str = 'dropout', parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object at 0x7824acff2d90>, name: Optional[str] = None, *, _kd_state: 'Optional[_ModuleState]' = None)
[kd.knn.Module](knn/Module) | Base Module class.
[kd.knn.Sequential](knn/Sequential) | Sequential(layers: collections.abc.Sequence[collections.abc.Callable[..., typing.Any]], parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object at 0x7824acff2d90>, name: Optional[str] = None, *, _kd_state: 'Optional[_ModuleState]' = None)
[kd.konfig.ConfigDict](konfig/ConfigDict) | Wrapper around ConfigDict.
[kd.konfig.WithRef](konfig/WithRef) | Protocol to better access lazy fields.
[kd.kontext.Context](kontext/Context) | 
[kd.kontext.GlobPath](kontext/GlobPath) | Represents a string path.
[kd.kontext.Path](kontext/Path) | Represents a (non-glob) string path.
[kd.losses.AbsoluteValue](losses/AbsoluteValue) | Absolute value loss.
[kd.losses.Huber](losses/Huber) | Huber loss.
[kd.losses.L1](losses/L1) | L1 loss.
[kd.losses.L2](losses/L2) | L2 loss.
[kd.losses.Loss](losses/Loss) | Base class for losses which handles masks, averaging, and loss-weight.
[kd.losses.NegativeCosineSimilarity](losses/NegativeCosineSimilarity) | Negative Cosine Similarity loss.
[kd.losses.SigmoidBinaryCrossEntropy](losses/SigmoidBinaryCrossEntropy) | Sigmoid cross-entropy loss with binary labels.
[kd.losses.SoftmaxCrossEntropy](losses/SoftmaxCrossEntropy) | Softmax cross-entropy loss.
[kd.losses.SoftmaxCrossEntropyWithIntLabels](losses/SoftmaxCrossEntropyWithIntLabels) | Softmax cross-entropy loss with integer labels.
[kd.losses.Value](losses/Value) | Value loss.
[kd.metrics.Accuracy](metrics/Accuracy) | Classification Accuracy.
[kd.metrics.Ari](metrics/Ari) | Adjusted Rand Index (ARI) computed from predictions and labels.
[kd.metrics.AutoState](metrics/AutoState) | Flexible base class for conveniently defining custom states.
[kd.metrics.AverageState](metrics/AverageState) | Computes the average of a scalar or a batch of tensors.
[kd.metrics.BinaryAccuracy](metrics/BinaryAccuracy) | Classification Accuracy for Binary classification tasks.
[kd.metrics.CollectFirstState](metrics/CollectFirstState) | Get the first outputs (possibly) across multiple steps (no reducing).
[kd.metrics.CollectingState](metrics/CollectingState) | Accumulate outputs across multiple steps (without reducing).
[kd.metrics.EmptyState](metrics/EmptyState) | Empty state.
[kd.metrics.LpipsVgg](metrics/LpipsVgg) | VGG LPIPS.
[kd.metrics.Metric](metrics/Metric) | Base class for metrics.
[kd.metrics.NoopMetric](metrics/NoopMetric) | Metric that does nothing. Can be used in sweeps to remove a metric.
[kd.metrics.Norm](metrics/Norm) | Wraps jnp.linalg.norm to compute the average norm for given tensors.
[kd.metrics.Precision1](metrics/Precision1) | Precision@1 for multilabel classification.
[kd.metrics.Psnr](metrics/Psnr) | PSNR.
[kd.metrics.RocAuc](metrics/RocAuc) | Area Under the Receiver Operating Characteristic Curve (ROC AUC).
[kd.metrics.SingleDimension](metrics/SingleDimension) | Returns a single chosen dimension of the tensor.
[kd.metrics.SkipIfMissing](metrics/SkipIfMissing) | Skip this metric if any of the keys are missing.
[kd.metrics.Ssim](metrics/Ssim) | Structural similarity (SSIM).
[kd.metrics.State](metrics/State) | Base metric state class.
[kd.metrics.Std](metrics/Std) | Compute the standard deviation for float values.
[kd.metrics.TreeMap](metrics/TreeMap) | Maps an inner metric to a pytree and returns a pytree of results.
[kd.metrics.TreeReduce](metrics/TreeReduce) | Applies a metric to a pytree and returns the aggregated result.
[kd.nn.AddEmbedding](nn/AddEmbedding) | Helper Module for adding a PositionEmbedding e.g. in a `knn.Sequential`.
[kd.nn.AddLearnedEmbedding](nn/AddLearnedEmbedding) | Adds learned positional embeddings to the inputs.
[kd.nn.AttentionModule](nn/AttentionModule) | Interface specification for Attention modules.
[kd.nn.Dropout](nn/Dropout) | Wrapper around `nn.Dropout` but using `kd.nn.train_property`.
[kd.nn.DummyModel](nn/DummyModel) | Empty model that ignores inputs and always produces a single logit of 42.
[kd.nn.ExternalModule](nn/ExternalModule) | Module that is defined outside Kauldron.
[kd.nn.FlatAutoencoder](nn/FlatAutoencoder) | Very simple auto-encoder class to showcase using keys and submodules.
[kd.nn.FourierEmbedding](nn/FourierEmbedding) | Apply Fourier position embedding to a grid of coordinates.
[kd.nn.Identity](nn/Identity) | Module that applies the identity function to a single tensor.
[kd.nn.ImageTokenizer](nn/ImageTokenizer) | Interface for modules that convert images into tokens.
[kd.nn.ImprovedMultiHeadDotProductAttention](nn/ImprovedMultiHeadDotProductAttention) | Multi-head dot-product attention.
[kd.nn.LearnedEmbedding](nn/LearnedEmbedding) | Learned positional embeddings.
[kd.nn.MultiHeadDotProductAttention](nn/MultiHeadDotProductAttention) | Wrapper around `nn.MultiHeadDotProductAttention` using `knn.train_property`.
[kd.nn.NormModule](nn/NormModule) | Interface specification for norm modules (to be used as type annotation).
[kd.nn.ParallelAttentionBlock](nn/ParallelAttentionBlock) | Parallel self attention (see Vit22B paper: arxiv.org/abs/2302.05442).
[kd.nn.Patchify](nn/Patchify) | Patchify an image, as in ViT (without linear embedding).
[kd.nn.PatchifyEmbed](nn/PatchifyEmbed) | Patchify and linearly embed and image, as in ViT.
[kd.nn.PostNormBlock](nn/PostNormBlock) | Post-LN Transformer layer (not recommended).
[kd.nn.PreNormBlock](nn/PreNormBlock) | Pre-LN Transformer layer (default transformer layer).
[kd.nn.Rearrange](nn/Rearrange) | Wrapper around `einops.rearrange` for usage e.g. in `nn.Sequential`.
[kd.nn.Reduce](nn/Reduce) | Wrapper around `einops.reduce` for usage e.g. in `nn.Sequential`.
[kd.nn.Sequential](nn/Sequential) | Like nn.Sequential but allows configuring input and output keys.
[kd.nn.TransformerBlock](nn/TransformerBlock) | Interface definition for transformer blocks (for use in type annotations).
[kd.nn.TransformerMLP](nn/TransformerMLP) | Simple MLP with a single hidden layer for use in Transformer blocks.
[kd.nn.Vit](nn/Vit) | Basic Vision Transformer classifer with GAP.
[kd.nn.VitEncoder](nn/VitEncoder) | Basic Vit Encoder.
[kd.nn.WrapperModule](nn/WrapperModule) | Base class to wrapper a module.
[kd.nn.ZeroEmbedding](nn/ZeroEmbedding) | Embedding that returns zero (for deactivating position embeddings).
[kd.random.PRNGKey](random/PRNGKey) | Small wrapper around `jax.random` key arrays to reduce boilerplate.
[kd.summaries.Histogram](summaries/Histogram) | Output type for histogram summaries.
[kd.summaries.HistogramSummary](summaries/HistogramSummary) | Basic histogram summary.
[kd.summaries.PointCloud](summaries/PointCloud) | Output type for point cloud summaries.
[kd.summaries.ShowBoxes](summaries/ShowBoxes) | Show a set of boxes with optional image reshaping.
[kd.summaries.ShowDifferenceImages](summaries/ShowDifferenceImages) | Show a set of difference images with optional reshaping.
[kd.summaries.ShowImages](summaries/ShowImages) | Show image summaries with optional reshaping.
[kd.summaries.ShowPointCloud](summaries/ShowPointCloud) | Show a point cloud with optional reshaping.
[kd.summaries.ShowSegmentations](summaries/ShowSegmentations) | Show a set of segmentations with optional reshaping.
[kd.summaries.ShowTexts](summaries/ShowTexts) | Show texts.
[kd.summaries.deprecated.ImageSummary](summaries/deprecated/ImageSummary) | Deprecated ImageSummary. Raises an error if instantiated.
[kd.summaries.deprecated.Summary](summaries/deprecated/Summary) | Deprecated Summaries. Raises an error if instantiated.
[kd.train.Auxiliaries](train/Auxiliaries) | Wrapper around the losses, summaries and metrics.
[kd.train.AuxiliariesOutput](train/AuxiliariesOutput) | Auxiliaries final values (after merge and compute).
[kd.train.AuxiliariesState](train/AuxiliariesState) | Auxiliaries (intermediate states to be accumulated).
[kd.train.Context](train/Context) | Namespace for retrieving information with path-based keys.
[kd.train.RngStream](train/RngStream) | Info on one `rng` stream.
[kd.train.RngStreams](train/RngStreams) | Manager of rng streams.
[kd.train.Setup](train/Setup) | Setup/environment options.
[kd.train.TqdmInfo](train/TqdmInfo) | TqdmInfo(*, desc: 'str' = 'train', log_xm: 'bool' = True)
[kd.train.TrainState](train/TrainState) | Data structure for checkpointing the model.
[kd.train.TrainStep](train/TrainStep) | Base Training Step.
[kd.train.Trainer](train/Trainer) | Base trainer class.
[kd.typing.Any](typing/Any) | 
[kd.typing.Array](typing/Array) | 
[kd.typing.ArraySpec](typing/ArraySpec) | Describes an array via it's dtype and shape.
[kd.typing.AxisName](typing/AxisName) | 
[kd.typing.Bool](typing/Bool) | 
[kd.typing.Complex](typing/Complex) | 
[kd.typing.Complex64](typing/Complex64) | 
[kd.typing.Float](typing/Float) | 
[kd.typing.Float32](typing/Float32) | 
[kd.typing.Hashable](typing/Hashable) | 
[kd.typing.Int](typing/Int) | 
[kd.typing.Integer](typing/Integer) | 
[kd.typing.Memo](typing/Memo) | Jaxtyping information about the shapes in the current scope.
[kd.typing.Num](typing/Num) | 
[kd.typing.Sequence](typing/Sequence) | All the operations on a read-only sequence.
[kd.typing.Shape](typing/Shape) | Helper to construct concrete shape tuples from shape-specs.
[kd.typing.TfArray](typing/TfArray) | 
[kd.typing.TfFloat](typing/TfFloat) | 
[kd.typing.TfFloat32](typing/TfFloat32) | 
[kd.typing.TfInt](typing/TfInt) | 
[kd.typing.TfUInt8](typing/TfUInt8) | 
[kd.typing.TypeCheckError](typing/TypeCheckError) | Indicates a runtime typechecking error from the @typechecked decorator.
[kd.typing.UInt32](typing/UInt32) | 
[kd.typing.UInt8](typing/UInt8) | 
[kd.typing.XArray](typing/XArray) | 
[kd.xm.Experiment](xm/Experiment) | XManager experiment wrapper.
[kd.xm.WorkUnit](xm/WorkUnit) | XManager work unit wrapper.

### Function

|  |  |
--- | ---
[kd._filter_logs.add_filter](_filter_logs/add_filter) | Add a filter to the absl logging handler.
[kd.ckpts.workdir_from_xid](ckpts/workdir_from_xid) | 
[kd.from_xid.get_cfg](from_xid/get_cfg) | Returns the config/sub-config from an xmanager experiment.
[kd.from_xid.get_element_spec](from_xid/get_element_spec) | Returns the element_spec of the train dataset of an xmanager experiment.
[kd.from_xid.get_resolved](from_xid/get_resolved) | Returns the resolved config/sub-config from an xmanager experiment.
[kd.from_xid.get_workdir](from_xid/get_workdir) | Returns the workdir of an xmanager experiment.
[kd.inspect.get_batch_stats](inspect/get_batch_stats) | Return `pd.DataFrame` containing the batch stats.
[kd.inspect.get_colab_model_overview](inspect/get_colab_model_overview) | Return `pd.DataFrame` for displaying the model params, inputs,...
[kd.inspect.get_connection_graph](inspect/get_connection_graph) | Build the graphviz.
[kd.inspect.json_spec_like](inspect/json_spec_like) | Convert `etree.spec_like` output to json and displays it in colab form.
[kd.inspect.lower_trainstep](inspect/lower_trainstep) | Returns lowered trainerstep.step.
[kd.inspect.plot_batch](inspect/plot_batch) | Display batch images.
[kd.inspect.plot_context](inspect/plot_context) | Display the context structure.
[kd.inspect.plot_schedules](inspect/plot_schedules) | Overview plot for (nested) dict of schedules.
[kd.inspect.plot_sharding](inspect/plot_sharding) | Plot sharding.
[kd.inspect.show_trainer_info](inspect/show_trainer_info) | Display various plot on the trainer.
[kd.kdash.build_and_upload](kdash/build_and_upload) | Create the dashboards.
[kd.knn.convert](knn/convert) | Decorator that convert a flax class into klinen.
[kd.konfig.DEFINE_config_file](konfig/DEFINE_config_file) | Defines flag for `ConfigDict`.
[kd.konfig.imports](konfig/imports) | Contextmanager which replace import statements by configdicts.
[kd.konfig.mock_modules](konfig/mock_modules) | Contextmanager which replaces list of modules with ConfigDictProxyObjects.
[kd.konfig.placeholder](konfig/placeholder) | Defines an entry in a ConfigDict that has no value yet.
[kd.konfig.ref_copy](konfig/ref_copy) | One-way recursive copy of the `ConfigDict`.
[kd.konfig.ref_fn](konfig/ref_fn) | Wrap a function for lazy-evaluation.
[kd.konfig.register_aliases](konfig/register_aliases) | Register module aliases for nicer display.
[kd.konfig.register_default_values](konfig/register_default_values) | Register default values when creating the ConfigDict.
[kd.konfig.required](konfig/required) | Defines a required attribute in the config that has no value yet.
[kd.konfig.resolve](konfig/resolve) | Recursively parses a nested ConfigDict and resolves module constructors.
[kd.konfig.set_lazy_imported_modules](konfig/set_lazy_imported_modules) | Set which modules inside `with konfig.imports()` will be lazy-imported.
[kd.kontext.filter_by_path](kontext/filter_by_path) | Filters a context by a path.
[kd.kontext.flatten_with_path](kontext/flatten_with_path) | Flatten any PyTree / ConfigDict into a dict with 'keys.like[0].this'.
[kd.kontext.get_by_path](kontext/get_by_path) | Get (nested) item or attribute by given path.
[kd.kontext.get_keypaths](kontext/get_keypaths) | Return a dictionary mapping Key-annotated fieldnames to their paths.
[kd.kontext.is_key_annotated](kontext/is_key_annotated) | Check if a given class or instance has fields annotated with `Key`.
[kd.kontext.path_builder_from](kontext/path_builder_from) | Create a path builder from a class.
[kd.kontext.resolve_from_keyed_obj](kontext/resolve_from_keyed_obj) | Resolve the Key annotations of an object for given context.
[kd.kontext.resolve_from_keypaths](kontext/resolve_from_keypaths) | Get values for key_paths from context with useful errors when failing.
[kd.kontext.set_by_path](kontext/set_by_path) | Mutate the `obj` to set the value.
[kd.losses.compute_losses](losses/compute_losses) | Compute all losses based on given context.
[kd.metrics.concat_field](metrics/concat_field) | Defines a AutoState data-field that is merged by concatenation.
[kd.metrics.static_field](metrics/static_field) | Define an AutoState static field.
[kd.metrics.sum_field](metrics/sum_field) | Define an AutoState data-field that is merged by summation (a + b).
[kd.metrics.truncate_field](metrics/truncate_field) | Defines a AutoState data-field that is merged by truncation.
[kd.nn.convert_to_fourier_features](nn/convert_to_fourier_features) | Convert inputs to Fourier features, e.g. for positional encoding.
[kd.nn.interms_property](nn/interms_property) | `interms` property that makes storing intermediates more convenient.
[kd.nn.set_train_property](nn/set_train_property) | Set the `self.is_training` state to the given value.
[kd.nn.train_property](nn/train_property) | `is_training` property.
[kd.optim.decay_to_init](optim/decay_to_init) | Add (params - init_params) scaled by `weight_decay`.
[kd.optim.exclude](optim/exclude) | Create a mask which selects all nodes except the ones matching the pattern.
[kd.optim.named_chain](optim/named_chain) | Wraps optax.named_chain and allows passing transformations as kwargs.
[kd.optim.partial_updates](optim/partial_updates) | Applies the optimizer to a subset of the parameters.
[kd.optim.select](optim/select) | Create a mask which selects only the sub-pytree matching the pattern.
[kd.testing.assert_step_specs](testing/assert_step_specs) | Check the train step run correctly (fast).
[kd.train.forward](train/forward) | Forward pass of the model.
[kd.train.forward_with_loss](train/forward_with_loss) | Forward pass of the model, including losses.
[kd.typing.Dim](typing/Dim) | Helper to construct concrete Dim (for single-axis Shape).
[kd.typing.check_type](typing/check_type) | Ensure that value matches expected_type, alias for typeguard.check_type.
[kd.typing.enable_kd_type_checking](typing/enable_kd_type_checking) | Enable custom type checking for Kauldron types.
[kd.typing.set_shape](typing/set_shape) | Validates the given shape and sets any previously unknown shapes.
[kd.typing.typechecked](typing/typechecked) | Decorator to enable runtime type-checking and shape-checking.
[kd.xm.add_colab_artifacts](xm/add_colab_artifacts) | Add a link to the kd-infer colab.
[kd.xm.add_log_artifacts](xm/add_log_artifacts) | Add XManager artifacts for easy access to the Python logs.
[kd.xm.add_tags_to_xm](xm/add_tags_to_xm) | Add tags to the xmanager experiment.
[kd.xm.load_config_from_path](xm/load_config_from_path) | Loads a config from a path.

### Attribute

|  |  |
--- | ---
[kd.kontext.REQUIRED](kontext/REQUIRED) | 
[kd.sharding](sharding) | 
[kd.typing.Union](typing/Union) | 
[kd.xm.xmanager_api](xm/xmanager_api) | 

### Typing

|  |  |
--- | ---
[kd.knn.Intermediate](knn/Intermediate) | 
[kd.konfig.ConfigDictLike](konfig/ConfigDictLike) | 
[kd.kontext.Key](kontext/Key) | 
[kd.kontext.KeyTree](kontext/KeyTree) | 
[kd.typing.Axes](typing/Axes) | 
[kd.typing.Callable](typing/Callable) | 
[kd.typing.DType](typing/DType) | 
[kd.typing.ElementSpec](typing/ElementSpec) | 
[kd.typing.Initializer](typing/Initializer) | 
[kd.typing.PRNGKey](typing/PRNGKey) | 
[kd.typing.PRNGKeyLike](typing/PRNGKeyLike) | 
[kd.typing.PyTree](typing/PyTree) | 
[kd.typing.Scalar](typing/Scalar) | 
[kd.typing.ScalarFloat](typing/ScalarFloat) | 
[kd.typing.ScalarInt](typing/ScalarInt) | 
[kd.typing.Schedule](typing/Schedule) | 


```{toctree}
:hidden:

_filter_logs/index
ckpts/index
contrib/index
data/index
evals/index
from_xid/index
inspect/index
kdash/index
knn/index
konfig/index
kontext/index
losses/index
metrics/index
nn/index
optim/index
random/index
sharding
summaries/index
testing/index
train/index
typing/index
utils/index
xm/index
```