Konfig - Demo#

konfig is a wrapper around ml_collection.ConfigDict.

It allows to turn any Python function call /class construction into a nested dict which can be serialized/deserialized.

This unlocks (among others) type checking and autocomplete inside configuration files.

Imports#

Import konfig:

from etils.lazy_imports import *

with ecolab.adhoc():
  from kauldron import konfig

End-to-end example#

Here is a short example demonstrating how konfig work:

# Step 1: Import your modules you want to be configurable (can be any module)
with konfig.imports():
  from flax import linen as nn

# Step 2: Create your configuration. **Nothing is executed yet!!**
model_cfg = nn.Dense(features=32)  # This create a `ConfigDict` !!!

# Step 3: The `ConfigDict` object can optionally be mutated
model_cfg.use_bias = True

# Step 4: Resolve the `ConfigDict` -> Python object
# This is where the actual `nn.Dense` object is created
model = konfig.resolve(model_cfg)

Let’s see those steps in more details.

Basic usage#

Create a config (from any Python objects)#

Here is a nested Python call that get executed (can be any function, modules,…):

from flax import linen as nn

model = nn.Sequential(layers=[
    nn.Dense(features=32),
    nn.Dropout(rate=0.5),
])

By wrapping the call inside a with konfig.mock_modules(): contextmanager, the nested call is turned into a nested konfig.ConfigDict object, rather than being executed:

with konfig.mock_modules():
  model_cfg = nn.Sequential(layers=[  # model_cfg is a ConfigDict ! Not nn.Sequential!
      nn.Dense(features=32),
      nn.Dropout(rate=0.5),
  ]);
<ConfigDict[nn.Sequential(
    layers=[
        nn.Dense(features=32),
        nn.Dropout(rate=0.5),
    ],
)]>

An equivalent to build the config is to import the module in a with kontext.imports(): contextmanager:

with konfig.imports():
  from flax import linen as nn

model_cfg = nn.Sequential(layers=[
    nn.Dense(features=32),
    nn.Dropout(rate=0.5),
]);
<ConfigDict[nn.Sequential(
    layers=[
        nn.Dense(features=32),
        nn.Dropout(rate=0.5),
    ],
)]>

To summarize: Turn any Python call into configurable with either:

  • with kontext.imports(): Import modules will then create ConfigDict when used

  • with konfig.mock_modules(): To locally make modules create ConfigDict

Mutate the config#

The ConfigDict object can be mutated (directly or through command line), like:

model_cfg.layers[0].features = 64
model_cfg.layers[0].use_bias = False
model_cfg.layers[1].rate = 0.9
model_cfg
<ConfigDict[nn.Sequential(
    layers=[
        nn.Dense(
            features=64,
            use_bias=False,
        ),
        nn.Dropout(rate=0.9),
    ],
)]>

See Advanced section bellow for mutating args, changing the class,…

Resolve the config#

Once the config has been updated, it can be resolved into the actual object:

model = konfig.resolve(model_cfg);
Sequential(
    # attributes
    layers = (Dense(
        # attributes
        features = 64
        use_bias = False
        dtype = None
        param_dtype = float32
        precision = None
        kernel_init = init
        bias_init = zeros
        dot_general = None
        dot_general_cls = None
    ), Dropout(
        # attributes
        rate = 0.9
        broadcast_dims = ()
        deterministic = None
        rng_collection = 'dropout'
    ))
)

To summarize:

  • model_cfg is the ConfigDict (mutable), created inside the with konfig.mock_modules(): contextmanager

  • model is the resolved object (after calling konfig.resolve)

assert isinstance(model_cfg, konfig.ConfigDict)
assert isinstance(model, flax.linen.Module)

Serialize the config#

ConfigDict objects are just very simple nested dict that are human readable and easily serializable:

serialized_dict = json.loads(model_cfg.to_json());p
{
    '__qualname__': 'flax.linen:Sequential',
    'layers': [
        {
            '__qualname__': 'flax.linen:Dense',
            'features': 64,
            'use_bias': False,
        },
        {
            '__qualname__': 'flax.linen:Dropout',
            'rate': 0.9,
        },
    ],
}

