kd.nn.MultiHeadDotProductAttention#
- class kauldron.modules.MultiHeadDotProductAttention(
- num_heads: int,
- dtype: Dtype | None = None,
- param_dtype: Dtype = <class 'jax.numpy.float32'>,
- qkv_features: int | None = None,
- out_features: int | None = None,
- broadcast_dropout: bool = True,
- dropout_rate: float = 0.0,
- deterministic: bool | None = None,
- precision: PrecisionLike = None,
- kernel_init: Initializer = <function variance_scaling.<locals>.init>,
- out_kernel_init: Initializer | None = None,
- bias_init: Initializer = <function zeros>,
- out_bias_init: Initializer | None = None,
- use_bias: bool = True,
- attention_fn: Callable[...,
- Array] = <function dot_product_attention>,
- decode: bool = False,
- normalize_qk: bool = False,
- force_fp32_for_softmax: bool = False,
- qkv_dot_general: DotGeneralT | None = None,
- out_dot_general: DotGeneralT | None = None,
- qkv_dot_general_cls: Any = None,
- out_dot_general_cls: Any = None,
- qk_attn_weights_einsum_cls: Callable[...,
- Callable[...,
- Array]] | None = None,
- attn_weights_value_einsum_cls: Callable[...,
- Callable[...,
- Array]] | None = None,
- parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>,
- name: str | None = None,
Bases:
flax.linen.attention.MultiHeadDotProductAttentionWrapper around nn.MultiHeadDotProductAttention using knn.train_property.
- property is_training: bool
is_training property.
- name: str | None = None
- parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = None
- scope: Scope | None = None