normalize

normalize#

Layer normalization utilities.

prxteinmpnn.utils.normalize

prxteinmpnn.utils.normalize.layer_normalization(x, layer_parameters, axis=-1, eps=1e-05)[source]#

Apply layer normalization to an input tensor in a functional manner.

Parameters:
  • x (Array) – The input tensor.

  • layer_parameters (PyTree[str, 'P']) – The layer parameters containing ‘scale’ and ‘offset’.

  • axis (int | Sequence[int] | None) – The axis or axes to normalize over. Defaults to the last axis.

  • eps (float) – A small epsilon value to prevent division by zero.

Return type:

Array

Returns:

The normalized tensor.

prxteinmpnn.utils.normalize.normalize(x, scale, offset, axis=-1, eps=1e-05)[source]#

Apply layer normalization to an input tensor in a functional manner.

Parameters:
  • x (Array) – The input tensor.

  • scale (Float[Array, 'C']) – The learnable ‘gamma’ scaling factor.

  • offset (Float[Array, 'C']) – The learnable ‘beta’ offset factor.

  • axis (int | Sequence[int] | None) – The axis or axes to normalize over. Defaults to the last axis.

  • eps (float) – A small epsilon value to prevent division by zero.

Return type:

Array

Returns:

The normalized tensor.