kd.nn.Rearrange

kd.nn.Rearrange#

class kauldron.modules.Rearrange(
pattern: str,
axes_lengths: dict[str,
int] = <factory>,
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>,
name: str | None = None,
)[source]

Bases: flax.linen.module.Module

Wrapper around einops.rearrange for usage e.g. in nn.Sequential.

Example:

cfg.model = kd.nn.Sequential(
    inputs="batch.image",
    layers=[
        nn.Conv(features=192, kernel_size=(8, 8), strides=(8, 8)),
        # flatten the image dimensions before applying the transformer blocks
        kd.nn.Rearrange(pattern="... h w d -> ... (h w) d"),
        kd.nn.PreNormBlock.from_spec(num_heads=12),
        ...
    ]
)
pattern

einops.rearrange pattern, e.g. “b h w c -> b c (h w)”

Type:

str

axes_lengths

a dictionary for specifying additional axis sizes that cannot be inferred from the pattern and the tensor alone.

Type:

dict[str, int]

pattern: str
axes_lengths: dict[str, int]
name: str | None = None
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = None
scope: Scope | None = None