kd#
import kauldron as kd
Kauldron public API.
from kauldron import kd
All symbols#
Module#
Kauldron public API. |
|
Helper to filter logs from verbose modules. |
|
Checkpoints API. |
|
Checkpoint handler. |
|
Contrib public API. |
|
Data modules. |
|
Dataset iterators. |
|
PyGrain public API. |
|
Evaluator. |
|
Helper for loading configs,… from XManager experiments. |
|
Inspect utils. |
|
Small library for creating Flatboard dashboards. |
|
Wrapper aound |
|
Wrapper around |
|
Default values and configuration. |
|
Kontext is a small self-contained library to manipulate nested trees. |
|
Losses. |
|
Metrics. |
|
Collection of nn.Modules to build neural networks. |
|
Optimizers etc. |
|
Small wrapper around |
|
Summaries. |
|
Deprecated summaries. |
|
Testing utilities. |
|
Train. |
|
Common Typing Annotations. |
|
Utils public API. |
|
XManager utils. |
Class#
Abstract class for partial checkpoint loaders. |
|
Wrapper around Orbax CheckpointManager. |
|
Transform which applies multiple transformations sequentially. |
|
Does nothing. |
|
|
|
Partial loader for Kauldron checkpoints. |
|
Interface for a checkpoint item. |
|
Standard checkpoint item (for arbitrary |
|
Checkpoint item that contains other sub-checkpoint items. |
|
Adds constant elements. |
|
Batch size. |
|
Cast an element to the specified dtype. |
|
Base class for elementwise transforms. |
|
Modify the elements by keeping xor dropping and/or renaming and/or copying. |
|
Abstract base class for filter transformations for individual elements. |
|
Gathers entries along a single dimension. |
|
Pipeline which fit in memory. |
|
General interface for iterable datasets. |
|
Abstract base class for all 1:1 transformations of elements. |
|
Base class for kauldron data pipelines. |
|
Einops rearrange on a single element. |
|
Resizes an image. |
|
Flatten any tree-structured elements. |
|
Map the value range of an element. |
|
Wrapper around a dataset iterator. |
|
Handler that is not-checkpointable. |
|
PyGrain iterator. |
|
Checkpointable |
|
Adds constant elements. |
|
Cast an element to the specified dtype. |
|
Generic loader of arbitrary grain data source. |
|
Base class to implement a data source. |
|
Base class for elementwise transforms. |
|
Modify the elements by keeping xor dropping and/or renaming and/or copying. |
|
Gathers entries along a single dimension. |
|
HuggingFace loader. |
|
Json pipeline. |
|
Create a dataset mixture. |
|
Abstract base class to construct PyGrain data pipeline. |
|
Einops rearrange on a single element. |
|
Resizes an image. |
|
Transform which select a subset of the dataset. |
|
Base TFDS loader. |
|
Flatten any tree-structured elements. |
|
Map the value range of an element. |
|
Names of the metrics/summaries/losses (displayed in flatboard). |
|
Evaluator running |
|
Base class for inline evaluators. |
|
Run eval every N train steps. |
|
FewShotEvaluator running closed-form few-shot classification. |
|
Run eval only after the |
|
Base class for info on how to run the evaluation. |
|
Run eval continuously everytime a new checkpoint is found. |
|
Run eval only after the last checkpoint, after train has completed. |
|
Context for building the dashboard. |
|
Flatboard dashboard structure. |
|
Standard |
|
Container of multiple dashboards. |
|
Empty dashboard. |
|
Single plot inside a dashboard. |
|
Single dashboard containing multiple plots. |
|
Dense(features: int, use_bias: bool = True, dtype: Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any, NoneType] = None, param_dtype: Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any] = <class ‘jax.numpy.float32’>, precision: Union[NoneType, str, jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]] = None, kernel_init: Union[jax.nn.initializers.Initializer, collections.abc.Callable[…, Any]] = <function variance_scaling. |
|
Dropout(rate: float, broadcast_dims: collections.abc.Sequence[int] = (), deterministic: bool |
|
Base Module class. |
|
Sequential(layers: collections.abc.Sequence[collections.abc.Callable[…, typing.Any]], parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object at 0x7824acff2d90>, name: Optional[str] = None, *, _kd_state: ‘Optional[_ModuleState]’ = None) |
|
Wrapper around ConfigDict. |
|
Protocol to better access lazy fields. |
|
Represents a string path. |
|
Represents a (non-glob) string path. |
|
Absolute value loss. |
|
Huber loss. |
|
L1 loss. |
|
L2 loss. |
|
Base class for losses which handles masks, averaging, and loss-weight. |
|
Negative Cosine Similarity loss. |
|
Sigmoid cross-entropy loss with binary labels. |
|
Softmax cross-entropy loss. |
|
Softmax cross-entropy loss with integer labels. |
|
Value loss. |
|
Classification Accuracy. |
|
Adjusted Rand Index (ARI) computed from predictions and labels. |
|
Flexible base class for conveniently defining custom states. |
|
Computes the average of a scalar or a batch of tensors. |
|
Classification Accuracy for Binary classification tasks. |
|
Get the first outputs (possibly) across multiple steps (no reducing). |
|
Accumulate outputs across multiple steps (without reducing). |
|
Empty state. |
|
VGG LPIPS. |
|
Base class for metrics. |
|
Metric that does nothing. Can be used in sweeps to remove a metric. |
|
Wraps jnp.linalg.norm to compute the average norm for given tensors. |
|
Precision@1 for multilabel classification. |
|
PSNR. |
|
Area Under the Receiver Operating Characteristic Curve (ROC AUC). |
|
Returns a single chosen dimension of the tensor. |
|
Skip this metric if any of the keys are missing. |
|
Structural similarity (SSIM). |
|
Base metric state class. |
|
Compute the standard deviation for float values. |
|
Maps an inner metric to a pytree and returns a pytree of results. |
|
Applies a metric to a pytree and returns the aggregated result. |
|
Helper Module for adding a PositionEmbedding e.g. in a |
|
Adds learned positional embeddings to the inputs. |
|
Interface specification for Attention modules. |
|
Wrapper around |
|
Empty model that ignores inputs and always produces a single logit of 42. |
|
Module that is defined outside Kauldron. |
|
Very simple auto-encoder class to showcase using keys and submodules. |
|
Apply Fourier position embedding to a grid of coordinates. |
|
Module that applies the identity function to a single tensor. |
|
Interface for modules that convert images into tokens. |
|
Multi-head dot-product attention. |
|
Learned positional embeddings. |
|
Wrapper around |
|
Interface specification for norm modules (to be used as type annotation). |
|
Parallel self attention (see Vit22B paper: arxiv.org/abs/2302.05442). |
|
Patchify an image, as in ViT (without linear embedding). |
|
Patchify and linearly embed and image, as in ViT. |
|
Post-LN Transformer layer (not recommended). |
|
Pre-LN Transformer layer (default transformer layer). |
|
Wrapper around |
|
Wrapper around |
|
Like nn.Sequential but allows configuring input and output keys. |
|
Interface definition for transformer blocks (for use in type annotations). |
|
Simple MLP with a single hidden layer for use in Transformer blocks. |
|
Basic Vision Transformer classifer with GAP. |
|
Basic Vit Encoder. |
|
Base class to wrapper a module. |
|
Embedding that returns zero (for deactivating position embeddings). |
|
Small wrapper around |
|
Output type for histogram summaries. |
|
Basic histogram summary. |
|
Output type for point cloud summaries. |
|
Show a set of boxes with optional image reshaping. |
|
Show a set of difference images with optional reshaping. |
|
Show image summaries with optional reshaping. |
|
Show a point cloud with optional reshaping. |
|
Show a set of segmentations with optional reshaping. |
|
Show texts. |
|
Deprecated ImageSummary. Raises an error if instantiated. |
|
Deprecated Summaries. Raises an error if instantiated. |
|
Wrapper around the losses, summaries and metrics. |
|
Auxiliaries final values (after merge and compute). |
|
Auxiliaries (intermediate states to be accumulated). |
|
Namespace for retrieving information with path-based keys. |
|
Info on one |
|
Manager of rng streams. |
|
Setup/environment options. |
|
TqdmInfo(*, desc: ‘str’ = ‘train’, log_xm: ‘bool’ = True) |
|
Data structure for checkpointing the model. |
|
Base Training Step. |
|
Base trainer class. |
|
Describes an array via it’s dtype and shape. |
|
Jaxtyping information about the shapes in the current scope. |
|
All the operations on a read-only sequence. |
|
Helper to construct concrete shape tuples from shape-specs. |
|
Indicates a runtime typechecking error from the @typechecked decorator. |
|
XManager experiment wrapper. |
|
XManager work unit wrapper. |
Function#
Add a filter to the absl logging handler. |
|
Returns the config/sub-config from an xmanager experiment. |
|
Returns the element_spec of the train dataset of an xmanager experiment. |
|
Returns the resolved config/sub-config from an xmanager experiment. |
|
Returns the workdir of an xmanager experiment. |
|
Return |
|
Return |
|
Build the graphviz. |
|
Convert |
|
Returns lowered trainerstep.step. |
|
Display batch images. |
|
Display the context structure. |
|
Overview plot for (nested) dict of schedules. |
|
Plot sharding. |
|
Display various plot on the trainer. |
|
Create the dashboards. |
|
Decorator that convert a flax class into klinen. |
|
Defines flag for |
|
Contextmanager which replace import statements by configdicts. |
|
Contextmanager which replaces list of modules with ConfigDictProxyObjects. |
|
Defines an entry in a ConfigDict that has no value yet. |
|
One-way recursive copy of the |
|
Wrap a function for lazy-evaluation. |
|
Register module aliases for nicer display. |
|
Register default values when creating the ConfigDict. |
|
Defines a required attribute in the config that has no value yet. |
|
Recursively parses a nested ConfigDict and resolves module constructors. |
|
Set which modules inside |
|
Filters a context by a path. |
|
Flatten any PyTree / ConfigDict into a dict with ‘keys.like[0].this’. |
|
Get (nested) item or attribute by given path. |
|
Return a dictionary mapping Key-annotated fieldnames to their paths. |
|
Check if a given class or instance has fields annotated with |
|
Create a path builder from a class. |
|
Resolve the Key annotations of an object for given context. |
|
Get values for key_paths from context with useful errors when failing. |
|
Mutate the |
|
Compute all losses based on given context. |
|
Defines a AutoState data-field that is merged by concatenation. |
|
Define an AutoState static field. |
|
Define an AutoState data-field that is merged by summation (a + b). |
|
Defines a AutoState data-field that is merged by truncation. |
|
Convert inputs to Fourier features, e.g. for positional encoding. |
|
|
|
Set the |
|
|
|
Add (params - init_params) scaled by |
|
Create a mask which selects all nodes except the ones matching the pattern. |
|
Wraps optax.named_chain and allows passing transformations as kwargs. |
|
Applies the optimizer to a subset of the parameters. |
|
Create a mask which selects only the sub-pytree matching the pattern. |
|
Check the train step run correctly (fast). |
|
Forward pass of the model. |
|
Forward pass of the model, including losses. |
|
Helper to construct concrete Dim (for single-axis Shape). |
|
Ensure that value matches expected_type, alias for typeguard.check_type. |
|
Enable custom type checking for Kauldron types. |
|
Validates the given shape and sets any previously unknown shapes. |
|
Decorator to enable runtime type-checking and shape-checking. |
|
Add a link to the kd-infer colab. |
|
Add XManager artifacts for easy access to the Python logs. |
|
Add tags to the xmanager experiment. |
|
Loads a config from a path. |