Source code for prxteinmpnn.model.mpnn

"""Main ProteinMPNN model implementation.

This module contains the top-level PrxteinMPNN model that combines all components.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Literal

import equinox as eqx
import jax
import jax.numpy as jnp

from prxteinmpnn.model.decoder import Decoder
from prxteinmpnn.model.encoder import Encoder
from prxteinmpnn.model.features import ProteinFeatures
from prxteinmpnn.model.multi_state_sampling import (
  max_min_over_group_logits,
  min_over_group_logits,
  product_of_probabilities_logits,
)
from prxteinmpnn.utils.concatenate import concatenate_neighbor_nodes
from prxteinmpnn.utils.ste import straight_through_estimator

if TYPE_CHECKING:
  from prxteinmpnn.utils.types import (
    AlphaCarbonMask,
    AutoRegressiveMask,
    BackboneNoise,
    ChainIndex,
    EdgeFeatures,
    Float,
    Int,
    Logits,
    NeighborIndices,
    NodeFeatures,
    OneHotProteinSequence,
    PRNGKeyArray,
    ResidueIndex,
    StructureAtomicCoordinates,
  )

# Define decoding approach type
DecodingApproach = Literal["unconditional", "conditional", "autoregressive"]


[docs] class PrxteinMPNN(eqx.Module): """The complete end-to-end ProteinMPNN model.""" features: ProteinFeatures encoder: Encoder decoder: Decoder # Feature embedding layers w_s_embed: eqx.nn.Embedding # For sequence # Final projection w_out: eqx.nn.Linear # Store dimensions as static metadata node_features_dim: int = eqx.field(static=True) edge_features_dim: int = eqx.field(static=True) num_decoder_layers: int = eqx.field(static=True) def __init__( self, node_features: int, edge_features: int, hidden_features: int, num_encoder_layers: int, num_decoder_layers: int, k_neighbors: int, num_amino_acids: int = 21, vocab_size: int = 21, # for w_s *, key: PRNGKeyArray, ) -> None: """Initialize the complete model. Args: node_features: Dimension of node features (e.g., 128). edge_features: Dimension of edge features (e.g., 128). hidden_features: Dimension of hidden layer in encoder/decoder. num_encoder_layers: Number of encoder layers. num_decoder_layers: Number of decoder layers. k_neighbors: Number of nearest neighbors for graph construction. num_amino_acids: Number of amino acid types (default: 21). vocab_size: Size of sequence vocabulary (default: 21). key: PRNG key for initialization. Returns: None Raises: None Example: >>> key = jax.random.PRNGKey(0) >>> model = PrxteinMPNN(128, 128, 128, 3, 3, 30, key=key) """ self.node_features_dim = node_features self.edge_features_dim = edge_features self.num_decoder_layers = num_decoder_layers keys = jax.random.split(key, 5) # 1 for features, 4 for main model self.features = ProteinFeatures( node_features, edge_features, k_neighbors, key=keys[0], ) self.encoder = Encoder( node_features, edge_features, hidden_features, num_encoder_layers, key=keys[1], ) self.decoder = Decoder( node_features, edge_features, hidden_features, num_decoder_layers, key=keys[2], ) self.w_s_embed = eqx.nn.Embedding( num_embeddings=vocab_size, embedding_size=node_features, key=keys[3], ) self.w_out = eqx.nn.Linear(node_features, num_amino_acids, key=keys[4])
[docs] def _call_unconditional( self, edge_features: EdgeFeatures, neighbor_indices: NeighborIndices, mask: AlphaCarbonMask, _ar_mask: AutoRegressiveMask, _one_hot_sequence: OneHotProteinSequence, _prng_key: PRNGKeyArray, _temperature: Float, _bias: Logits, _tie_group_map: jnp.ndarray | None, _multi_state_strategy_idx: Int, _multi_state_alpha: float, ) -> tuple[OneHotProteinSequence, Logits]: """Run the unconditional (scoring) path. Args: edge_features: Edge features from feature extraction. neighbor_indices: Indices of neighbors for each node. mask: Alpha carbon mask. _ar_mask: Unused, required for jax.lax.switch signature. _one_hot_sequence: Unused, required for jax.lax.switch signature. prng_key: Unused, required for jax.lax.switch signature. _temperature: Unused, required for jax.lax.switch signature. _bias: Unused, required for jax.lax.switch signature. _tie_group_map: Unused, required for jax.lax.switch signature. _multi_state_strategy_idx: Unused, required for jax.lax.switch signature. _multi_state_alpha: Unused, required for jax.lax.switch signature. Returns: Tuple of (dummy sequence, logits). Raises: None Example: >>> key = jax.random.PRNGKey(0) >>> model = PrxteinMPNN(128, 128, 128, 3, 3, 30, key=key) >>> edge_feats = jnp.ones((10, 30, 128)) >>> neighbor_idx = jnp.arange(300).reshape(10, 30) >>> mask = jnp.ones((10,)) >>> seq, logits = model._call_unconditional(edge_feats, neighbor_idx, mask) """ node_features, processed_edge_features = self.encoder( edge_features, neighbor_indices, mask, ) decoded_node_features = self.decoder( node_features, processed_edge_features, neighbor_indices, # Pass neighbor indices for correct context mask, ) logits = jax.vmap(self.w_out)(decoded_node_features) # Return dummy sequence to match PyTree shape dummy_seq = jnp.zeros( (logits.shape[0], self.w_s_embed.num_embeddings), dtype=logits.dtype, ) return dummy_seq, logits
[docs] def _call_conditional( self, edge_features: EdgeFeatures, neighbor_indices: NeighborIndices, mask: AlphaCarbonMask, _ar_mask: AutoRegressiveMask, one_hot_sequence: OneHotProteinSequence, _prng_key: PRNGKeyArray, _temperature: Float, _bias: Logits, _tie_group_map: jnp.ndarray | None, _multi_state_strategy_idx: Int, _multi_state_alpha: float, ) -> tuple[OneHotProteinSequence, Logits]: """Run the conditional (scoring) path. Args: edge_features: Edge features from feature extraction. neighbor_indices: Indices of neighbors for each node. mask: Alpha carbon mask. _ar_mask: Autoregressive mask for conditional decoding. one_hot_sequence: One-hot encoded protein sequence. prng_key: Unused, required for jax.lax.switch signature. _temperature: Unused, required for jax.lax.switch signature. _bias: Unused, required for jax.lax.switch signature. _tie_group_map: Unused, required for jax.lax.switch signature. _multi_state_strategy_idx: Unused, required for jax.lax.switch signature. _multi_state_alpha: Unused, required for jax.lax.switch signature. Returns: Tuple of (input sequence, logits). Raises: None Example: >>> key = jax.random.PRNGKey(0) >>> model = PrxteinMPNN(128, 128, 128, 3, 3, 30, key=key) >>> edge_feats = jnp.ones((10, 30, 128)) >>> neighbor_idx = jnp.arange(300).reshape(10, 30) >>> mask = jnp.ones((10,)) >>> ar_mask = jnp.ones((10, 10)) >>> seq = jax.nn.one_hot(jnp.arange(10), 21) >>> out_seq, logits = model._call_conditional( ... edge_feats, neighbor_idx, mask, ar_mask, seq ... ) """ node_features, processed_edge_features = self.encoder( edge_features, neighbor_indices, mask, ) decoded_node_features = self.decoder.call_conditional( node_features, processed_edge_features, neighbor_indices, mask, _ar_mask, one_hot_sequence, self.w_s_embed.weight, ) logits = jax.vmap(self.w_out)(decoded_node_features) # Return input sequence to match PyTree shape return one_hot_sequence, logits
[docs] def _call_autoregressive( self, edge_features: EdgeFeatures, neighbor_indices: NeighborIndices, mask: AlphaCarbonMask, ar_mask: AutoRegressiveMask, _one_hot_sequence: OneHotProteinSequence, prng_key: PRNGKeyArray, temperature: Float, bias: Logits, tie_group_map: jnp.ndarray | None, multi_state_strategy_idx: Int, multi_state_alpha: float = 0.5, ) -> tuple[OneHotProteinSequence, Logits]: """Run the autoregressive (sampling) path. Args: edge_features: Edge features from feature extraction. neighbor_indices: Indices of neighbors for each node. mask: Alpha carbon mask. ar_mask: Autoregressive mask for sampling. _one_hot_sequence: Unused, required for jax.lax.switch signature. prng_key: PRNG key for sampling. temperature: Temperature for Gumbel-max sampling. bias: Bias to add to logits before sampling (N, 21). tie_group_map: Optional (N,) array mapping each position to a group ID. When provided, positions in the same group sample identical amino acids. multi_state_strategy_idx: Integer index for strategy (0=mean, 1=min, 2=product, 3=max_min). multi_state_alpha: Weight for min component when strategy="max_min". Returns: Tuple of (sampled sequence, logits). Raises: None Example: >>> key = jax.random.PRNGKey(0) >>> model = PrxteinMPNN(128, 128, 128, 3, 3, 30, key=key) >>> edge_feats = jnp.ones((10, 30, 128)) >>> neighbor_idx = jnp.arange(300).reshape(10, 30) >>> mask = jnp.ones((10,)) >>> ar_mask = jnp.ones((10, 10)) >>> temp = jnp.array(1.0) >>> bias = jnp.zeros((10, 21)) >>> seq, logits = model._call_autoregressive( ... edge_feats, neighbor_idx, mask, ar_mask, None, key, temp, bias, None ... ) """ # Convert strategy index back to string node_features, processed_edge_features = self.encoder( edge_features, neighbor_indices, mask, ) seq, logits = self._run_autoregressive_scan( prng_key, node_features, processed_edge_features, neighbor_indices, mask, ar_mask, temperature, bias, tie_group_map, multi_state_strategy_idx, multi_state_alpha, ) return seq, logits
[docs] @staticmethod def _average_logits_over_group( logits: Logits, group_mask: jnp.ndarray, ) -> jnp.ndarray: """Average logits across positions in a tie group using log-sum-exp. This implements numerically stable logit averaging for tied positions. Given logits of shape (N, 21) and a boolean mask indicating which positions belong to the current group, returns averaged logits of shape (1, 21). Args: logits: Logits array of shape (N, 21). group_mask: Boolean mask of shape (N,) indicating group membership. Returns: Averaged logits of shape (1, 21). Raises: None Example: >>> logits = jnp.array([[0.1, 0.9], [0.3, 0.7]]) >>> group_mask = jnp.array([True, True]) >>> avg_logits = PrxteinMPNN._average_logits_over_group(logits, group_mask) """ # Find max logits within the group for numerical stability max_logits = jnp.max( logits, where=group_mask[:, None], initial=-1e9, axis=0, keepdims=True, ) # (1, 21) # Shift logits and compute exp shifted_logits = logits - max_logits # (N, 21) exp_logits = jnp.exp(shifted_logits) # (N, 21) # Mask out non-group positions and sum masked_exp_logits = jnp.where(group_mask[:, None], exp_logits, 0.0) # (N, 21) sum_exp_logits = jnp.sum(masked_exp_logits, axis=0, keepdims=True) # (1, 21) # Compute average and convert back to log space num_in_group = jnp.sum(group_mask) avg_exp_logits = sum_exp_logits / num_in_group # (1, 21) return jnp.log(avg_exp_logits) + max_logits # (1, 21)
[docs] @staticmethod def _combine_logits_multistate( logits: Logits, group_mask: jnp.ndarray, strategy: Literal["mean", "min", "product", "max_min"] = "mean", alpha: float = 0.5, ) -> jnp.ndarray: """Combine logits across tied positions using different multi-state strategies. Args: logits: Logits array of shape (N, 21). group_mask: Boolean mask of shape (N,) indicating group membership. strategy: Strategy for combining logits: - "mean": Average logits (consensus prediction, default) - "min": Minimum logits (worst-case robust design) - "product": Sum of logits (multiply probabilities) - "max_min": Weighted combination of min and mean (alpha controls weight) alpha: Weight for min component when strategy="max_min" (0=pure mean, 1=pure min). Returns: Combined logits of shape (1, 21). Example: >>> logits = jnp.array([[10.0, -5.0], [8.0, -3.0]]) >>> group_mask = jnp.array([True, True]) >>> # Average strategy (compromise) >>> avg = PrxteinMPNN._combine_logits_multistate(logits, group_mask, "mean") >>> # Min strategy (robust to worst case) >>> robust = PrxteinMPNN._combine_logits_multistate(logits, group_mask, "min") """ if strategy == "mean": return PrxteinMPNN._average_logits_over_group(logits, group_mask) if strategy == "min": return min_over_group_logits(logits, group_mask) if strategy == "product": return product_of_probabilities_logits(logits, group_mask) if strategy == "max_min": return max_min_over_group_logits(logits, group_mask, alpha) msg = f"Unknown multi-state strategy: {strategy}" raise ValueError(msg)
[docs] @staticmethod def _combine_logits_multistate_idx( logits: Logits, group_mask: jnp.ndarray, strategy_idx: Int, alpha: float = 0.5, ) -> jnp.ndarray: """Combine logits using strategy index (JAX-traceable version). This is a JAX-traceable wrapper around _combine_logits_multistate that accepts an integer strategy index instead of a string. Used internally when the function needs to be JIT-compiled. Args: logits: Logits array of shape (N, 21). group_mask: Boolean mask of shape (N,) indicating group membership. strategy_idx: Integer strategy index (0=mean, 1=min, 2=product, 3=max_min). alpha: Weight for min component when strategy_idx=3 (0=pure mean, 1=pure min). Returns: Combined logits of shape (1, 21). """ def mean_fn(_: tuple) -> jnp.ndarray: return PrxteinMPNN._average_logits_over_group(logits, group_mask) def min_fn(_: tuple) -> jnp.ndarray: return min_over_group_logits(logits, group_mask) def product_fn(_: tuple) -> jnp.ndarray: return product_of_probabilities_logits(logits, group_mask) def max_min_fn(_: tuple) -> jnp.ndarray: return max_min_over_group_logits(logits, group_mask, alpha) branches = [mean_fn, min_fn, product_fn, max_min_fn] return jax.lax.switch(strategy_idx, branches, ())
[docs] def _process_group_positions( self, group_mask: jnp.ndarray, all_layers_h: NodeFeatures, s_embed: NodeFeatures, encoder_context: jnp.ndarray, edge_features: EdgeFeatures, neighbor_indices: NeighborIndices, mask: AlphaCarbonMask, mask_bw: jnp.ndarray, ) -> tuple[NodeFeatures, jnp.ndarray]: """Process all positions in a group through decoder and collect logits. Args: group_mask: Boolean mask (N,) for positions in current group. all_layers_h: Hidden states (num_layers+1, N, C). s_embed: Sequence embeddings (N, C). encoder_context: Precomputed encoder context (N, K, features). edge_features: Edge features (N, K, C). neighbor_indices: Neighbor indices (N, K). mask: Alpha carbon mask (N,). mask_bw: Backward mask (N, K). Returns: Tuple of (updated all_layers_h, computed logits (N, 21)). """ num_residues = all_layers_h.shape[1] computed_logits = jnp.zeros((num_residues, 21)) def process_one_position(idx: Int, state: tuple) -> tuple: """Process one position through decoder layers.""" position_all_layers_h, position_logits = state is_in_group = group_mask[idx] encoder_context_pos = encoder_context[idx] neighbor_indices_pos = neighbor_indices[idx] mask_pos = mask[idx] mask_bw_pos = mask_bw[idx] edge_sequence_features = concatenate_neighbor_nodes( s_embed, edge_features[idx], neighbor_indices_pos, ) for layer_idx, layer in enumerate(self.decoder.layers): h_in_pos = position_all_layers_h[layer_idx, idx] decoder_context_pos = concatenate_neighbor_nodes( position_all_layers_h[layer_idx], edge_sequence_features, neighbor_indices_pos, ) decoding_context = mask_bw_pos[..., None] * decoder_context_pos + encoder_context_pos h_in_expanded = jnp.expand_dims(h_in_pos, axis=0) decoding_context_expanded = jnp.expand_dims(decoding_context, axis=0) h_out_pos = layer(h_in_expanded, decoding_context_expanded, mask=mask_pos) position_all_layers_h = position_all_layers_h.at[layer_idx + 1, idx].set( jnp.squeeze(h_out_pos), ) final_h_pos = position_all_layers_h[-1, idx] logits_pos = self.w_out(final_h_pos) position_logits = jnp.where( is_in_group, position_logits.at[idx].set(logits_pos), position_logits, ) return position_all_layers_h, position_logits return jax.lax.fori_loop( 0, num_residues, process_one_position, (all_layers_h, computed_logits), )
[docs] def _run_tied_position_scan( self, prng_key: PRNGKeyArray, node_features: NodeFeatures, edge_features: EdgeFeatures, neighbor_indices: NeighborIndices, mask: AlphaCarbonMask, encoder_context: jnp.ndarray, mask_bw: jnp.ndarray, temperature: Float, bias: Logits, tie_group_map: jnp.ndarray, decoding_order: jnp.ndarray, multi_state_strategy_idx: Int = 0, multi_state_alpha: float = 0.5, ) -> tuple[OneHotProteinSequence, Logits]: """Run group-based autoregressive scan with logit combining. Args: prng_key: PRNG key. node_features: Node features (N, C). edge_features: Edge features (N, K, C). neighbor_indices: Neighbor indices (N, K). mask: Alpha carbon mask (N,). encoder_context: Precomputed encoder context (N, K, features). mask_bw: Backward mask (N, K). temperature: Sampling temperature. bias: Logits array (N, 21). tie_group_map: Group mapping (N,). decoding_order: Position decoding order (N,). multi_state_strategy_idx: Integer strategy index (0=mean, 1=min, 2=product, 3=max_min). multi_state_alpha: Weight for min component when strategy_idx=3. Returns: Tuple of (final sequence, final logits). """ num_residues = node_features.shape[0] # Find group decoding order based on first appearance in decoding_order # Get the group ID at each position in decoding_order groups_in_order = tie_group_map[decoding_order] # Find unique groups preserving order of first appearance using cumulative comparison # For each position, check if this group appeared earlier in the sequence # Create a lower triangular mask to check all previous positions position_indices = jnp.arange(num_residues) # Create mask: mask[i,j] = True if i > j (i.e., j is before i) is_before_mask = position_indices[:, None] > position_indices[None, :] # Check if current group (at each position) matches any previous group # groups_in_order[:, None] has shape (N, 1), groups_in_order[None, :] has shape (1, N) group_matches = groups_in_order[:, None] == groups_in_order[None, :] # For each position i, check if its group appears in any position j < i appeared_before = jnp.any(group_matches & is_before_mask, axis=1) # Mark first occurrences (positions where group hasn't appeared before) is_first_occurrence = ~appeared_before # Extract the group IDs at their first occurrences # Use compress with static size to avoid boolean indexing issues # Maximum possible size is num_residues (all positions in different groups) group_decoding_order = jnp.compress( is_first_occurrence, groups_in_order, size=num_residues, fill_value=-1, ) # Note: group_decoding_order now contains the group IDs followed by -1 padding # The scan will iterate over this, and steps with group_id=-1 will be no-ops def group_autoregressive_step( carry: tuple[NodeFeatures, NodeFeatures, Logits, OneHotProteinSequence], scan_inputs: tuple[Int, PRNGKeyArray], ) -> tuple[ tuple[NodeFeatures, NodeFeatures, Logits, OneHotProteinSequence], None, ]: """Process one group at a time with logit averaging.""" all_layers_h, s_embed, all_logits, sequence = carry group_id, key = scan_inputs group_mask = tie_group_map == group_id # Process positions and collect logits all_layers_h, computed_logits = self._process_group_positions( group_mask, all_layers_h, s_embed, encoder_context, edge_features, neighbor_indices, mask, mask_bw, ) # Combine logits using multi-state strategy and sample combined_logits = self._combine_logits_multistate_idx( computed_logits, group_mask, multi_state_strategy_idx, multi_state_alpha, ) all_logits, s_embed, sequence = self._sample_and_broadcast_to_group( combined_logits, group_mask, bias, temperature, key, all_logits, s_embed, sequence, ) return (all_layers_h, s_embed, all_logits, sequence), None # Initialize scan initial_all_layers_h = jnp.zeros( (self.num_decoder_layers + 1, num_residues, self.node_features_dim), ) initial_all_layers_h = initial_all_layers_h.at[0].set(node_features) initial_s_embed = jnp.zeros_like(node_features) initial_all_logits = jnp.zeros((num_residues, self.w_out.out_features)) initial_sequence = jnp.zeros((num_residues, self.w_s_embed.num_embeddings)) initial_carry = ( initial_all_layers_h, initial_s_embed, initial_all_logits, initial_sequence, ) # Use the length of group_decoding_order to determine how many groups we have actual_num_groups = group_decoding_order.shape[0] scan_inputs = (group_decoding_order, jax.random.split(prng_key, actual_num_groups)) final_carry, _ = jax.lax.scan( group_autoregressive_step, initial_carry, scan_inputs, ) return final_carry[3], final_carry[2]
[docs] def _sample_and_broadcast_to_group( self, avg_logits: jnp.ndarray, group_mask: jnp.ndarray, bias: Logits, temperature: Float, key: PRNGKeyArray, all_logits: Logits, s_embed: NodeFeatures, sequence: OneHotProteinSequence, ) -> tuple[Logits, NodeFeatures, OneHotProteinSequence]: """Sample once and broadcast token to all positions in a group. Args: avg_logits: Averaged logits (1, 21). group_mask: Boolean mask (N,) for group positions. bias: Bias array (N, 21). temperature: Sampling temperature. key: PRNG key. all_logits: Current logits array (N, 21). s_embed: Current sequence embeddings (N, C). sequence: Current sequence (N, 21). Returns: Tuple of (updated all_logits, updated s_embed, updated sequence). """ # Average bias across group group_bias = jnp.sum( jnp.where(group_mask[:, None], bias, 0.0), axis=0, keepdims=True, ) / jnp.sum(group_mask) logits_with_bias = avg_logits + group_bias # Sample once for the entire group sampled_logits = (logits_with_bias / temperature) + jax.random.gumbel( key, logits_with_bias.shape, ) sampled_logits_no_pad = sampled_logits[..., :20] one_hot_sample = straight_through_estimator(sampled_logits_no_pad) padding = jnp.zeros_like(one_hot_sample[..., :1]) one_hot_seq = jnp.concatenate([one_hot_sample, padding], axis=-1) # Broadcast to all group positions s_embed_new = one_hot_seq @ self.w_s_embed.weight all_logits = jnp.where(group_mask[:, None], jnp.squeeze(avg_logits), all_logits) s_embed = jnp.where(group_mask[:, None], jnp.squeeze(s_embed_new), s_embed) sequence = jnp.where(group_mask[:, None], jnp.squeeze(one_hot_seq), sequence) return all_logits, s_embed, sequence
[docs] def _run_autoregressive_scan( self, prng_key: PRNGKeyArray, node_features: NodeFeatures, edge_features: EdgeFeatures, neighbor_indices: NeighborIndices, mask: AlphaCarbonMask, autoregressive_mask: AutoRegressiveMask, temperature: Float, bias: Logits, tie_group_map: jnp.ndarray | None = None, multi_state_strategy_idx: Int = 0, multi_state_alpha: float = 0.5, ) -> tuple[OneHotProteinSequence, Logits]: """Run JAX scan loop for autoregressive sampling with optional tied positions. When tie_group_map is provided, the scan iterates over groups instead of individual positions. For each group: 1. Decoder processes all positions in the group 2. Logits are computed for all group members 3. Logits are averaged across the group (log-sum-exp) 4. A single token is sampled from the averaged logits 5. The token is broadcast to all positions in the group Args: prng_key: PRNG key for sampling. node_features: Node features from encoder. edge_features: Edge features from encoder. neighbor_indices: Indices of neighbors for each node. mask: Alpha carbon mask. autoregressive_mask: Mask defining decoding order. temperature: Temperature for Gumbel-max sampling. bias: Bias to add to logits before sampling (N, 21). tie_group_map: Optional (N,) array mapping each position to a group ID. When provided, positions in the same group are sampled together using combined logits. multi_state_strategy_idx: Integer strategy index (0=mean, 1=min, 2=product, 3=max_min). multi_state_alpha: Weight for min component when strategy_idx=3. Returns: Tuple of (sampled sequence, final logits). Raises: None Example: >>> key = jax.random.PRNGKey(0) >>> model = PrxteinMPNN(128, 128, 128, 3, 3, 30, key=key) >>> node_feats = jnp.ones((10, 128)) >>> edge_feats = jnp.ones((10, 30, 128)) >>> neighbor_idx = jnp.arange(300).reshape(10, 30) >>> mask = jnp.ones((10,)) >>> ar_mask = jnp.ones((10, 10)) >>> temp = jnp.array(1.0) >>> bias = jnp.zeros((10, 21)) >>> seq, logits = model._run_autoregressive_scan( ... key, node_feats, edge_feats, neighbor_idx, mask, ar_mask, temp, bias ... ) """ num_residues = node_features.shape[0] attention_mask = jnp.take_along_axis( autoregressive_mask, neighbor_indices, axis=1, ) mask_1d = mask[:, None] mask_bw = mask_1d * attention_mask mask_fw = mask_1d * (1 - attention_mask) decoding_order = jnp.argsort(jnp.sum(autoregressive_mask, axis=1)) # Precompute encoder context: [e_ij, 0_j, h_j] # This matches the unconditional decoder structure encoder_edge_neighbors = concatenate_neighbor_nodes( jnp.zeros_like(node_features), edge_features, neighbor_indices, ) # [e_ij, 0_j] encoder_context = concatenate_neighbor_nodes( node_features, encoder_edge_neighbors, neighbor_indices, ) # [[e_ij, 0_j], h_j] = [e_ij, 0_j, h_j] encoder_context = encoder_context * mask_fw[..., None] def autoregressive_step( carry: tuple[NodeFeatures, NodeFeatures, Logits, OneHotProteinSequence], scan_inputs: tuple[Int, PRNGKeyArray], ) -> tuple[ tuple[NodeFeatures, NodeFeatures, Logits, OneHotProteinSequence], None, ]: all_layers_h, s_embed, all_logits, sequence = carry position, key = scan_inputs # Direct indexing at current position encoder_context_pos = encoder_context[position] # (K, 384) neighbor_indices_pos = neighbor_indices[position] # (K,) mask_pos = mask[position] # scalar mask_bw_pos = mask_bw[position] # (K,) # Compute edge sequence features for this position edge_sequence_features = concatenate_neighbor_nodes( s_embed, edge_features[position], neighbor_indices_pos, ) # (K, 256) # Decoder Layer Loop for layer_idx, layer in enumerate(self.decoder.layers): # Get node features for this layer at current position h_in_pos = all_layers_h[layer_idx, position] # [C] # Compute decoder context for this position decoder_context_pos = concatenate_neighbor_nodes( all_layers_h[layer_idx], edge_sequence_features, neighbor_indices_pos, ) # (K, 384) # Combine with encoder context using backward mask decoding_context = ( mask_bw_pos[..., None] * decoder_context_pos + encoder_context_pos ) # (K, 384) # Expand dims for layer forward pass h_in_expanded = jnp.expand_dims(h_in_pos, axis=0) # [1, C] decoding_context_expanded = jnp.expand_dims(decoding_context, axis=0) # [1, K, 384] # Call DecoderLayer h_out_pos = layer( h_in_expanded, decoding_context_expanded, mask=mask_pos, ) # [1, C] # Update the state for next layer all_layers_h = all_layers_h.at[layer_idx + 1, position].set(jnp.squeeze(h_out_pos)) # Sampling Step # Get final layer output for this position final_h_pos = all_layers_h[-1, position] # [C] logits_pos_vec = self.w_out(final_h_pos) # [21] logits_pos = jnp.expand_dims(logits_pos_vec, axis=0) # [1, 21] next_all_logits = all_logits.at[position, :].set(jnp.squeeze(logits_pos)) # Apply bias before sampling bias_pos = jax.lax.dynamic_slice( bias, (position, 0), (1, bias.shape[-1]), ) logits_with_bias = logits_pos + bias_pos # Gumbel-max trick sampled_logits = (logits_with_bias / temperature) + jax.random.gumbel( key, logits_with_bias.shape, ) sampled_logits_no_pad = sampled_logits[..., :20] # Exclude padding one_hot_sample = straight_through_estimator(sampled_logits_no_pad) padding = jnp.zeros_like(one_hot_sample[..., :1]) one_hot_seq_pos = jnp.concatenate([one_hot_sample, padding], axis=-1) s_embed_pos = one_hot_seq_pos @ self.w_s_embed.weight # [1, C] next_s_embed = s_embed.at[position, :].set(jnp.squeeze(s_embed_pos)) next_sequence = sequence.at[position, :].set(jnp.squeeze(one_hot_seq_pos)) return ( all_layers_h, next_s_embed, next_all_logits, next_sequence, ), None # Branch based on whether we have tied positions if tie_group_map is None: # Standard position-by-position sampling # Initialize Scan initial_all_layers_h = jnp.zeros( (self.num_decoder_layers + 1, num_residues, self.node_features_dim), ) initial_all_layers_h = initial_all_layers_h.at[0].set(node_features) initial_s_embed = jnp.zeros_like(node_features) initial_all_logits = jnp.zeros((num_residues, self.w_out.out_features)) initial_sequence = jnp.zeros((num_residues, self.w_s_embed.num_embeddings)) initial_carry = ( initial_all_layers_h, initial_s_embed, initial_all_logits, initial_sequence, ) scan_inputs = (decoding_order, jax.random.split(prng_key, num_residues)) final_carry, _ = jax.lax.scan( autoregressive_step, initial_carry, scan_inputs, ) final_sequence = final_carry[3] final_all_logits = final_carry[2] return final_sequence, final_all_logits # Group-based sampling with logit combining return self._run_tied_position_scan( prng_key, node_features, edge_features, neighbor_indices, mask, encoder_context, mask_bw, temperature, bias, tie_group_map, decoding_order, multi_state_strategy_idx, multi_state_alpha, )
def __call__( self, structure_coordinates: StructureAtomicCoordinates, mask: AlphaCarbonMask, residue_index: ResidueIndex, chain_index: ChainIndex, decoding_approach: DecodingApproach, *, prng_key: PRNGKeyArray | None = None, ar_mask: AutoRegressiveMask | None = None, one_hot_sequence: OneHotProteinSequence | None = None, temperature: Float | None = None, bias: Logits | None = None, backbone_noise: BackboneNoise | None = None, tie_group_map: jnp.ndarray | None = None, multi_state_strategy: Literal["mean", "min", "product", "max_min"] = "mean", multi_state_alpha: float = 0.5, ) -> tuple[OneHotProteinSequence, Logits]: """Forward pass for the complete model. Dispatches to one of three modes: 1. "unconditional": Scores all positions in parallel. 2. "conditional": Scores a given sequence. 3. "autoregressive": Samples a new sequence. Args: structure_coordinates: Raw atomic coordinates of protein structure. mask: Alpha carbon mask indicating valid residues. residue_index: Residue indices for each position. chain_index: Chain indices for each position. decoding_approach: One of "unconditional", "conditional", or "autoregressive". prng_key: PRNG key for random operations (optional). ar_mask: Autoregressive mask for decoding order (optional). one_hot_sequence: One-hot encoded sequence for conditional mode (optional). temperature: Temperature for autoregressive sampling (optional). bias: Optional bias to add to logits before sampling (N, 21) (optional). backbone_noise: Noise level for backbone coordinates (optional). tie_group_map: Optional (N,) array mapping each position to a group ID. When provided, positions in the same group sample identical amino acids using logit combining. Only used in "autoregressive" mode (optional). multi_state_strategy: Strategy for combining logits across tied positions. Options: "mean" (default, average), "min" (worst-case robust), "product" (multiply probabilities), "max_min" (weighted combination). Only used in "autoregressive" mode with tied positions (optional). multi_state_alpha: Weight for min component when multi_state_strategy="max_min". Range [0, 1] where 0=pure mean, 1=pure min (optional). Returns: A tuple of (OneHotProteinSequence, Logits). - For "unconditional", the sequence is a zero-tensor. - For "conditional", the sequence is the input sequence. - For "autoregressive", the sequence is the newly sampled one. Raises: None Example: >>> key = jax.random.PRNGKey(0) >>> model = PrxteinMPNN(128, 128, 128, 3, 3, 30, key=key) >>> coords = jnp.ones((10, 4, 3)) >>> mask = jnp.ones((10,)) >>> residue_idx = jnp.arange(10) >>> chain_idx = jnp.zeros((10,), dtype=jnp.int32) >>> seq, logits = model( ... coords, mask, residue_idx, chain_idx, "unconditional", prng_key=key ... ) """ # 1. Prepare keys and noise if prng_key is None: prng_key = jax.random.PRNGKey(0) # Use a default key if none provided prng_key, feat_key = jax.random.split(prng_key) if backbone_noise is None: backbone_noise = jnp.array(0.0, dtype=jnp.float32) # 2. Run Feature Extraction edge_features, neighbor_indices, _ = self.features( feat_key, structure_coordinates, mask, residue_index, chain_index, backbone_noise, ) # 3. Prepare inputs for jax.lax.switch branch_indices = { "unconditional": 0, "conditional": 1, "autoregressive": 2, } branch_index = branch_indices[decoding_approach] # All branches must accept the same (super-set) of arguments. # We fill in defaults for modes that don't use them. if ar_mask is None: ar_mask = jnp.zeros((mask.shape[0], mask.shape[0]), dtype=jnp.int32) if one_hot_sequence is None: one_hot_sequence = jnp.zeros( (mask.shape[0], self.w_s_embed.num_embeddings), ) if temperature is None: temperature = jnp.array(1.0) if bias is None: bias = jnp.zeros((mask.shape[0], 21), dtype=jnp.float32) # Note: tie_group_map is only used in autoregressive mode, but all branches # must accept the same arguments for jax.lax.switch compatibility # Convert multi_state_strategy string to integer for JAX tracing strategy_map = {"mean": 0, "min": 1, "product": 2, "max_min": 3} multi_state_strategy_idx = jnp.array( strategy_map[multi_state_strategy], dtype=jnp.int32, ) # 4. Define the branches for jax.lax.switch branches = [ self._call_unconditional, self._call_conditional, self._call_autoregressive, ] # 5. Collect all operands operands = ( edge_features, neighbor_indices, mask, ar_mask, one_hot_sequence, prng_key, temperature, bias, tie_group_map, multi_state_strategy_idx, multi_state_alpha, ) # 6. Run the switch return jax.lax.switch(branch_index, branches, *operands)