kd.nn.Rearrange#
- class kauldron.modules.Rearrange(
- pattern: str,
- axes_lengths: dict[str,
- int] = <factory>,
- parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>,
- name: str | None = None,
Bases:
flax.linen.module.ModuleWrapper around einops.rearrange for usage e.g. in nn.Sequential.
Example:
cfg.model = kd.nn.Sequential( inputs="batch.image", layers=[ nn.Conv(features=192, kernel_size=(8, 8), strides=(8, 8)), # flatten the image dimensions before applying the transformer blocks kd.nn.Rearrange(pattern="... h w d -> ... (h w) d"), kd.nn.PreNormBlock.from_spec(num_heads=12), ... ] )
- pattern
einops.rearrange pattern, e.g. “b h w c -> b c (h w)”
- Type:
str
- axes_lengths
a dictionary for specifying additional axis sizes that cannot be inferred from the pattern and the tensor alone.
- Type:
dict[str, int]
- pattern: str
- axes_lengths: dict[str, int]
- name: str | None = None
- parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = None
- scope: Scope | None = None