kd.nn.AddEmbedding

kd.nn.AddEmbedding#

class kauldron.modules.AddEmbedding(emb: kauldron.modules.knn_types.PositionEmbedding, axis: int | tuple[int, ...], parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]

Bases: flax.linen.module.Module

Helper Module for adding a PositionEmbedding e.g. in a knn.Sequential.

emb

The position embedding to be added to the inputs.

Type:

kauldron.modules.knn_types.PositionEmbedding

axis

The axis parameter passed to the position embedding for determining its shape. Usually set to -2, to get embeddings of shape n d for inputs of dimension *b n d.

Type:

int | tuple[int, …]

emb: knn_types.PositionEmbedding
axis: Axes
name: str | None = None
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = None
scope: Scope | None = None