kd.nn.TransformerMLP#
- class kauldron.modules.TransformerMLP(hidden_size: int | None = None, activation_fn: typing.Callable[[jaxtyping.Float[Array, '*any'] | jaxtyping.Float[ndarray, '*any']], jaxtyping.Float[Array, '*any'] | jaxtyping.Float[ndarray, '*any']] = <function gelu>, kernel_init: typing.Callable[[jaxtyping.UInt32[Array, '2'] | jaxtyping.UInt32[ndarray, '2'] | jax.Array, kauldron.typing.shape_spec.Shape, str | type[typing.Any] | numpy.dtype | jax._src.typing.SupportsDType], kauldron.typing.array_types.Array] = <function variance_scaling.<locals>.init>, bias_init: typing.Callable[[jaxtyping.UInt32[Array, '2'] | jaxtyping.UInt32[ndarray, '2'] | jax.Array, kauldron.typing.shape_spec.Shape, str | type[typing.Any] | numpy.dtype | jax._src.typing.SupportsDType], kauldron.typing.array_types.Array] = <function zeros>, parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]
Bases:
flax.linen.module.ModuleSimple MLP with a single hidden layer for use in Transformer blocks.
- hidden_size: int | None = None
- activation_fn(
- approximate: bool = True,
Gaussian error linear unit activation function.
If
approximate=False, computes the element-wise function:\[\mathrm{gelu}(x) = \frac{x}{2} \left(\mathrm{erfc} \left( \frac{-x}{\sqrt{2}} \right) \right)\]If
approximate=True, uses the approximate formulation of GELU:\[\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left( \sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)\]For more information, see Gaussian Error Linear Units (GELUs), section 2.
- Parameters:
x – input array
approximate – whether to use the approximate or exact formulation.
- kernel_init(shape: collections.abc.Sequence[int | typing.Any], dtype: typing.Any = <class 'jax.numpy.float64'>, out_sharding=None) jax.Array
- bias_init(shape: collections.abc.Sequence[int | typing.Any], dtype: typing.Any = <class 'jax.numpy.float64'>) jax.Array
An initializer that returns a constant array full of zeros.
The
keyargument is ignored.>>> import jax, jax.numpy as jnp >>> jax.nn.initializers.zeros(jax.random.key(42), (2, 3), jnp.float32) Array([[0., 0., 0.], [0., 0., 0.]], dtype=float32)
- name: str | None = None
- parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = None
- scope: Scope | None = None