Source code for prxteinmpnn.utils.graph

"""Graph utilities for the PrxteinMPNN model.

prxteinmpnn.utils.graph
"""

import jax
import jax.numpy as jnp
from jaxtyping import Array, Int

from prxteinmpnn.utils.types import (
  AtomResidueIndex,
  NeighborIndices,
)

NeighborOffsets = Int[Array, "num_residues num_neighbors"]


[docs] @jax.jit def compute_neighbor_offsets( residue_indices: AtomResidueIndex, neighbor_indices: NeighborIndices, ) -> jax.Array: """Compute offsets between residues for neighbor indices. Args: residue_indices: Residue indices for each atom. neighbor_indices: Indices of neighboring atoms. Returns: A 2D array of offsets where each row corresponds to a residue and each column corresponds to a neighbor. The values represent the difference in residue indices. """ offset = residue_indices[:, None] - residue_indices[None, :] return jnp.take_along_axis(offset, neighbor_indices, axis=1)