kd.nn.WrapperModule#
- class kauldron.modules.WrapperModule(
- model: flax.linen.module.Module,
- parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>,
- name: str | None = None,
Bases:
flax.linen.module.ModuleBase class to wrapper a module.
The wrapper module transparent with respect to the inner parameters ( {‘params’: inner_params} instead of nesting {‘params’: {‘model’: inner_params}}).
The keys from the wrapped model are auto-propagated to the wrapper, so the module can be initialized as:
cfg.model = kd.nn.WrapperModule( model=MyModel( input='batch.input', # keys propagated to the `WrapperModule` ), )
Example to create a wrapper which adds gradient checkpointing to any model:
class CheckpointWrapper(kd.nn.WrapperModule): @nn.checkpoint @nn.compact def __call__(self, *args, **kwargs): return super().__call__(*args, **kwargs) model = CheckpointWrapper(model=MyModel(x='batch.input'))
- model: flax.linen.module.Module
- name: str | None = None
- parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = None
- scope: Scope | None = None