"""Score a given sequence on a structure using the ProteinMPNN model."""
from collections.abc import Callable
from functools import partial
import jax
from jaxtyping import Float, PRNGKeyArray
from prxteinmpnn.model import PrxteinMPNN
from prxteinmpnn.utils.autoregression import generate_ar_mask
from prxteinmpnn.utils.decoding_order import DecodingOrderFn, random_decoding_order
from prxteinmpnn.utils.types import (
AlphaCarbonMask,
AutoRegressiveMask,
BackboneNoise,
ChainIndex,
DecodingOrder,
Logits,
OneHotProteinSequence,
ProteinSequence,
ResidueIndex,
StructureAtomicCoordinates,
)
ScoringFn = Callable[
[
PRNGKeyArray,
ProteinSequence,
StructureAtomicCoordinates,
AlphaCarbonMask,
ResidueIndex,
ChainIndex,
int,
BackboneNoise | None,
AutoRegressiveMask | None,
],
tuple[Float, Logits, DecodingOrder],
]
SCORE_EPS = 1e-8
[docs]
def make_score_sequence(
model: PrxteinMPNN,
decoding_order_fn: DecodingOrderFn = random_decoding_order,
_num_encoder_layers: int = 3,
_num_decoder_layers: int = 3,
) -> ScoringFn:
"""Create a function to score a sequence on a structure using PrxteinMPNN.
Args:
model: A PrxteinMPNN Equinox model instance.
decoding_order_fn: Function to generate decoding order (default: random).
_num_encoder_layers: Deprecated, ignored (kept for API compatibility).
_num_decoder_layers: Deprecated, ignored (kept for API compatibility).
Returns:
A function that scores sequences on structures.
Example:
>>> from prxteinmpnn.io.weights import load_model
>>> model = load_model()
>>> score_fn = make_score_sequence(model)
>>> score, logits, order = score_fn(key, seq, coords, mask, res_idx, chain_idx)
"""
@partial(jax.jit, static_argnames=("_k_neighbors",))
def score_sequence(
prng_key: PRNGKeyArray,
sequence: ProteinSequence | OneHotProteinSequence,
structure_coordinates: StructureAtomicCoordinates,
mask: AlphaCarbonMask,
residue_index: ResidueIndex,
chain_index: ChainIndex,
_k_neighbors: int = 48,
backbone_noise: BackboneNoise | None = None,
ar_mask: AutoRegressiveMask | None = None,
) -> tuple[Float, Logits, DecodingOrder]:
"""Score a sequence on a structure using the ProteinMPNN model.
Args:
prng_key: JAX random key.
sequence: Protein sequence (integer indices or one-hot).
structure_coordinates: Atomic coordinates (N, 4, 3).
mask: Alpha carbon mask indicating valid residues.
residue_index: Residue indices.
chain_index: Chain indices.
_k_neighbors: Deprecated, model handles internally (kept for API compatibility).
backbone_noise: Optional noise for backbone coordinates.
ar_mask: Optional custom autoregressive mask.
Returns:
Tuple of (average score, logits, decoding order).
Example:
>>> score, logits, order = score_sequence(
... key, seq, coords, mask, res_idx, chain_idx
... )
"""
decoding_order, prng_key = decoding_order_fn(prng_key, sequence.shape[0], None, None)
autoregressive_mask = generate_ar_mask(decoding_order) if ar_mask is None else ar_mask
# Ensure sequence is one-hot encoded
if sequence.ndim == 1:
sequence = jax.nn.one_hot(sequence, num_classes=21)
# Run model in conditional mode (scoring a given sequence)
_, logits = model(
structure_coordinates,
mask,
residue_index,
chain_index,
decoding_approach="conditional",
prng_key=prng_key,
ar_mask=autoregressive_mask,
one_hot_sequence=sequence,
temperature=0.0, # Not used in conditional mode
bias=None, # No bias in scoring
backbone_noise=backbone_noise,
)
# Compute score from logits
log_probability = jax.nn.log_softmax(logits, axis=-1)[..., :20]
score = -(sequence[..., :20] * log_probability).sum(-1)
masked_score_sum = (score * mask).sum(-1)
mask_sum = mask.sum() + SCORE_EPS
return masked_score_sum / mask_sum, logits, decoding_order
return score_sequence