Source code for prxteinmpnn.utils.radial_basis

"""Radial basis functions for distance encoding.

prxteinmpnn.utils.radial_basis
"""

import jax
import jax.numpy as jnp
from jaxtyping import Array, Float

from .types import AtomIndexPair, BackboneCoordinates, NeighborIndices

AllAtomRBF = Float[Array, "L K R D"]
AtomPairRBF = Float[Array, "K R D"]

RADIAL_BASES = 16
RADIAL_BASE_MINIMUM, RADIAL_BASE_MAXIMUM = 2.0, 22.0
RBF_CENTERS = jnp.linspace(RADIAL_BASE_MINIMUM, RADIAL_BASE_MAXIMUM, RADIAL_BASES)
RBF_SIGMA = (RADIAL_BASE_MAXIMUM - RADIAL_BASE_MINIMUM) / RADIAL_BASES

BACKBONE_PAIRS = jnp.array(
  [
    [1, 1],
    [0, 0],
    [2, 2],
    [3, 3],
    [4, 4],
    [1, 0],
    [1, 2],
    [1, 3],
    [1, 4],
    [0, 2],
    [0, 3],
    [0, 4],
    [4, 2],
    [4, 3],
    [3, 2],
    [0, 1],
    [2, 1],
    [3, 1],
    [4, 1],
    [2, 0],
    [3, 0],
    [4, 0],
    [2, 4],
    [3, 4],
    [2, 3],
  ],
)

DISTANCE_EPSILON = 1e-6


[docs] @jax.jit def compute_radial_basis( backbone_coordinates: BackboneCoordinates, neighbor_indices: NeighborIndices, ) -> AllAtomRBF: """Compute the radial basis functions for backbone coordinates.""" def _rbf(pair: AtomIndexPair, neighbor_indices: NeighborIndices) -> AtomPairRBF: """Compute the radial basis function for a given pair of atoms.""" atom1, atom2 = backbone_coordinates[:, pair[0], :], backbone_coordinates[:, pair[1], :] delta_coords = atom1[:, None, :] - atom2[None, :, :] distance_sq = jnp.sum(jnp.square(delta_coords), axis=-1) distance = jnp.sqrt(DISTANCE_EPSILON + distance_sq) neighbor_distances = jnp.take_along_axis(distance, neighbor_indices, axis=1) return jnp.exp( -(jnp.square((neighbor_distances[..., None] - RBF_CENTERS) / RBF_SIGMA)), ) return ( jax.vmap(lambda pair: _rbf(pair, neighbor_indices))(BACKBONE_PAIRS) .transpose((1, 2, 0, 3)) .reshape( backbone_coordinates.shape[0], neighbor_indices.shape[1], -1, ) )