kd.nn.PreNormBlock#
- class kauldron.modules.PreNormBlock(
- attention: kauldron.modules.knn_types.AttentionModule,
- mlp: flax.linen.module.Module = <factory>,
- attention_norm: kauldron.modules.knn_types.NormModule = <factory>,
- mlp_norm: kauldron.modules.knn_types.NormModule = <factory>,
- parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>,
- name: str | None = None,
Bases:
flax.linen.module.ModulePre-LN Transformer layer (default transformer layer).
- attention: knn_types.AttentionModule
- mlp: nn.Module
- attention_norm: knn_types.NormModule
- mlp_norm: knn_types.NormModule
- classmethod from_spec(
- num_heads: int,
- mlp_size: int | None = None,
- normalize_qk: bool = True,
- attn_kwargs: dict[str, Any] | None = None,
- mlp_kwargs: dict[str, Any] | None = None,
- **kwargs,
- name: str | None = None
- parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = None
- scope: Scope | None = None