kd.knn.Module#
- class kauldron.klinen.Module(
- *,
- _kd_state: kauldron.klinen.module._ModuleState | None = None,
- parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>,
- name: str | None = None,
Bases:
flax.linen.module.ModuleBase Module class.
- init_bind(
- rng: int | jax.Array | dict[str, jax.Array],
- *args,
- streams: tuple[str, ...] = ('dropout',),
- **kwargs,
Initialize the module, returning a binded version.
- with_rng(
- rng: int | jax.Array | dict[str, jax.Array] | None = None,
Replace the rngs keys.
Can be called:
model = model.with_rng(): Replace key with next key
model = model.with_rng(0): Create a key from the seed.
model = model.with_rng(key): Key distributed among streams
model = model.with_rng({‘dropout’: key}): streams explicitly defined
- Parameters:
rng – Random key.
- Returns:
The updated model with the next key.
- property rngs: dict[str, kauldron.random.random.PRNGKey]
Returns dict[str, PRNGKey] mapping key to.
- train() kauldron.klinen.module._SelfT[source]
Switch mode to training.
- eval() kauldron.klinen.module._SelfT[source]
Switch mode to evaluation (disable dropout,…).
- property training: bool
Returns True if mode is training.
- property params: flax.core.frozen_dict.FrozenDict[str, collections.abc.Mapping[str, Any]]
Model weights.
- param_tree_on() kauldron.klinen.module._SelfT[source]
Makes tree_utils only act on params.
- param_tree_off() kauldron.klinen.module._SelfT[source]
Makes tree_utils act on everything.
- call_with_intermediates(
- *args: Any,
- **kwargs: Any,
Call the module with intermediates.
Wrapper around __call__ which also return the intermediate values:
y = model(x) y, intermediates = model.call_with_intermediates(x)
The intermediate values have the same structure as the model.
- Parameters:
*args – Arguments forwarded to module.__call__
**kwargs – Arguments forwarded to module.__call__
- Returns:
module.__call__ output Intermediate values.
- capture_intermediates() Iterator[kauldron.klinen.module._SelfT][source]
Track the intermediate values.
Note that this function isn’t meant to be called directly but instead through y, intermediates = model.call_and_capture(x).
Usage:
with model.capture_intermediates() as intermediates: y = model(x) # Model set `model.xxx` # After the contextmanager end, `intermediates` contain the captured # intermediate values. intermediates.xxx
- Yields:
The module proxy containing the intermediate values
- Raises:
RuntimeError – If contextmanager are nested.
- tree_flatten() tuple[list[flax.core.frozen_dict.FrozenDict[str, collections.abc.Mapping[str, Any]]], kauldron.klinen.module.Module][source]
jax.tree_utils support.
- classmethod tree_unflatten(
- metadata: kauldron.klinen.module.Module,
- array_field_values: list[flax.core.frozen_dict.FrozenDict[str, collections.abc.Mapping[str, Any]]],
- name: str | None = None
- parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = None
- scope: Scope | None = None