kd.nn.AddLearnedEmbedding

kd.nn.AddLearnedEmbedding#

class kauldron.modules.AddLearnedEmbedding(emb_init: typing.Callable[[jaxtyping.UInt32[Array, '2'] | jaxtyping.UInt32[ndarray, '2'] | jax.Array, kauldron.typing.shape_spec.Shape, str | type[typing.Any] | numpy.dtype | jax._src.typing.SupportsDType], kauldron.typing.array_types.Array] = <function normal.<locals>.init>, axes: int | tuple[int, ...] = (-2, -1), parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]

Bases: flax.linen.module.Module

Adds learned positional embeddings to the inputs.

DEPRECATED: This module is deprecated in favor of the new LearnedEmbedding.

emb_init

Positional embedding initializer.

Type:

Callable[[jaxtyping.UInt32[Array, ‘2’] | jaxtyping.UInt32[ndarray, ‘2’] | jax.Array, kauldron.typing.shape_spec.Shape, str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType], kauldron.typing.array_types.Array]

axes

One or more axes which to include into the embedding shape. The feature axis (-1) is automatically included and should not be passed explicitly.

Type:

int | tuple[int, …]

Returns:

Array with same shape as input.

emb_init(shape: collections.abc.Sequence[int | typing.Any], dtype: typing.Any = <class 'jax.numpy.float64'>, out_sharding=None) jax.Array
axes: Axes = (-2, -1)
name: str | None = None
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = None
scope: Scope | None = None