kd.nn.Reduce

kd.nn.Reduce#

class kauldron.modules.Reduce(pattern: str, reduction: typing.Literal['min', 'max', 'sum', 'mean', 'prod'], axes_lengths: dict[str, int] = <factory>, parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]

Bases: flax.linen.module.Module

Wrapper around einops.reduce for usage e.g. in nn.Sequential.

Example:

cfg.model = kd.nn.Sequential(
    inputs="batch.image",
    layers=[
        ...
        kd.nn.PreNormBlock.from_spec(num_heads=12),
        # use Reduce to implement Global Average Pooling
        kd.nn.Reduce(pattern="... n d -> ... d", reduction="mean"),
        nn.Dense(features=1000),
    ]
)
pattern

einops.reduce pattern, e.g. “b h w c -> b c”

Type:

str

reduction

one of available reductions (‘min’, ‘max’, ‘sum’, ‘mean’, ‘prod’)

Type:

Literal[‘min’, ‘max’, ‘sum’, ‘mean’, ‘prod’]

axes_lengths

a dictionary for specifying additional axis sizes that cannot be inferred from the pattern and the tensor alone.

Type:

dict[str, int]

pattern: str
reduction: Literal['min', 'max', 'sum', 'mean', 'prod']
axes_lengths: dict[str, int]
name: str | None = None
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = None
scope: Scope | None = None