kd.nn.VitEncoder#
- class kauldron.modules.VitEncoder(layers: typing.Sequence[kauldron.modules.knn_types.TransformerBlock], embedding: kauldron.modules.knn_types.ImageTokenizer, pos_embedding: kauldron.modules.knn_types.PositionEmbedding = LearnedEmbedding( # attributes dtype = float32 emb_init = init emb_name = 'embeddings' ), encoder_norm: kauldron.modules.knn_types.NormModule | None = RMSNorm( # attributes epsilon = 1e-06 dtype = None param_dtype = float32 use_scale = True scale_init = ones reduction_axes = -1 feature_axes = -1 axis_name = None axis_index_groups = None use_fast_variance = True force_float32_reductions = True ), prepend_cls_token: bool = False, cls_token_init: typing.Callable[[jaxtyping.UInt32[Array, '2'] | jaxtyping.UInt32[ndarray, '2'] | jax.Array, kauldron.typing.shape_spec.Shape, str | type[typing.Any] | numpy.dtype | jax._src.typing.SupportsDType], kauldron.typing.array_types.Array] = <function zeros>, image: typing.Annotated[typing.Any, <object object at 0x7824c478ba80>] = '__KEY_REQUIRED__', parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]
Bases:
flax.linen.module.ModuleBasic Vit Encoder.
Implements: - extracting and linearly embedding patches - adding a position embedding - a sequence of Transformer (self-attention) layers - a final RMS normalization
- embedding
Submodule for embedding images into a flat set of tokens.
- Type:
kauldron.modules.knn_types.ImageTokenizer
- pos_embedding
Position Embeddings to add to the embedded tokens.
- Type:
kauldron.modules.knn_types.PositionEmbedding
- layers
Sequence of transformer blocks to apply.
- Type:
Sequence[kauldron.modules.knn_types.TransformerBlock]
- encoder_norm
Normalization to be applied at the end of the encoder.
- Type:
kauldron.modules.knn_types.NormModule | None
- prepend_cls_token
Whether to prepend a cls token after the position embeddings.
- Type:
bool
- cls_token_init
initializer for the cls token (if present).
- Type:
Callable[[jaxtyping.UInt32[Array, ‘2’] | jaxtyping.UInt32[ndarray, ‘2’] | jax.Array, kauldron.typing.shape_spec.Shape, str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType], kauldron.typing.array_types.Array]
- layers: Sequence[knn_types.TransformerBlock]
- embedding: knn_types.ImageTokenizer
- pos_embedding: knn_types.PositionEmbedding = LearnedEmbedding( # attributes dtype = float32 emb_init = init emb_name = 'embeddings' )
- encoder_norm: knn_types.NormModule | None = RMSNorm( # attributes epsilon = 1e-06 dtype = None param_dtype = float32 use_scale = True scale_init = ones reduction_axes = -1 feature_axes = -1 axis_name = None axis_index_groups = None use_fast_variance = True force_float32_reductions = True )
- prepend_cls_token: bool = False
- cls_token_init(shape: collections.abc.Sequence[int | typing.Any], dtype: typing.Any = <class 'jax.numpy.float64'>) jax.Array
An initializer that returns a constant array full of zeros.
The
keyargument is ignored.>>> import jax, jax.numpy as jnp >>> jax.nn.initializers.zeros(jax.random.key(42), (2, 3), jnp.float32) Array([[0., 0., 0.], [0., 0., 0.]], dtype=float32)
- image: kontext.Key = '__KEY_REQUIRED__'
- classmethod from_variant_str(
- variant_str: str,
- **kwargs,
- classmethod from_spec(num_heads: int, hidden_size: int, num_layers: int, patch_size: int | tuple[int, int], mlp_size: int | None = None, block_type=<class 'kauldron.modules.transformers.PreNormBlock'>, dtype=<class 'jax.numpy.float32'>, **kwargs)[source]
- name: str | None = None
- parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = None
- scope: Scope | None = None