Source code for prxteinmpnn.sampling.conditional_logits

"""Factory for creating conditional logits functions.

Conditional logits are computed given a specific sequence input,
allowing the model to evaluate how well a sequence fits a structure.

This is used for:
- Jacobian computation (sensitivity analysis)
- Sequence scoring and validation
- Conformational inference
"""

from __future__ import annotations

from functools import partial
from typing import TYPE_CHECKING

import equinox as eqx
import jax


# Avoid Equinox attempting to hash module fields containing JAX arrays
# during JAX tracing (which can raise TypeError: unhashable type).
# Use object id-based hash to make Module hashable in tracing/cache contexts.
def _eqx_module_hash(self: object) -> int:  # pragma: no cover - safe shim
  return id(self)


eqx.Module.__hash__ = _eqx_module_hash

# Provide a runtime-friendly fallback for the ConditionalLogitsFn symbol so that
# modules (and tests) can import the name at runtime. The precise, detailed
# type alias is created only under TYPE_CHECKING to avoid importing heavy or
# optional typing modules at runtime.
if TYPE_CHECKING:
  from collections.abc import Callable

  from jaxtyping import PRNGKeyArray

  from prxteinmpnn.model import PrxteinMPNN
  from prxteinmpnn.utils.types import (
    AlphaCarbonMask,
    AutoRegressiveMask,
    BackboneNoise,
    ChainIndex,
    Logits,
    ProteinSequence,
    ResidueIndex,
    StructureAtomicCoordinates,
  )

  ConditionalLogitsFn = Callable[
    [
      PRNGKeyArray,
      StructureAtomicCoordinates,
      AlphaCarbonMask,
      ResidueIndex,
      ChainIndex,
      ProteinSequence,
      AutoRegressiveMask | None,
      BackboneNoise | None,
    ],
    Logits,
  ]
else:
  # Runtime fallback: a generic callable returning Any. This keeps imports safe
  # at runtime while allowing test modules to import the symbol.
  from collections.abc import Callable
  from typing import Any

  ConditionalLogitsFn = Callable[..., Any]


