kd.nn.ExternalModule

kd.nn.ExternalModule#

class kauldron.modules.ExternalModule(model: flax.linen.module.Module, keys: str | dict[str, str], train_kwarg_name: str | None = None, parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]

Bases: kauldron.modules.adapter.WrapperModule

Module that is defined outside Kauldron.

This is a very thin wrapper around flax.linen.Module that add:

  • Keys: To connect the model to the dataset batch

  • Training property compatibility: Pass train=True (or equivalent kwargs) when calling the model, rather than using the kd.nn.train_property()

cfg.model = kd.nn.ExternalModule(
    model=nn.Dropout(),
    keys={
        'x': 'batch.image',
    },
    train_kwarg_name='~deterministic',
)
model

The flax model to wrap

Type:

flax.linen.module.Module

keys

Mapping from model.__call__ kwargs names to context paths ( e.g. keys={‘x’: ‘batch.image’} to call the model as model.apply(rng, x=batch[‘image’])). If str given, the input is passed as args, if dict, the inputs are passed as kwargs.

Type:

str | dict[str, str]

train_kwarg_name

If provided, then the model will be called with model.apply(…, <train_kwarg_name>=True). Flax models don’t have a standard way to specify train/eval mode, so each codebase uses a different convention (deterministic=, train=, is_training=,…). The kwargs can be inverted with ~ (e.g. train_kwarg_name=’~deterministic’)

Type:

str | None

keys: str | dict[str, str]
train_kwarg_name: str | None = None
property is_training: bool

is_training property.

name: str | None = None
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = None
scope: Scope | None = None