kd.knn.Module

kd.knn.Module#

class kauldron.klinen.Module(
*,
_kd_state: kauldron.klinen.module._ModuleState | None = None,
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>,
name: str | None = None,
)[source]

Bases: flax.linen.module.Module

Base Module class.

init_bind(
rng: int | jax.Array | dict[str, jax.Array],
*args,
streams: tuple[str, ...] = ('dropout',),
**kwargs,
) kauldron.klinen.module._SelfT[source]

Initialize the module, returning a binded version.

with_rng(
rng: int | jax.Array | dict[str, jax.Array] | None = None,
) kauldron.klinen.module._SelfT[source]

Replace the rngs keys.

Can be called:

  • model = model.with_rng(): Replace key with next key

  • model = model.with_rng(0): Create a key from the seed.

  • model = model.with_rng(key): Key distributed among streams

  • model = model.with_rng({‘dropout’: key}): streams explicitly defined

Parameters:

rng – Random key.

Returns:

The updated model with the next key.

property rngs: dict[str, kauldron.random.random.PRNGKey]

Returns dict[str, PRNGKey] mapping key to.

train() kauldron.klinen.module._SelfT[source]

Switch mode to training.

eval() kauldron.klinen.module._SelfT[source]

Switch mode to evaluation (disable dropout,…).

property training: bool

Returns True if mode is training.

property params: flax.core.frozen_dict.FrozenDict[str, collections.abc.Mapping[str, Any]]

Model weights.

param_tree_on() kauldron.klinen.module._SelfT[source]

Makes tree_utils only act on params.

param_tree_off() kauldron.klinen.module._SelfT[source]

Makes tree_utils act on everything.

call_with_intermediates(
*args: Any,
**kwargs: Any,
) tuple[Any, kauldron.klinen.module._SelfT][source]

Call the module with intermediates.

Wrapper around __call__ which also return the intermediate values:

y = model(x)

y, intermediates = model.call_with_intermediates(x)

The intermediate values have the same structure as the model.

Parameters:
  • *args – Arguments forwarded to module.__call__

  • **kwargs – Arguments forwarded to module.__call__

Returns:

module.__call__ output Intermediate values.

capture_intermediates() Iterator[kauldron.klinen.module._SelfT][source]

Track the intermediate values.

Note that this function isn’t meant to be called directly but instead through y, intermediates = model.call_and_capture(x).

Usage:

with model.capture_intermediates() as intermediates:
  y = model(x)  # Model set `model.xxx`

# After the contextmanager end, `intermediates` contain the captured
# intermediate values.
intermediates.xxx
Yields:

The module proxy containing the intermediate values

Raises:

RuntimeError – If contextmanager are nested.

tree_flatten() tuple[list[flax.core.frozen_dict.FrozenDict[str, collections.abc.Mapping[str, Any]]], kauldron.klinen.module.Module][source]

jax.tree_utils support.

classmethod tree_unflatten(
metadata: kauldron.klinen.module.Module,
array_field_values: list[flax.core.frozen_dict.FrozenDict[str, collections.abc.Mapping[str, Any]]],
) kauldron.klinen.module._SelfT[source]
name: str | None = None
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = None
scope: Scope | None = None