kd.nn.Vit#
- class kauldron.modules.Vit(encoder: flax.linen.module.Module, num_classes: int = 1000, init_head_bias: typing.Callable[[jaxtyping.UInt32[Array, '2'] | jaxtyping.UInt32[ndarray, '2'] | jax.Array, kauldron.typing.shape_spec.Shape, str | type[typing.Any] | numpy.dtype | jax._src.typing.SupportsDType], kauldron.typing.array_types.Array] = <function zeros>, mode: typing.Literal['gap', 'cls_token', 'cls_token_forced'] = 'gap', image: typing.Annotated[typing.Any, <object object at 0x7824c478ba80>] = '__KEY_REQUIRED__', parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]
Bases:
flax.linen.module.ModuleBasic Vision Transformer classifer with GAP.
- encoder: nn.Module
- num_classes: int = 1000
- init_head_bias(shape: collections.abc.Sequence[int | typing.Any], dtype: typing.Any = <class 'jax.numpy.float64'>) jax.Array
An initializer that returns a constant array full of zeros.
The
keyargument is ignored.>>> import jax, jax.numpy as jnp >>> jax.nn.initializers.zeros(jax.random.key(42), (2, 3), jnp.float32) Array([[0., 0., 0.], [0., 0., 0.]], dtype=float32)
- mode: Literal['gap', 'cls_token', 'cls_token_forced'] = 'gap'
- image: kontext.Key = '__KEY_REQUIRED__'
- name: str | None = None
- parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = None
- scope: Scope | None = None