kd.nn.ExternalModule#
- class kauldron.modules.ExternalModule(model: flax.linen.module.Module, keys: str | dict[str, str], train_kwarg_name: str | None = None, parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]
Bases:
kauldron.modules.adapter.WrapperModuleModule that is defined outside Kauldron.
This is a very thin wrapper around flax.linen.Module that add:
Keys: To connect the model to the dataset batch
Training property compatibility: Pass train=True (or equivalent kwargs) when calling the model, rather than using the kd.nn.train_property()
cfg.model = kd.nn.ExternalModule( model=nn.Dropout(), keys={ 'x': 'batch.image', }, train_kwarg_name='~deterministic', )
- model
The flax model to wrap
- Type:
flax.linen.module.Module
- keys
Mapping from model.__call__ kwargs names to context paths ( e.g. keys={‘x’: ‘batch.image’} to call the model as model.apply(rng, x=batch[‘image’])). If str given, the input is passed as args, if dict, the inputs are passed as kwargs.
- Type:
str | dict[str, str]
- train_kwarg_name
If provided, then the model will be called with model.apply(…, <train_kwarg_name>=True). Flax models don’t have a standard way to specify train/eval mode, so each codebase uses a different convention (deterministic=, train=, is_training=,…). The kwargs can be inverted with ~ (e.g. train_kwarg_name=’~deterministic’)
- Type:
str | None
- keys: str | dict[str, str]
- train_kwarg_name: str | None = None
- property is_training: bool
is_training property.
- name: str | None = None
- parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = None
- scope: Scope | None = None