kd.optim.select

kd.optim.select#

kauldron.optim.select(
pattern: str | collections.abc.Sequence[str],
) collections.abc.Callable[[Any], Any][source]

Create a mask which selects only the sub-pytree matching the pattern.

  • xx will match all {‘xx’: …} dict anywhere inside the tree. Note that the match is strict, so xx will NOT match {‘xxyy’: }

  • xx.yy will match {‘xx’: {‘yy’: …}} dict

  • Regex are supported, when using regex, make sure to escape . (e.g. xx.yy[0-9]+)

Example:

mask_fn = kg.optim.select("lora")

mask_fn({
    'layer0': {
        'lora': {
            'a': jnp.zeros(),
            'b': jnp.zeros(),
        },
        'weights': jnp.zeros(),
        'bias': jnp.zeros(),
    }
}) == {
    'layer0': {
        'lora': {
            'a': True,
            'b': True,
        },
        'weights': False,
        'bias': False,
    }
}
Parameters:

pattern – The pattern to include. Everything else will be False.

Returns:

The optax mask factory.