Klinen - Torch-like API for flax#
Klinen is a small wrapper around flax.linen. The goal is to provide a
stateless, object-oriented, supporting auto-complete and type checking.
Documentation#
Model creation#
Model creation is similar to flax, except the modules should inherit from
klinen instead of linen:
from flax import linen as nn
from kauldron import klinen as knn
class MLP(knn.Module):
@nn.compact
def __call__(self, x: jax.Array) -> jax.Array:
return nn.Dense(32)(x)
class AutoEncoder(knn.Module):
encoder: knn.Module
decoder: knn.Module
@nn.compact
def __call__(self, x: jax.Array) -> jax.Array:
return self.decoder(self.encoder(x))
Inside
knn.Module, any linen modules can be used.
Note: Alternatively, you can convert any existing flax.linen.Module to klinen
with MyKlinenModule = knn.convert(MyFlaxModule)
Model initialization / usage#
To initialize the model, use model.init_bind() instead of model.init(). It
will return a copy of the module with bind parameters.
model = AutoEncoder(
encoder=MLP(),
decoder=MLP(),
)
model = model.init_bind(rng, jnp.zeros((batch_size, 64)))
# After the model is initialized, it can be called directly
y = model(x)
# You can also call individual sub-modules
y = model.encoder(x)
Randomness#
klinen modules are stateless, this mean they are fully deterministic. Calling
the same model twice will always return the same result. If your model uses
randomness (e.g. nn.Dropout), the rng key has to be explicitly provided:
model = model.with_rng(rng) # Replace the current rng.
y0 = model(x)
y1 = model(x)
assert jnp.allclose(y0, y1) # Calling the model twice give the same output
model = model.with_rng(rng2) # Set a new rng
Multiple values are accepted:
model.with_rng({'dropout': rng}): Rng streams explicitly definedmodel.with_rng(rng): Key distributed among streams (withrng.fold_in(stream_name))model.with_rng(): no-argument provided, split the currentrngto get the next one.
Calling model(x) before a key was provided with .with_rng will yield an
error the first time.
Currently, there’s no guarantee that the encoder called in model(x) or
directly with model.encoder(x) have the same rng. This will be fixed in the
future.
Training/eval mode#
To disable determinism, models can use the self.training attribute:
class MLP(knn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(2)(x)
x = nn.Dropout(0.5)(x, deterministic=not self.training)
return x
By default, model.training == True. You can switch the model to eval mode
with .eval()
model = model.eval() # Switch to eval mode
assert not model.training
model = model.train() # Switch back to train mode
assert model.training
Parameters#
You can access the flax parameters, either at the root level or for individual modules.
model.params
model.encoder.params
model.encoder.params['Dense_0'] # nn.Dense params defined inside `nn.compact`
Jit, auto-diff#
knn.Module are compatible with jax.tree_utils to map over the parameters.
This means modules can be used nativelly inside jax.jit:
@jax.jit
def eval_step(model: knn.Model, x: jax.Array, y: jax.Array) -> jax.Array:
model = model.eval()
y_pred = model(x)
return loss(y_pred, y)
Intermediate values#
Often, it’s very convenient to be able to store/access intermediate values
in the module tree. It is possible by annotating module fields as
knn.Intermediate[T] = dataclasses.field(init=False).
class Sequential(knn.Module):
childs: list[nn.Module]
tmp_values: knn.Intermediate[list[jax.Array]] = dataclasses.field(
init=False,
default_factory=list,
)
@nn.compact
def __call__(self, x: jax.Array) -> jax.Array:
for child in childs:
x = child(x)
self.tmp_values.append(x)
return x
The intermediate values are reset at each call (calling model() twice
will create a new tmp_values list). Intermediate values are not bound to
the model object. Instead they need to be explicitly fetched:
model = AutoEncoder(
encoder=Sequential([
nn.Dense(32),
nn.Dense(32),
]),
decoder=MLP(),
)
model = model.init_bind(rng, x)
y = model(x) # Standard call (no intermediate)
with model.capture_intermediates() as intermediates:
y = model(x)
# Convenience wrapper around `model.capture_intermediates()`
y, intermediates = model.call_with_intermediate(x)
# `intermediates` has the same structure as the `model`, but only sub-modules
# and `knn.Intermediate` fields are available.
tmp_values = intermediates.encoder.tmp_values