kd.nn.WrapperModule

kd.nn.WrapperModule#

class kauldron.modules.WrapperModule(
model: flax.linen.module.Module,
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>,
name: str | None = None,
)[source]

Bases: flax.linen.module.Module

Base class to wrapper a module.

The wrapper module transparent with respect to the inner parameters ( {‘params’: inner_params} instead of nesting {‘params’: {‘model’: inner_params}}).

The keys from the wrapped model are auto-propagated to the wrapper, so the module can be initialized as:

cfg.model = kd.nn.WrapperModule(
    model=MyModel(
        input='batch.input',  # keys propagated to the `WrapperModule`
    ),
)

Example to create a wrapper which adds gradient checkpointing to any model:

class CheckpointWrapper(kd.nn.WrapperModule):

  @nn.checkpoint
  @nn.compact
  def __call__(self, *args, **kwargs):
    return super().__call__(*args, **kwargs)


model = CheckpointWrapper(model=MyModel(x='batch.input'))
model: flax.linen.module.Module
name: str | None = None
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = None
scope: Scope | None = None