kd.knn#
Wrapper aound flax.linen.Module to add torch-like API.
Symbols#
Class#
Dense(features: int, use_bias: bool = True, dtype: Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any, NoneType] = None, param_dtype: Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any] = <class ‘jax.numpy.float32’>, precision: Union[NoneType, str, jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]] = None, kernel_init: Union[jax.nn.initializers.Initializer, collections.abc.Callable[…, Any]] = <function variance_scaling. |
|
Dropout(rate: float, broadcast_dims: collections.abc.Sequence[int] = (), deterministic: bool |
|
Base Module class. |
|
Sequential(layers: collections.abc.Sequence[collections.abc.Callable[…, typing.Any]], parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object at 0x76410688a690>, name: Optional[str] = None, *, _kd_state: ‘Optional[_ModuleState]’ = None) |
Function#
Decorator that convert a flax class into klinen. |