kd.nn.ImprovedMultiHeadDotProductAttention#
- class kauldron.modules.ImprovedMultiHeadDotProductAttention(num_heads: int, qk_features: int | None = None, v_features: int | None = None, out_features: int | None = None, softmax_axis: int | tuple[int, ...] = -1, normalize_qk: bool = False, kernel_init: typing.Callable[[jaxtyping.UInt32[Array, '2'] | jaxtyping.UInt32[ndarray, '2'] | jax.Array, kauldron.typing.shape_spec.Shape, str | type[typing.Any] | numpy.dtype | jax._src.typing.SupportsDType], kauldron.typing.array_types.Array] = <function variance_scaling.<locals>.init>, bias_init: typing.Callable[[jaxtyping.UInt32[Array, '2'] | jaxtyping.UInt32[ndarray, '2'] | jax.Array, kauldron.typing.shape_spec.Shape, str | type[typing.Any] | numpy.dtype | jax._src.typing.SupportsDType], kauldron.typing.array_types.Array] = <function zeros>, use_bias: bool = True, attn_weights_fn: typing.Callable[[...], jaxtyping.Float[Array, '...'] | jaxtyping.Float[ndarray, '...']] = <function dot_product_attention_weights>, decode: bool = False, parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]
Bases:
flax.linen.module.ModuleMulti-head dot-product attention.
- Simplified nn.MultiheadDotProductAttention with a few modifications:
include a softmax axis
accept an (additive) bias for the attention weights (in addition to mask)
dropped support for dropout
add attention weights to interms as “interms.PATH.TO.LAYER.attn_weights”
- num_heads
Number of attention heads.
- Type:
int
- qk_size
Total dimension of the keys and queries.
- v_size
Total dimension of the values. Defaults to qk_size.
- softmax_axis
The axis over which the softmax is taken. defaults to -1 which is the keys axis. For Slot-Attention set to -2 (queries).
- Type:
int | tuple[int, …]
- num_heads: int
- qk_features: int | None = None
- v_features: int | None = None
- out_features: int | None = None
- softmax_axis: Axes = -1
- normalize_qk: bool = False
- kernel_init(
- shape: collections.abc.Sequence[int | Any],
- dtype: Any | None = None,
- out_sharding: jax.sharding.NamedSharding | jax.sharding.PartitionSpec | None = None,
- bias_init(
- shape: collections.abc.Sequence[int | Any],
- dtype: Any | None = None,
- out_sharding: jax.sharding.NamedSharding | jax.sharding.PartitionSpec | None = None,
An initializer that returns a constant array full of zeros.
The
keyargument is ignored.>>> import jax, jax.numpy as jnp >>> jax.nn.initializers.zeros(jax.random.key(42), (2, 3), jnp.float32) Array([[0., 0., 0.], [0., 0., 0.]], dtype=float32)
- use_bias: bool = True
- attn_weights_fn(key: jaxtyping.Float[Array, '*b k h d'] | jaxtyping.Float[ndarray, '*b k h d'], softmax_axis: int | tuple[int, ...] = -1, bias: jaxtyping.Float[Array, '*b #h #q #k'] | jaxtyping.Float[ndarray, '*b #h #q #k'] | None = None, mask: jaxtyping.Bool[Array, '*b #h #q #k'] | jaxtyping.Bool[ndarray, '*b #h #q #k'] | None = None, softmax_dtype: str | type[typing.Any] | numpy.dtype | jax._src.typing.SupportsDType | None = <class 'jax.numpy.float32'>) jaxtyping.Float[Array, '*b h q k'] | jaxtyping.Float[ndarray, '*b h q k'][source]
Computes dot-product attention weights given query and key.
q: number of queries, k: number of keys, h: number of heads d: dimension of keys/queries
- Parameters:
query – Queries for calculating attention
key – Keys for calculating attention.
softmax_axis – The axes over which the softmax is taken. defaults to -1 which is the keys axis. For Slot-Attention set to -2 (queries).
bias – Bias for the attention weights. This should be broadcastable to the shape [*b h q k]. This can be used for incorporating causal masks, padding masks, proximity bias, etc.
mask – Mask for the attention weights. This should be broadcastable to the shape [*b h q k]. This can be used for incorporating causal masks. Attention weights are masked out if their corresponding mask value is False.
softmax_dtype – The dtype for the softmax operation. If None, the dtype of the input is used.
- Returns:
Attention weights of shape [*b h q k].
- decode: bool = False
- property interms
- name: str | None = None
- parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = None
- scope: Scope | None = None