Source code for prxteinmpnn.sampling.initialize

"""Sample sequences from a structure using the ProteinMPNN model.

prxteinmpnn.sampling.initialize
"""

from collections.abc import Callable
from functools import partial

import jax
import jax.numpy as jnp
from jaxtyping import PRNGKeyArray

from prxteinmpnn.model.features import extract_features, project_features
from prxteinmpnn.utils.autoregression import generate_ar_mask
from prxteinmpnn.utils.decoding_order import DecodingOrderFn
from prxteinmpnn.utils.types import (
  AtomMask,
  AutoRegressiveMask,
  ChainIndex,
  DecodingOrder,
  EdgeFeatures,
  ModelParameters,
  NeighborIndices,
  NodeFeatures,
  ResidueIndex,
  StructureAtomicCoordinates,
)

SamplingModelPassInput = tuple[
  PRNGKeyArray,
  StructureAtomicCoordinates,
  AtomMask,
  ResidueIndex,
  ChainIndex,
  ModelParameters,
  int,  # k_neighbors
  float,  # augment_eps
]

SamplingModelPassOutput = tuple[
  NodeFeatures,
  EdgeFeatures,
  NeighborIndices,
  DecodingOrder,
  AutoRegressiveMask,
  PRNGKeyArray,
]

SamplingModelPassFn = Callable[
  [*SamplingModelPassInput],
  SamplingModelPassOutput,
]


[docs] def sampling_encode( encoder: Callable[..., tuple[NodeFeatures, EdgeFeatures]], decoding_order_fn: DecodingOrderFn, ) -> SamplingModelPassFn: """Create a function to run a single pass through the encoder.""" @partial(jax.jit, static_argnames=("k_neighbors", "augment_eps")) def sample_model_pass( prng_key: PRNGKeyArray, model_parameters: ModelParameters, structure_coordinates: StructureAtomicCoordinates, mask: AtomMask, residue_index: ResidueIndex, chain_index: ChainIndex, k_neighbors: int = 48, augment_eps: float = 0.0, ) -> SamplingModelPassOutput: """Run a single pass through the encoder and decoder to prepare for sampling.""" decoding_order, next_rng_key = decoding_order_fn(prng_key, structure_coordinates.shape[0]) autoregressive_mask = generate_ar_mask(decoding_order) edge_features, neighbor_indices, next_rng_key = extract_features( next_rng_key, model_parameters, structure_coordinates, mask, residue_index, chain_index, k_neighbors=k_neighbors, augment_eps=augment_eps, ) edge_features = project_features(model_parameters, edge_features) attention_mask = jnp.take_along_axis( mask[:, None] * mask[None, :], neighbor_indices, axis=1, ) node_features, edge_features = encoder( edge_features, neighbor_indices, mask, attention_mask, ) return ( node_features, edge_features, neighbor_indices, decoding_order, autoregressive_mask, next_rng_key, ) return sample_model_pass