kd.optim.select#
- kauldron.optim.select(
- pattern: str | collections.abc.Sequence[str],
Create a mask which selects only the sub-pytree matching the pattern.
xx will match all {‘xx’: …} dict anywhere inside the tree. Note that the match is strict, so xx will NOT match {‘xxyy’: }
xx.yy will match {‘xx’: {‘yy’: …}} dict
Regex are supported, when using regex, make sure to escape . (e.g. xx.yy[0-9]+)
Example:
mask_fn = kg.optim.select("lora") mask_fn({ 'layer0': { 'lora': { 'a': jnp.zeros(), 'b': jnp.zeros(), }, 'weights': jnp.zeros(), 'bias': jnp.zeros(), } }) == { 'layer0': { 'lora': { 'a': True, 'b': True, }, 'weights': False, 'bias': False, } }
- Parameters:
pattern – The pattern to include. Everything else will be False.
- Returns:
The optax mask factory.