kd.nn.ParallelAttentionBlock

kd.nn.ParallelAttentionBlock#

class kauldron.modules.ParallelAttentionBlock(
attention: kauldron.modules.knn_types.AttentionModule,
mlp: flax.linen.module.Module = <factory>,
attention_norm: kauldron.modules.knn_types.NormModule = <factory>,
mlp_norm: kauldron.modules.knn_types.NormModule = <factory>,
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>,
name: str | None = None,
)[source]

Bases: flax.linen.module.Module

Parallel self attention (see Vit22B paper: arxiv.org/abs/2302.05442).

attention: knn_types.AttentionModule
mlp: nn.Module
attention_norm: knn_types.NormModule
mlp_norm: knn_types.NormModule
classmethod from_spec(
num_heads: int,
mlp_size: int | None = None,
normalize_qk: bool = True,
attn_kwargs: dict[str, Any] | None = None,
mlp_kwargs: dict[str, Any] | None = None,
**kwargs,
) kauldron.modules.transformers.ParallelAttentionBlock[source]
name: str | None = None
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = None
scope: Scope | None = None