Konfig - Demo#
konfig is a wrapper around ml_collection.ConfigDict.
It allows to turn any Python function call /class construction into a nested dict which can be serialized/deserialized.
This unlocks (among others) type checking and autocomplete inside configuration files.
Imports#
Import konfig:
from etils.lazy_imports import *
with ecolab.adhoc():
from kauldron import konfig
End-to-end example#
Here is a short example demonstrating how konfig work:
# Step 1: Import your modules you want to be configurable (can be any module)
with konfig.imports():
from flax import linen as nn
# Step 2: Create your configuration. **Nothing is executed yet!!**
model_cfg = nn.Dense(features=32) # This create a `ConfigDict` !!!
# Step 3: The `ConfigDict` object can optionally be mutated
model_cfg.use_bias = True
# Step 4: Resolve the `ConfigDict` -> Python object
# This is where the actual `nn.Dense` object is created
model = konfig.resolve(model_cfg)
Let’s see those steps in more details.
Basic usage#
Create a config (from any Python objects)#
Here is a nested Python call that get executed (can be any function, modules,…):
from flax import linen as nn
model = nn.Sequential(layers=[
nn.Dense(features=32),
nn.Dropout(rate=0.5),
])
By wrapping the call inside a with konfig.mock_modules(): contextmanager, the nested call is turned into a nested konfig.ConfigDict object, rather than being executed:
with konfig.mock_modules():
model_cfg = nn.Sequential(layers=[ # model_cfg is a ConfigDict ! Not nn.Sequential!
nn.Dense(features=32),
nn.Dropout(rate=0.5),
]);
<ConfigDict[nn.Sequential(
layers=[
nn.Dense(features=32),
nn.Dropout(rate=0.5),
],
)]>An equivalent to build the config is to import the module in a with kontext.imports(): contextmanager:
with konfig.imports():
from flax import linen as nn
model_cfg = nn.Sequential(layers=[
nn.Dense(features=32),
nn.Dropout(rate=0.5),
]);
<ConfigDict[nn.Sequential(
layers=[
nn.Dense(features=32),
nn.Dropout(rate=0.5),
],
)]>To summarize: Turn any Python call into configurable with either:
with kontext.imports():Import modules will then createConfigDictwhen usedwith konfig.mock_modules():To locally make modules createConfigDict
Mutate the config#
The ConfigDict object can be mutated (directly or through command line), like:
model_cfg.layers[0].features = 64
model_cfg.layers[0].use_bias = False
model_cfg.layers[1].rate = 0.9
model_cfg
<ConfigDict[nn.Sequential(
layers=[
nn.Dense(
features=64,
use_bias=False,
),
nn.Dropout(rate=0.9),
],
)]>See Advanced section bellow for mutating args, changing the class,…
Resolve the config#
Once the config has been updated, it can be resolved into the actual object:
model = konfig.resolve(model_cfg);
Sequential(
# attributes
layers = (Dense(
# attributes
features = 64
use_bias = False
dtype = None
param_dtype = float32
precision = None
kernel_init = init
bias_init = zeros
dot_general = None
dot_general_cls = None
), Dropout(
# attributes
rate = 0.9
broadcast_dims = ()
deterministic = None
rng_collection = 'dropout'
))
)
To summarize:
model_cfgis theConfigDict(mutable), created inside thewith konfig.mock_modules():contextmanagermodelis the resolved object (after callingkonfig.resolve)
assert isinstance(model_cfg, konfig.ConfigDict)
assert isinstance(model, flax.linen.Module)
Serialize the config#
ConfigDict objects are just very simple nested dict that are human readable and easily serializable:
serialized_dict = json.loads(model_cfg.to_json());p
{
'__qualname__': 'flax.linen:Sequential',
'layers': [
{
'__qualname__': 'flax.linen:Dense',
'features': 64,
'use_bias': False,
},
{
'__qualname__': 'flax.linen:Dropout',
'rate': 0.9,
},
],
}
Deserialization done by passing the json dict to konfig.ConfigDict:
konfig.ConfigDict(serialized_dict)
<ConfigDict[nn.Sequential(
layers=[
nn.Dense(
features=64,
use_bias=False,
),
nn.Dropout(rate=0.9),
],
)]>Caveat:
Json do no preserve tuple/list information.
tupleare restored aslist
Benefits#
Everything is configurable by default (flax, optax, your custom dataclasses,…) without any change in your code.
No duplication between code and the config system (
ModelvsModelParams), as code is directly customizableConfig files nativelly support type checking and auto-complete. And benefit from all IDE overlay (hover on a class to see the doc, click on a attribute to see where it is defined in the code,…).
Advanced usage#
Mutating a config#
All parts of a Python call can be modified/overwritten (e.g. sweeped over):
# Create a dummy function to demonstrate how to overwrite configs
def some_fn(*args, **kwargs):
print(f'some_fn called with {args}, {kwargs}')
def other_fn(*args, **kwargs):
print(f'other_fn called with {args}, {kwargs}')
# In Colab, import `__main__` to turn the Colab itself into configurable
import __main__
with konfig.mock_modules():
cfg = __main__.some_fn(dtype=np.int32);
<ConfigDict[__main__.some_fn(dtype=np.int32)]>kwargs - by mutating
cfg.kwarg_name =:
cfg.x = 1
cfg.y = 2
cfg
<ConfigDict[__main__.some_fn(dtype=np.int32, x=1, y=2)]>args - by mutating
cfg[arg_id] =:
cfg[0] = 'a' # Append additional arguments
cfg[1] = 'b'
cfg[-1] = 'b_overwritten'
cfg
<ConfigDict[__main__.some_fn(
'a',
'b_overwritten',
dtype=np.int32,
x=1,
y=2,
)]>class/function - by mutating
cfg.__qualname__ =(to<import>:<name>):
f'Previous qualname: {cfg.__qualname__}';
cfg.__qualname__ = '__main__:other_fn' # Use `other_fn()` instead of `some_fn()`
cfg
'Previous qualname: __main__:some_fn'
<ConfigDict[__main__.other_fn(
'a',
'b_overwritten',
dtype=np.int32,
x=1,
y=2,
)]>constants - by mutating
cfg.__const__ =
cfg.dtype.__const__ = 'numpy:float64'
cfg
<ConfigDict[__main__.other_fn(
'a',
'b_overwritten',
dtype=np.float64,
x=1,
y=2,
)]>konfig.resolve(cfg)
other_fn called with ('a', 'b_overwritten'), {'dtype': <class 'numpy.float64'>, 'x': 1, 'y': 2}
Protocols#
Skip this section unless you’re interested in Konfig internals
In some cases, the can be some interactions between the Python object and it’s Config version. This interaction is done through __konfig_xxx__ protocols and konfig.register_xxx() functions.
konfig.register_alias: For cosmetic only, allow to rewrite the module names (jax.numpy -> jnp, kauldron.kd -> kd,…)
konfig.register_default_values: When a ConfigDict object is created, it should sometimes be initialized with default values. This can be use to initialized child fields, like:
cfg = kd.train.Trainer() # Implicitly create `cfg.setup`,...
cfg.setup.tags = ['text', 'tpu_v3']
__konfig_resolve_exclude_fields__: Fields to exclude from being resolved. Can be used if the config contain objects that should not always be resolved (e.g. to avoid implorting heavy deps when not needed).
__post_konfig_resolve__(self, cfg: konfig.ConfigDict): Function called after
an object is resolved. Can be used so an resolved object keep track of the
config it was created with:
class A:
def __post_konfig_resolve__(self, cfg: konfig.ConfigDict):
self.cfg = cfg