kd.nn.MultiHeadDotProductAttention

kd.nn.MultiHeadDotProductAttention#

class kauldron.modules.MultiHeadDotProductAttention(
num_heads: int,
dtype: Dtype | None = None,
param_dtype: Dtype = <class 'jax.numpy.float32'>,
qkv_features: int | None = None,
out_features: int | None = None,
broadcast_dropout: bool = True,
dropout_rate: float = 0.0,
deterministic: bool | None = None,
precision: PrecisionLike = None,
kernel_init: Initializer = <function variance_scaling.<locals>.init>,
out_kernel_init: Initializer | None = None,
bias_init: Initializer = <function zeros>,
out_bias_init: Initializer | None = None,
use_bias: bool = True,
attention_fn: Callable[...,
Array] = <function dot_product_attention>,
decode: bool = False,
normalize_qk: bool = False,
force_fp32_for_softmax: bool = False,
qkv_dot_general: DotGeneralT | None = None,
out_dot_general: DotGeneralT | None = None,
qkv_dot_general_cls: Any = None,
out_dot_general_cls: Any = None,
qk_attn_weights_einsum_cls: Callable[...,
Callable[...,
Array]] | None = None,
attn_weights_value_einsum_cls: Callable[...,
Callable[...,
Array]] | None = None,
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>,
name: str | None = None,
)[source]

Bases: flax.linen.attention.MultiHeadDotProductAttention

Wrapper around nn.MultiHeadDotProductAttention using knn.train_property.

property is_training: bool

is_training property.

name: str | None = None
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = None
scope: Scope | None = None