kd.nn.ImprovedMultiHeadDotProductAttention

kd.nn.ImprovedMultiHeadDotProductAttention#

class kauldron.modules.ImprovedMultiHeadDotProductAttention(num_heads: int, qk_features: int | None = None, v_features: int | None = None, out_features: int | None = None, softmax_axis: int | tuple[int, ...] = -1, normalize_qk: bool = False, kernel_init: typing.Callable[[jaxtyping.UInt32[Array, '2'] | jaxtyping.UInt32[ndarray, '2'] | jax.Array, kauldron.typing.shape_spec.Shape, str | type[typing.Any] | numpy.dtype | jax._src.typing.SupportsDType], kauldron.typing.array_types.Array] = <function variance_scaling.<locals>.init>, bias_init: typing.Callable[[jaxtyping.UInt32[Array, '2'] | jaxtyping.UInt32[ndarray, '2'] | jax.Array, kauldron.typing.shape_spec.Shape, str | type[typing.Any] | numpy.dtype | jax._src.typing.SupportsDType], kauldron.typing.array_types.Array] = <function zeros>, use_bias: bool = True, attn_weights_fn: typing.Callable[[...], jaxtyping.Float[Array, '...'] | jaxtyping.Float[ndarray, '...']] = <function dot_product_attention_weights>, decode: bool = False, parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]

Bases: flax.linen.module.Module

Multi-head dot-product attention.

Simplified nn.MultiheadDotProductAttention with a few modifications:
  • include a softmax axis

  • accept an (additive) bias for the attention weights (in addition to mask)

  • dropped support for dropout

  • add attention weights to interms as “interms.PATH.TO.LAYER.attn_weights”

num_heads

Number of attention heads.

Type:

int

qk_size

Total dimension of the keys and queries.

v_size

Total dimension of the values. Defaults to qk_size.

softmax_axis

The axis over which the softmax is taken. defaults to -1 which is the keys axis. For Slot-Attention set to -2 (queries).

Type:

int | tuple[int, …]

num_heads: int
qk_features: int | None = None
v_features: int | None = None
out_features: int | None = None
softmax_axis: Axes = -1
normalize_qk: bool = False
kernel_init(shape: collections.abc.Sequence[int | typing.Any], dtype: typing.Any = <class 'jax.numpy.float64'>, out_sharding=None) jax.Array
bias_init(shape: collections.abc.Sequence[int | typing.Any], dtype: typing.Any = <class 'jax.numpy.float64'>) jax.Array

An initializer that returns a constant array full of zeros.

The key argument is ignored.

>>> import jax, jax.numpy as jnp
>>> jax.nn.initializers.zeros(jax.random.key(42), (2, 3), jnp.float32)
Array([[0., 0., 0.],
       [0., 0., 0.]], dtype=float32)
use_bias: bool = True
attn_weights_fn(key: jaxtyping.Float[Array, '*b k h d'] | jaxtyping.Float[ndarray, '*b k h d'], softmax_axis: int | tuple[int, ...] = -1, bias: jaxtyping.Float[Array, '*b #h #q #k'] | jaxtyping.Float[ndarray, '*b #h #q #k'] | None = None, mask: jaxtyping.Bool[Array, '*b #h #q #k'] | jaxtyping.Bool[ndarray, '*b #h #q #k'] | None = None, softmax_dtype: str | type[typing.Any] | numpy.dtype | jax._src.typing.SupportsDType | None = <class 'jax.numpy.float32'>) jaxtyping.Float[Array, '*b h q k'] | jaxtyping.Float[ndarray, '*b h q k'][source]

Computes dot-product attention weights given query and key.

q: number of queries, k: number of keys, h: number of heads d: dimension of keys/queries

Parameters:
  • query – Queries for calculating attention

  • key – Keys for calculating attention.

  • softmax_axis – The axes over which the softmax is taken. defaults to -1 which is the keys axis. For Slot-Attention set to -2 (queries).

  • bias – Bias for the attention weights. This should be broadcastable to the shape [*b h q k]. This can be used for incorporating causal masks, padding masks, proximity bias, etc.

  • mask – Mask for the attention weights. This should be broadcastable to the shape [*b h q k]. This can be used for incorporating causal masks. Attention weights are masked out if their corresponding mask value is False.

  • softmax_dtype – The dtype for the softmax operation. If None, the dtype of the input is used.

Returns:

Attention weights of shape [*b h q k].

decode: bool = False
property interms
name: str | None = None
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = None
scope: Scope | None = None