"""Score a given sequence on a structure using the ProteinMPNN model."""
from collections.abc import Callable
from functools import partial
from typing import TYPE_CHECKING, cast
import jax
import jax.numpy as jnp
from jaxtyping import Float, PRNGKeyArray
if TYPE_CHECKING:
from prxteinmpnn.model.decoding_signatures import RunConditionalDecoderFn
from prxteinmpnn.model.decoder import make_decoder
from prxteinmpnn.model.encoder import make_encoder
from prxteinmpnn.model.features import extract_features, project_features
from prxteinmpnn.model.projection import final_projection
from prxteinmpnn.utils.autoregression import generate_ar_mask
from prxteinmpnn.utils.decoding_order import DecodingOrderFn, random_decoding_order
from prxteinmpnn.utils.residue_constants import atom_order
from prxteinmpnn.utils.types import (
AtomMask,
AutoRegressiveMask,
BackboneNoise,
ChainIndex,
DecodingOrder,
Logits,
ModelParameters,
OneHotProteinSequence,
ProteinSequence,
ResidueIndex,
StructureAtomicCoordinates,
)
ScoringFn = Callable[
[
PRNGKeyArray,
ProteinSequence,
StructureAtomicCoordinates,
AtomMask,
ResidueIndex,
ChainIndex,
int,
BackboneNoise | None,
AutoRegressiveMask | None,
],
tuple[Float, Logits, DecodingOrder],
]
SCORE_EPS = 1e-8
[docs]
def make_score_sequence(
model_parameters: ModelParameters,
decoding_order_fn: DecodingOrderFn = random_decoding_order,
num_encoder_layers: int = 3,
num_decoder_layers: int = 3,
) -> ScoringFn:
"""Create a function to score a sequence on a structure."""
encoder = make_encoder(
model_parameters=model_parameters,
attention_mask_type="cross",
num_encoder_layers=num_encoder_layers,
)
decoder: RunConditionalDecoderFn = cast(
"RunConditionalDecoderFn",
make_decoder(
model_parameters=model_parameters,
attention_mask_type=None,
decoding_approach="conditional",
num_decoder_layers=num_decoder_layers,
),
)
@partial(jax.jit, static_argnames=("k_neighbors",))
def score_sequence(
prng_key: PRNGKeyArray,
sequence: ProteinSequence | OneHotProteinSequence,
structure_coordinates: StructureAtomicCoordinates,
mask: AtomMask,
residue_index: ResidueIndex,
chain_index: ChainIndex,
k_neighbors: int = 48,
backbone_noise: BackboneNoise | None = None,
ar_mask: AutoRegressiveMask | None = None,
) -> tuple[Float, Logits, DecodingOrder]:
"""Score a sequence on a structure using the ProteinMPNN model."""
decoding_order, prng_key = decoding_order_fn(prng_key, sequence.shape[0])
autoregressive_mask = generate_ar_mask(decoding_order) if ar_mask is None else ar_mask
if sequence.ndim == 1:
sequence = jax.nn.one_hot(sequence, num_classes=21)
residue_mask = mask[:, atom_order["CA"]]
edge_features, neighbor_indices, prng_key = extract_features(
prng_key,
model_parameters,
structure_coordinates,
residue_mask,
residue_index,
chain_index,
k_neighbors=k_neighbors,
backbone_noise=backbone_noise,
)
edge_features = project_features(
model_parameters,
edge_features,
)
attention_mask = jnp.take_along_axis(
residue_mask[:, None] * residue_mask[None, :],
neighbor_indices,
axis=1,
)
node_features, edge_features = encoder(
edge_features,
neighbor_indices,
residue_mask,
attention_mask,
)
node_features = decoder(
node_features,
edge_features,
neighbor_indices,
residue_mask,
autoregressive_mask,
sequence,
)
logits = final_projection(
model_parameters,
node_features,
)
log_probability = jax.nn.log_softmax(logits, axis=-1)[..., :20]
if sequence.ndim == 1:
sequence = jax.nn.one_hot(sequence, num_classes=21)
score = -(sequence[..., :20] * log_probability).sum(-1)
masked_score_sum = (score * residue_mask).sum(-1)
mask_sum = residue_mask.sum() + SCORE_EPS
return masked_score_sum / mask_sum, logits, decoding_order
return score_sequence