kd.nn.VitEncoder

kd.nn.VitEncoder#

class kauldron.modules.VitEncoder(layers: typing.Sequence[kauldron.modules.knn_types.TransformerBlock], embedding: kauldron.modules.knn_types.ImageTokenizer, pos_embedding: kauldron.modules.knn_types.PositionEmbedding = LearnedEmbedding(     # attributes     dtype = float32     emb_init = init     emb_name = 'embeddings' ), encoder_norm: kauldron.modules.knn_types.NormModule | None = RMSNorm(     # attributes     epsilon = 1e-06     dtype = None     param_dtype = float32     use_scale = True     scale_init = ones     reduction_axes = -1     feature_axes = -1     axis_name = None     axis_index_groups = None     use_fast_variance = True     force_float32_reductions = True ), prepend_cls_token: bool = False, cls_token_init: typing.Callable[[jaxtyping.UInt32[Array, '2'] | jaxtyping.UInt32[ndarray, '2'] | jax.Array, kauldron.typing.shape_spec.Shape, str | type[typing.Any] | numpy.dtype | jax._src.typing.SupportsDType], kauldron.typing.array_types.Array] = <function zeros>, image: typing.Annotated[typing.Any, <object object at 0x76412092fb90>] = '__KEY_REQUIRED__', parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]

Bases: flax.linen.module.Module

Basic Vit Encoder.

Implements: - extracting and linearly embedding patches - adding a position embedding - a sequence of Transformer (self-attention) layers - a final RMS normalization

embedding

Submodule for embedding images into a flat set of tokens.

Type:

kauldron.modules.knn_types.ImageTokenizer

pos_embedding

Position Embeddings to add to the embedded tokens.

Type:

kauldron.modules.knn_types.PositionEmbedding

layers

Sequence of transformer blocks to apply.

Type:

Sequence[kauldron.modules.knn_types.TransformerBlock]

encoder_norm

Normalization to be applied at the end of the encoder.

Type:

kauldron.modules.knn_types.NormModule | None

prepend_cls_token

Whether to prepend a cls token after the position embeddings.

Type:

bool

cls_token_init

initializer for the cls token (if present).

Type:

Callable[[jaxtyping.UInt32[Array, ‘2’] | jaxtyping.UInt32[ndarray, ‘2’] | jax.Array, kauldron.typing.shape_spec.Shape, str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType], kauldron.typing.array_types.Array]

layers: Sequence[knn_types.TransformerBlock]
embedding: knn_types.ImageTokenizer
pos_embedding: knn_types.PositionEmbedding = LearnedEmbedding(     # attributes     dtype = float32     emb_init = init     emb_name = 'embeddings' )
encoder_norm: knn_types.NormModule | None = RMSNorm(     # attributes     epsilon = 1e-06     dtype = None     param_dtype = float32     use_scale = True     scale_init = ones     reduction_axes = -1     feature_axes = -1     axis_name = None     axis_index_groups = None     use_fast_variance = True     force_float32_reductions = True )
prepend_cls_token: bool = False
cls_token_init(
shape: collections.abc.Sequence[int | Any],
dtype: Any | None = None,
out_sharding: jax.sharding.NamedSharding | jax.sharding.PartitionSpec | None = None,
) jax.Array

An initializer that returns a constant array full of zeros.

The key argument is ignored.

>>> import jax, jax.numpy as jnp
>>> jax.nn.initializers.zeros(jax.random.key(42), (2, 3), jnp.float32)
Array([[0., 0., 0.],
       [0., 0., 0.]], dtype=float32)
image: kontext.Key = '__KEY_REQUIRED__'
classmethod from_variant_str(
variant_str: str,
**kwargs,
) kauldron.modules.vit.VitEncoder[source]
classmethod from_spec(num_heads: int, hidden_size: int, num_layers: int, patch_size: int | tuple[int, int], mlp_size: int | None = None, block_type=<class 'kauldron.modules.transformers.PreNormBlock'>, dtype=<class 'jax.numpy.float32'>, **kwargs)[source]
name: str | None = None
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = None
scope: Scope | None = None