kd.nn.LearnedEmbedding

kd.nn.LearnedEmbedding#

class kauldron.modules.LearnedEmbedding(dtype: str | type[typing.Any] | numpy.dtype | jax._src.typing.SupportsDType = <class 'jax.numpy.float32'>, emb_init: typing.Callable[[jaxtyping.UInt32[Array, '2'] | jaxtyping.UInt32[ndarray, '2'] | jax.Array, kauldron.typing.shape_spec.Shape, str | type[typing.Any] | numpy.dtype | jax._src.typing.SupportsDType], kauldron.typing.array_types.Array] = <function normal.<locals>.init>, emb_name: str = 'embeddings', parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]

Bases: flax.linen.module.Module

Learned positional embeddings.

Implements the knn_types.PositionEmbedding protocol.

emb_init

Initializer for the position embeddings.

Type:

Callable[[jaxtyping.UInt32[Array, ‘2’] | jaxtyping.UInt32[ndarray, ‘2’] | jax.Array, kauldron.typing.shape_spec.Shape, str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType], kauldron.typing.array_types.Array]

dtype

DType of the position embedding. Default to float32.

Type:

str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType

dtype

alias of jax.numpy.float32

emb_init(
shape: collections.abc.Sequence[int | Any],
dtype: Any | None = None,
out_sharding: jax.sharding.NamedSharding | jax.sharding.PartitionSpec | None = None,
) jax.Array
emb_name: str = 'embeddings'
name: str | None = None
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = None
scope: Scope | None = None