Deserialization done by passing the json dict to konfig.ConfigDict:

konfig.ConfigDict(serialized_dict)
<ConfigDict[nn.Sequential(
    layers=[
        nn.Dense(
            features=64,
            use_bias=False,
        ),
        nn.Dropout(rate=0.9),
    ],
)]>

Caveat:

  • Json do no preserve tuple/list information. tuple are restored as list

Benefits#

  • Everything is configurable by default (flax, optax, your custom dataclasses,…) without any change in your code.

  • No duplication between code and the config system (Model vs ModelParams), as code is directly customizable

  • Config files nativelly support type checking and auto-complete. And benefit from all IDE overlay (hover on a class to see the doc, click on a attribute to see where it is defined in the code,…).

Advanced usage#

Mutating a config#

All parts of a Python call can be modified/overwritten (e.g. sweeped over):

# Create a dummy function to demonstrate how to overwrite configs
def some_fn(*args, **kwargs):
  print(f'some_fn called with {args}, {kwargs}')


def other_fn(*args, **kwargs):
  print(f'other_fn called with {args}, {kwargs}')


# In Colab, import `__main__` to turn the Colab itself into configurable
import __main__


with konfig.mock_modules():
  cfg = __main__.some_fn(dtype=np.int32);
<ConfigDict[__main__.some_fn(dtype=np.int32)]>
  • kwargs - by mutating cfg.kwarg_name = :

cfg.x = 1
cfg.y = 2
cfg
<ConfigDict[__main__.some_fn(dtype=np.int32, x=1, y=2)]>
  • args - by mutating cfg[arg_id] = :

cfg[0] = 'a'  # Append additional arguments
cfg[1] = 'b'
cfg[-1] = 'b_overwritten'
cfg
<ConfigDict[__main__.some_fn(
    'a',
    'b_overwritten',
    dtype=np.int32,
    x=1,
    y=2,
)]>
  • class/function - by mutating cfg.__qualname__ = (to <import>:<name>):

f'Previous qualname: {cfg.__qualname__}';

cfg.__qualname__ = '__main__:other_fn'  # Use `other_fn()` instead of `some_fn()`
cfg
'Previous qualname: __main__:some_fn'
<ConfigDict[__main__.other_fn(
    'a',
    'b_overwritten',
    dtype=np.int32,
    x=1,
    y=2,
)]>
  • constants - by mutating cfg.__const__ =

cfg.dtype.__const__ = 'numpy:float64'
cfg
<ConfigDict[__main__.other_fn(
    'a',
    'b_overwritten',
    dtype=np.float64,
    x=1,
    y=2,
)]>
konfig.resolve(cfg)
other_fn called with ('a', 'b_overwritten'), {'dtype': <class 'numpy.float64'>, 'x': 1, 'y': 2}

Protocols#

Skip this section unless you’re interested in Konfig internals

In some cases, the can be some interactions between the Python object and it’s Config version. This interaction is done through __konfig_xxx__ protocols and konfig.register_xxx() functions.

konfig.register_alias: For cosmetic only, allow to rewrite the module names (jax.numpy -> jnp, kauldron.kd -> kd,…)

konfig.register_default_values: When a ConfigDict object is created, it should sometimes be initialized with default values. This can be use to initialized child fields, like:

cfg = kd.train.Trainer()  # Implicitly create `cfg.setup`,...
cfg.setup.tags = ['text', 'tpu_v3']

__konfig_resolve_exclude_fields__: Fields to exclude from being resolved. Can be used if the config contain objects that should not always be resolved (e.g. to avoid implorting heavy deps when not needed).

__post_konfig_resolve__(self, cfg: konfig.ConfigDict): Function called after an object is resolved. Can be used so an resolved object keep track of the config it was created with:

class A:

  def __post_konfig_resolve__(self, cfg: konfig.ConfigDict):
    self.cfg = cfg