[docs] def make_conditional_logits_fn( model: PrxteinMPNN, ) -> ConditionalLogitsFn: """Create a function to compute conditional logits for a given sequence. Conditional logits evaluate how well a sequence fits a structure by running the model with the sequence as input. Args: model: A PrxteinMPNN Equinox model instance. Returns: A function that computes conditional logits for sequence-structure pairs. Example: >>> from prxteinmpnn.io.weights import load_model >>> model = load_model() >>> logits_fn = make_conditional_logits_fn(model) >>> logits = logits_fn(key, coords, mask, res_idx, chain_idx, sequence) """ @partial(jax.jit) def conditional_logits( prng_key: PRNGKeyArray, structure_coordinates: StructureAtomicCoordinates, mask: AlphaCarbonMask, residue_index: ResidueIndex, chain_index: ChainIndex, sequence: ProteinSequence, ar_mask: AutoRegressiveMask | None = None, backbone_noise: BackboneNoise | None = None, ) -> Logits: """Compute conditional logits for a sequence-structure pair. Args: prng_key: JAX random key (unused but kept for API consistency). structure_coordinates: Atomic coordinates (N, 4, 3). mask: Alpha carbon mask indicating valid residues. residue_index: Residue indices. chain_index: Chain indices. sequence: Protein sequence as integer array (N,) or one-hot (N, 21). ar_mask: Optional autoregressive mask (N, N). backbone_noise: Optional noise for backbone coordinates. Returns: Logits of shape (N, 21) for each residue position. Example: >>> logits = conditional_logits( ... key, coords, mask, res_idx, chain_idx, sequence ... ) """ # Keep prng_key available for feature extraction below. # Manually run feature extraction and the model's conditional path # to avoid dispatch through jax.lax.switch (which can trigger other # branches under tracing). This keeps the conditional logits path # explicit and avoids dynamic indexing issues in other branches. edge_features, neighbor_indices, _ = model.features( prng_key, structure_coordinates, mask, residue_index, chain_index, backbone_noise, ) ar_mask = ( jax.numpy.zeros((mask.shape[0], mask.shape[0]), dtype=jax.numpy.int32) if ar_mask is None else ar_mask ) # Default multi-state parameters for conditional logit computation _multi_state_strategy_idx = jax.numpy.array(0, dtype=jax.numpy.int32) # 0 = "mean" _multi_state_alpha = 0.5 # Call the model's conditional path directly _, logits = model._call_conditional( # noqa: SLF001 edge_features, neighbor_indices, mask, ar_mask, sequence, prng_key, 0.0, # temperature unused in conditional path jax.numpy.zeros((mask.shape[0], 21), dtype=jax.numpy.float32), None, # tie_group_map not used in jacobian computation _multi_state_strategy_idx, _multi_state_alpha, ) return logits return conditional_logits
[docs] def make_encoding_conditional_logits_split_fn( model: PrxteinMPNN, ) -> tuple[Callable, Callable]: """Create separate encoding and decoding functions for averaged encodings. This splits the model into two parts: 1. Encoding: Structure -> Encoder features (node_features, edge_features, neighbor_indices) 2. Decoding: (Encoder features, Sequence) -> Logits This separation allows: - Averaging encoder features across multiple noise levels - Efficient jacobian computation by caching encoder output - Reusing encoder output for multiple sequence evaluations Args: model: A PrxteinMPNN Equinox model instance. Returns: Tuple of (encode_fn, decode_fn) where: - encode_fn: Computes encoder features from structure - decode_fn: Computes logits from cached features and sequence Example: >>> encode_fn, decode_fn = make_encoding_conditional_logits_split_fn(model) >>> # Encode once >>> key = jax.random.key(0) >>> encoding = encode_fn(key, coords, mask, res_idx, chain_idx, noise=0.1) >>> # Decode multiple sequences using same encoding >>> logits1 = decode_fn(encoding, sequence1) >>> logits2 = decode_fn(encoding, sequence2) """ def encode_fn( structure_coordinates: StructureAtomicCoordinates, mask: AlphaCarbonMask, residue_index: ResidueIndex, chain_index: ChainIndex, backbone_noise: BackboneNoise | None = None, prng_key: PRNGKeyArray | None = None, ) -> tuple: """Encode structure to get encoder features. Args: prng_key: JAX random key for feature extraction. structure_coordinates: Atomic coordinates (N, 4, 3). mask: Alpha carbon mask indicating valid residues. residue_index: Residue indices. chain_index: Chain indices. backbone_noise: Optional noise for backbone coordinates. Returns: Tuple of (node_features, edge_features, neighbor_indices, mask, ar_mask_placeholder) where ar_mask_placeholder is zeros to maintain consistent shape. """ if backbone_noise is None: backbone_noise = jax.numpy.array(0.0, dtype=jax.numpy.float32) if prng_key is None: # Use a fixed deterministic key when none is provided to keep behavior # deterministic in contexts that don't supply a PRNGKey. prng_key = jax.random.PRNGKey(0) # Run feature extraction edge_features, neighbor_indices, _ = model.features( prng_key, structure_coordinates, mask, residue_index, chain_index, backbone_noise, ) # Run encoder node_features, processed_edge_features = model.encoder( edge_features, neighbor_indices, mask, ) # Return encoder outputs + metadata needed for decoding # Include ar_mask placeholder (zeros) for shape consistency ar_mask_placeholder = jax.numpy.zeros((mask.shape[0], mask.shape[0]), dtype=jax.numpy.int32) return (node_features, processed_edge_features, neighbor_indices, mask, ar_mask_placeholder) def decode_fn( encoding: tuple, sequence: ProteinSequence, ar_mask: AutoRegressiveMask | None = None, ) -> Logits: """Decode encoder features to logits for a given sequence. Args: encoding: Tuple of (node_features, edge_features, neighbor_indices, mask, _) from encode_fn. sequence: Protein sequence as integer array (N,) or one-hot (N, 21). ar_mask: Optional autoregressive mask (N, N). If None, uses zeros. Returns: Logits of shape (N, 21) for each residue position. """ node_features, processed_edge_features, neighbor_indices, mask, _ = encoding if ar_mask is None: ar_mask = jax.numpy.zeros((mask.shape[0], mask.shape[0]), dtype=jax.numpy.int32) # Ensure sequence is one-hot encoded if sequence.ndim == 1: # Convert from integer to one-hot one_hot_sequence = jax.nn.one_hot(sequence, model.w_s_embed.num_embeddings) else: one_hot_sequence = sequence # Run decoder in conditional mode decoded_node_features = model.decoder.call_conditional( node_features, processed_edge_features, neighbor_indices, mask, ar_mask, one_hot_sequence, model.w_s_embed.weight, ) # Project to logits return jax.vmap(model.w_out)(decoded_node_features) return encode_fn, decode_fn