kd.knn

kd.knn#

[[Source]]

Wrapper aound flax.linen.Module to add torch-like API.

Symbols#

Class#

kd.knn.Dense

Dense(features: int, use_bias: bool = True, dtype: Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any, NoneType] = None, param_dtype: Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any] = <class ‘jax.numpy.float32’>, precision: Union[NoneType, str, jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]] = None, kernel_init: Union[jax.nn.initializers.Initializer, collections.abc.Callable[…, Any]] = <function variance_scaling..init at 0x7824acfd7f60>, bias_init: Union[jax.nn.initializers.Initializer, collections.abc.Callable[…, Any]] = <function zeros at 0x7824b06c2f20>, promote_dtype: flax.linen.linear.PromoteDtypeFn = <function promote_dtype at 0x7824ad016d40>, dot_general: collections.abc.Callable[…, typing.Union[jax.Array, typing.Any]]

kd.knn.Dropout

Dropout(rate: float, broadcast_dims: collections.abc.Sequence[int] = (), deterministic: bool

kd.knn.Module

Base Module class.

kd.knn.Sequential

Sequential(layers: collections.abc.Sequence[collections.abc.Callable[…, typing.Any]], parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object at 0x7824acff2d90>, name: Optional[str] = None, *, _kd_state: ‘Optional[_ModuleState]’ = None)

Function#

kd.knn.convert

Decorator that convert a flax class into klinen.

Typing#