kd.nn.TransformerMLP

kd.nn.TransformerMLP#

class kauldron.modules.TransformerMLP(hidden_size: int | None = None, activation_fn: typing.Callable[[jaxtyping.Float[Array, '*any'] | jaxtyping.Float[ndarray, '*any']], jaxtyping.Float[Array, '*any'] | jaxtyping.Float[ndarray, '*any']] = <function gelu>, kernel_init: typing.Callable[[jaxtyping.UInt32[Array, '2'] | jaxtyping.UInt32[ndarray, '2'] | jax.Array, kauldron.typing.shape_spec.Shape, str | type[typing.Any] | numpy.dtype | jax._src.typing.SupportsDType], kauldron.typing.array_types.Array] = <function variance_scaling.<locals>.init>, bias_init: typing.Callable[[jaxtyping.UInt32[Array, '2'] | jaxtyping.UInt32[ndarray, '2'] | jax.Array, kauldron.typing.shape_spec.Shape, str | type[typing.Any] | numpy.dtype | jax._src.typing.SupportsDType], kauldron.typing.array_types.Array] = <function zeros>, parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]

Bases: flax.linen.module.Module

Simple MLP with a single hidden layer for use in Transformer blocks.

hidden_size: int | None = None
activation_fn(
approximate: bool = True,
) jax.Array

Gaussian error linear unit activation function.

If approximate=False, computes the element-wise function:

\[\mathrm{gelu}(x) = \frac{x}{2} \left(\mathrm{erfc} \left( \frac{-x}{\sqrt{2}} \right) \right)\]

If approximate=True, uses the approximate formulation of GELU:

\[\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left( \sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)\]

For more information, see Gaussian Error Linear Units (GELUs), section 2.

Parameters:
  • x – input array

  • approximate – whether to use the approximate or exact formulation.

kernel_init(
shape: collections.abc.Sequence[int | Any],
dtype: Any | None = None,
out_sharding: jax.sharding.NamedSharding | jax.sharding.PartitionSpec | None = None,
) jax.Array
bias_init(
shape: collections.abc.Sequence[int | Any],
dtype: Any | None = None,
out_sharding: jax.sharding.NamedSharding | jax.sharding.PartitionSpec | None = None,
) jax.Array

An initializer that returns a constant array full of zeros.

The key argument is ignored.

>>> import jax, jax.numpy as jnp
>>> jax.nn.initializers.zeros(jax.random.key(42), (2, 3), jnp.float32)
Array([[0., 0., 0.],
       [0., 0., 0.]], dtype=float32)
name: str | None = None
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = None
scope: Scope | None = None