Source code for prxteinmpnn.utils.normalize

"""Layer normalization utilities.

prxteinmpnn.utils.normalize
"""

from __future__ import annotations

from functools import partial
from typing import TYPE_CHECKING

import jax
from jax import numpy as jnp
from jaxtyping import Array, Float

if TYPE_CHECKING:
  from collections.abc import Sequence

  from prxteinmpnn.utils.types import ModelParameters

STANDARD_EPSILON = 1e-5
ScaleConstant = Float[Array, "C"]  # Scale parameter for normalization
OffsetConstant = Float[Array, "C"]  # Offset parameter for normalization


[docs] @partial(jax.jit, static_argnames=("axis", "eps")) def layer_normalization( x: Array, layer_parameters: ModelParameters, axis: int | Sequence[int] | None = -1, eps: float = STANDARD_EPSILON, ) -> Array: """Apply layer normalization to an input tensor in a functional manner. Args: x: The input tensor. layer_parameters: The layer parameters containing 'scale' and 'offset'. axis: The axis or axes to normalize over. Defaults to the last axis. eps: A small epsilon value to prevent division by zero. Returns: The normalized tensor. """ scale = layer_parameters["scale"] offset = layer_parameters["offset"] return normalize( x, scale, offset, axis=axis, eps=eps, )
[docs] @partial(jax.jit, static_argnames=("axis", "eps")) def normalize( x: Array, scale: ScaleConstant, offset: OffsetConstant, axis: int | Sequence[int] | None = -1, eps: float = STANDARD_EPSILON, ) -> Array: """Apply layer normalization to an input tensor in a functional manner. Args: x: The input tensor. scale: The learnable 'gamma' scaling factor. offset: The learnable 'beta' offset factor. axis: The axis or axes to normalize over. Defaults to the last axis. eps: A small epsilon value to prevent division by zero. Returns: The normalized tensor. """ mean = jnp.mean(x, axis=axis, keepdims=True) variance = jnp.var(x, axis=axis, keepdims=True) x_normalized = (x - mean) / jnp.sqrt(variance + eps) return x_normalized * scale + offset