model#
Model module for PrxteinMPNN.
This module contains the core Equinox-based neural network components for ProteinMPNN.
- class prxteinmpnn.model.Decoder(node_features, edge_features, hidden_features, num_layers=3, *, key)[source]#
Bases:
ModuleThe complete decoder module for ProteinMPNN.
- Parameters:
- call_conditional(node_features, edge_features, neighbor_indices, mask, ar_mask, one_hot_sequence, w_s_weight)[source]#
Forward pass for CONDITIONAL decoding (scoring).
- Parameters:
node_features (
Int[Array, 'num_atoms num_features']) – Node features from encoder of shape (N, 128).edge_features (
Float[Array, 'num_atoms num_neighbors num_features']) – Edge features from encoder of shape (N, K, 128).neighbor_indices (
Int[Array, 'num_atoms num_neighbors']) – Indices of neighbors for each node.mask (
Int[Array, 'num_residues 3']) – Alpha carbon mask of shape (N,).ar_mask (
Bool[Array, 'num_residues num_residues']) – Autoregressive mask for conditional decoding.one_hot_sequence (
Float[Array, 'num_residues num_classes']) – One-hot encoded protein sequence.w_s_weight (
Array) – Sequence embedding weight matrix.
- Return type:
Int[Array, 'num_atoms num_features']- Returns:
Decoded node features of shape (N, 128).
- Raises:
None –
Example
>>> key = jax.random.PRNGKey(0) >>> decoder = Decoder(128, 128, 128, num_layers=3, key=key) >>> node_feats = jnp.ones((10, 128)) >>> edge_feats = jnp.ones((10, 30, 128)) >>> neighbor_indices = jnp.arange(300).reshape(10, 30) >>> mask = jnp.ones((10,)) >>> ar_mask = jnp.ones((10, 10)) >>> seq = jax.nn.one_hot(jnp.arange(10), 21) >>> w_s = jnp.ones((21, 128)) >>> output = decoder.call_conditional( ... node_feats, edge_feats, neighbor_indices, mask, ar_mask, seq, w_s ... )
-
layers:
tuple[DecoderLayer,...]#
- class prxteinmpnn.model.DecoderLayer(node_features, edge_context_features, hidden_features, *, key)[source]#
Bases:
ModuleA single decoder layer for the ProteinMPNN model.
- Parameters:
-
message_mlp:
MLP#
-
norm1:
LayerNorm#
-
dense:
MLP#
-
norm2:
LayerNorm#
- class prxteinmpnn.model.Encoder(node_features, edge_features, hidden_features, num_layers=3, *, key)[source]#
Bases:
ModuleThe complete encoder module for ProteinMPNN.
- Parameters:
-
layers:
tuple[EncoderLayer,...]#
- class prxteinmpnn.model.EncoderLayer(node_features, edge_features, hidden_features, *, key)[source]#
Bases:
ModuleA single encoder layer for the ProteinMPNN model.
- _get_mlp_input(h, e, neighbor_indices)[source]#
Return the input tensor [h_i, e_ij, h_j] for edge_message_mlp.
- Return type:
- Parameters:
h (Int[Array, 'num_atoms num_features'])
e (Float[Array, 'num_atoms num_neighbors num_features'])
neighbor_indices (Int[Array, 'num_atoms num_neighbors'])
-
edge_message_mlp:
MLP#
-
norm1:
LayerNorm#
-
dense:
MLP#
-
norm2:
LayerNorm#
-
edge_update_mlp:
MLP#
-
norm3:
LayerNorm#
- class prxteinmpnn.model.ProteinFeatures(node_features, edge_features, k_neighbors, *, key)[source]#
Bases:
ModuleExtracts and projects features from raw protein coordinates.
This module encapsulates k-NN, RBF, positional encodings, and edge projections. Note: W_e projection is NOT here - it’s in the main model (matches ColabDesign).
-
w_pos:
Linear#
-
w_e:
Linear#
-
norm_edges:
LayerNorm#
-
w_e_proj:
Linear#
-
w_pos:
- class prxteinmpnn.model.PrxteinMPNN(node_features, edge_features, hidden_features, num_encoder_layers, num_decoder_layers, k_neighbors, num_amino_acids=21, vocab_size=21, *, key)[source]#
Bases:
ModuleThe complete end-to-end ProteinMPNN model.
- Parameters:
- static _average_logits_over_group(logits, group_mask)[source]#
Average logits across positions in a tie group using log-sum-exp.
This implements numerically stable logit averaging for tied positions. Given logits of shape (N, 21) and a boolean mask indicating which positions belong to the current group, returns averaged logits of shape (1, 21).
- Parameters:
logits (
Float[Array, 'num_residues num_classes']) – Logits array of shape (N, 21).group_mask (
Array) – Boolean mask of shape (N,) indicating group membership.
- Return type:
- Returns:
Averaged logits of shape (1, 21).
- Raises:
None –
Example
>>> logits = jnp.array([[0.1, 0.9], [0.3, 0.7]]) >>> group_mask = jnp.array([True, True]) >>> avg_logits = PrxteinMPNN._average_logits_over_group(logits, group_mask)
- _call_autoregressive(edge_features, neighbor_indices, mask, ar_mask, _one_hot_sequence, prng_key, temperature, bias, tie_group_map, multi_state_strategy_idx, multi_state_alpha=0.5)[source]#
Run the autoregressive (sampling) path.
- Parameters:
edge_features (
Float[Array, 'num_atoms num_neighbors num_features']) – Edge features from feature extraction.neighbor_indices (
Int[Array, 'num_atoms num_neighbors']) – Indices of neighbors for each node.mask (
Int[Array, 'num_residues 3']) – Alpha carbon mask.ar_mask (
Bool[Array, 'num_residues num_residues']) – Autoregressive mask for sampling._one_hot_sequence (
Float[Array, 'num_residues num_classes']) – Unused, required for jax.lax.switch signature.prng_key (
Union[Key[Array, ''],UInt32[Array, '2']]) – PRNG key for sampling.temperature (
Float) – Temperature for Gumbel-max sampling.bias (
Float[Array, 'num_residues num_classes']) – Bias to add to logits before sampling (N, 21).tie_group_map (
Array|None) – Optional (N,) array mapping each position to a group ID. When provided, positions in the same group sample identical amino acids.multi_state_strategy_idx (
Int) – Integer index for strategy (0=mean, 1=min, 2=product, 3=max_min).multi_state_alpha (
float) – Weight for min component when strategy=”max_min”.
- Return type:
tuple[Float[Array, 'num_residues num_classes'],Float[Array, 'num_residues num_classes']]- Returns:
Tuple of (sampled sequence, logits).
- Raises:
None –
Example
>>> key = jax.random.PRNGKey(0) >>> model = PrxteinMPNN(128, 128, 128, 3, 3, 30, key=key) >>> edge_feats = jnp.ones((10, 30, 128)) >>> neighbor_idx = jnp.arange(300).reshape(10, 30) >>> mask = jnp.ones((10,)) >>> ar_mask = jnp.ones((10, 10)) >>> temp = jnp.array(1.0) >>> bias = jnp.zeros((10, 21)) >>> seq, logits = model._call_autoregressive( ... edge_feats, neighbor_idx, mask, ar_mask, None, key, temp, bias, None ... )
- _call_conditional(edge_features, neighbor_indices, mask, _ar_mask, one_hot_sequence, _prng_key, _temperature, _bias, _tie_group_map, _multi_state_strategy_idx, _multi_state_alpha)[source]#
Run the conditional (scoring) path.
- Parameters:
edge_features (
Float[Array, 'num_atoms num_neighbors num_features']) – Edge features from feature extraction.neighbor_indices (
Int[Array, 'num_atoms num_neighbors']) – Indices of neighbors for each node.mask (
Int[Array, 'num_residues 3']) – Alpha carbon mask._ar_mask (
Bool[Array, 'num_residues num_residues']) – Autoregressive mask for conditional decoding.one_hot_sequence (
Float[Array, 'num_residues num_classes']) – One-hot encoded protein sequence.prng_key – Unused, required for jax.lax.switch signature.
_temperature (
Float) – Unused, required for jax.lax.switch signature._bias (
Float[Array, 'num_residues num_classes']) – Unused, required for jax.lax.switch signature._tie_group_map (
Array|None) – Unused, required for jax.lax.switch signature._multi_state_strategy_idx (
Int) – Unused, required for jax.lax.switch signature._multi_state_alpha (
float) – Unused, required for jax.lax.switch signature._prng_key (Key[Array, ''] | UInt32[Array, '2'])
- Return type:
tuple[Float[Array, 'num_residues num_classes'],Float[Array, 'num_residues num_classes']]- Returns:
Tuple of (input sequence, logits).
- Raises:
None –
Example
>>> key = jax.random.PRNGKey(0) >>> model = PrxteinMPNN(128, 128, 128, 3, 3, 30, key=key) >>> edge_feats = jnp.ones((10, 30, 128)) >>> neighbor_idx = jnp.arange(300).reshape(10, 30) >>> mask = jnp.ones((10,)) >>> ar_mask = jnp.ones((10, 10)) >>> seq = jax.nn.one_hot(jnp.arange(10), 21) >>> out_seq, logits = model._call_conditional( ... edge_feats, neighbor_idx, mask, ar_mask, seq ... )
- _call_unconditional(edge_features, neighbor_indices, mask, _ar_mask, _one_hot_sequence, _prng_key, _temperature, _bias, _tie_group_map, _multi_state_strategy_idx, _multi_state_alpha)[source]#
Run the unconditional (scoring) path.
- Parameters:
edge_features (
Float[Array, 'num_atoms num_neighbors num_features']) – Edge features from feature extraction.neighbor_indices (
Int[Array, 'num_atoms num_neighbors']) – Indices of neighbors for each node.mask (
Int[Array, 'num_residues 3']) – Alpha carbon mask._ar_mask (
Bool[Array, 'num_residues num_residues']) – Unused, required for jax.lax.switch signature._one_hot_sequence (
Float[Array, 'num_residues num_classes']) – Unused, required for jax.lax.switch signature.prng_key – Unused, required for jax.lax.switch signature.
_temperature (
Float) – Unused, required for jax.lax.switch signature._bias (
Float[Array, 'num_residues num_classes']) – Unused, required for jax.lax.switch signature._tie_group_map (
Array|None) – Unused, required for jax.lax.switch signature._multi_state_strategy_idx (
Int) – Unused, required for jax.lax.switch signature._multi_state_alpha (
float) – Unused, required for jax.lax.switch signature._prng_key (Key[Array, ''] | UInt32[Array, '2'])
- Return type:
tuple[Float[Array, 'num_residues num_classes'],Float[Array, 'num_residues num_classes']]- Returns:
Tuple of (dummy sequence, logits).
- Raises:
None –
Example
>>> key = jax.random.PRNGKey(0) >>> model = PrxteinMPNN(128, 128, 128, 3, 3, 30, key=key) >>> edge_feats = jnp.ones((10, 30, 128)) >>> neighbor_idx = jnp.arange(300).reshape(10, 30) >>> mask = jnp.ones((10,)) >>> seq, logits = model._call_unconditional(edge_feats, neighbor_idx, mask)
- static _combine_logits_multistate(logits, group_mask, strategy='mean', alpha=0.5)[source]#
Combine logits across tied positions using different multi-state strategies.
- Parameters:
logits (
Float[Array, 'num_residues num_classes']) – Logits array of shape (N, 21).group_mask (
Array) – Boolean mask of shape (N,) indicating group membership.strategy (
Literal['mean','min','product','max_min']) – Strategy for combining logits: - “mean”: Average logits (consensus prediction, default) - “min”: Minimum logits (worst-case robust design) - “product”: Sum of logits (multiply probabilities) - “max_min”: Weighted combination of min and mean (alpha controls weight)alpha (
float) – Weight for min component when strategy=”max_min” (0=pure mean, 1=pure min).
- Return type:
- Returns:
Combined logits of shape (1, 21).
Example
>>> logits = jnp.array([[10.0, -5.0], [8.0, -3.0]]) >>> group_mask = jnp.array([True, True]) >>> # Average strategy (compromise) >>> avg = PrxteinMPNN._combine_logits_multistate(logits, group_mask, "mean") >>> # Min strategy (robust to worst case) >>> robust = PrxteinMPNN._combine_logits_multistate(logits, group_mask, "min")
- static _combine_logits_multistate_idx(logits, group_mask, strategy_idx, alpha=0.5)[source]#
Combine logits using strategy index (JAX-traceable version).
This is a JAX-traceable wrapper around _combine_logits_multistate that accepts an integer strategy index instead of a string. Used internally when the function needs to be JIT-compiled.
- Parameters:
logits (
Float[Array, 'num_residues num_classes']) – Logits array of shape (N, 21).group_mask (
Array) – Boolean mask of shape (N,) indicating group membership.strategy_idx (
Int) – Integer strategy index (0=mean, 1=min, 2=product, 3=max_min).alpha (
float) – Weight for min component when strategy_idx=3 (0=pure mean, 1=pure min).
- Return type:
- Returns:
Combined logits of shape (1, 21).
- _process_group_positions(group_mask, all_layers_h, s_embed, encoder_context, edge_features, neighbor_indices, mask, mask_bw)[source]#
Process all positions in a group through decoder and collect logits.
- Parameters:
group_mask (
Array) – Boolean mask (N,) for positions in current group.all_layers_h (
Int[Array, 'num_atoms num_features']) – Hidden states (num_layers+1, N, C).s_embed (
Int[Array, 'num_atoms num_features']) – Sequence embeddings (N, C).encoder_context (
Array) – Precomputed encoder context (N, K, features).edge_features (
Float[Array, 'num_atoms num_neighbors num_features']) – Edge features (N, K, C).neighbor_indices (
Int[Array, 'num_atoms num_neighbors']) – Neighbor indices (N, K).mask (
Int[Array, 'num_residues 3']) – Alpha carbon mask (N,).mask_bw (
Array) – Backward mask (N, K).
- Return type:
- Returns:
Tuple of (updated all_layers_h, computed logits (N, 21)).
- _run_autoregressive_scan(prng_key, node_features, edge_features, neighbor_indices, mask, autoregressive_mask, temperature, bias, tie_group_map=None, multi_state_strategy_idx=0, multi_state_alpha=0.5)[source]#
Run JAX scan loop for autoregressive sampling with optional tied positions.
When tie_group_map is provided, the scan iterates over groups instead of individual positions. For each group: 1. Decoder processes all positions in the group 2. Logits are computed for all group members 3. Logits are averaged across the group (log-sum-exp) 4. A single token is sampled from the averaged logits 5. The token is broadcast to all positions in the group
- Parameters:
prng_key (
Union[Key[Array, ''],UInt32[Array, '2']]) – PRNG key for sampling.node_features (
Int[Array, 'num_atoms num_features']) – Node features from encoder.edge_features (
Float[Array, 'num_atoms num_neighbors num_features']) – Edge features from encoder.neighbor_indices (
Int[Array, 'num_atoms num_neighbors']) – Indices of neighbors for each node.mask (
Int[Array, 'num_residues 3']) – Alpha carbon mask.autoregressive_mask (
Bool[Array, 'num_residues num_residues']) – Mask defining decoding order.temperature (
Float) – Temperature for Gumbel-max sampling.bias (
Float[Array, 'num_residues num_classes']) – Bias to add to logits before sampling (N, 21).tie_group_map (
Array|None) – Optional (N,) array mapping each position to a group ID. When provided, positions in the same group are sampled together using combined logits.multi_state_strategy_idx (
Int) – Integer strategy index (0=mean, 1=min, 2=product, 3=max_min).multi_state_alpha (
float) – Weight for min component when strategy_idx=3.
- Return type:
tuple[Float[Array, 'num_residues num_classes'],Float[Array, 'num_residues num_classes']]- Returns:
Tuple of (sampled sequence, final logits).
- Raises:
None –
Example
>>> key = jax.random.PRNGKey(0) >>> model = PrxteinMPNN(128, 128, 128, 3, 3, 30, key=key) >>> node_feats = jnp.ones((10, 128)) >>> edge_feats = jnp.ones((10, 30, 128)) >>> neighbor_idx = jnp.arange(300).reshape(10, 30) >>> mask = jnp.ones((10,)) >>> ar_mask = jnp.ones((10, 10)) >>> temp = jnp.array(1.0) >>> bias = jnp.zeros((10, 21)) >>> seq, logits = model._run_autoregressive_scan( ... key, node_feats, edge_feats, neighbor_idx, mask, ar_mask, temp, bias ... )
- _run_tied_position_scan(prng_key, node_features, edge_features, neighbor_indices, mask, encoder_context, mask_bw, temperature, bias, tie_group_map, decoding_order, multi_state_strategy_idx=0, multi_state_alpha=0.5)[source]#
Run group-based autoregressive scan with logit combining.
- Parameters:
prng_key (
Union[Key[Array, ''],UInt32[Array, '2']]) – PRNG key.node_features (
Int[Array, 'num_atoms num_features']) – Node features (N, C).edge_features (
Float[Array, 'num_atoms num_neighbors num_features']) – Edge features (N, K, C).neighbor_indices (
Int[Array, 'num_atoms num_neighbors']) – Neighbor indices (N, K).mask (
Int[Array, 'num_residues 3']) – Alpha carbon mask (N,).encoder_context (
Array) – Precomputed encoder context (N, K, features).mask_bw (
Array) – Backward mask (N, K).temperature (
Float) – Sampling temperature.bias (
Float[Array, 'num_residues num_classes']) – Logits array (N, 21).tie_group_map (
Array) – Group mapping (N,).decoding_order (
Array) – Position decoding order (N,).multi_state_strategy_idx (
Int) – Integer strategy index (0=mean, 1=min, 2=product, 3=max_min).multi_state_alpha (
float) – Weight for min component when strategy_idx=3.
- Return type:
tuple[Float[Array, 'num_residues num_classes'],Float[Array, 'num_residues num_classes']]- Returns:
Tuple of (final sequence, final logits).
- _sample_and_broadcast_to_group(avg_logits, group_mask, bias, temperature, key, all_logits, s_embed, sequence)[source]#
Sample once and broadcast token to all positions in a group.
- Parameters:
avg_logits (
Array) – Averaged logits (1, 21).group_mask (
Array) – Boolean mask (N,) for group positions.bias (
Float[Array, 'num_residues num_classes']) – Bias array (N, 21).temperature (
Float) – Sampling temperature.key (
Union[Key[Array, ''],UInt32[Array, '2']]) – PRNG key.all_logits (
Float[Array, 'num_residues num_classes']) – Current logits array (N, 21).s_embed (
Int[Array, 'num_atoms num_features']) – Current sequence embeddings (N, C).sequence (
Float[Array, 'num_residues num_classes']) – Current sequence (N, 21).
- Return type:
tuple[Float[Array, 'num_residues num_classes'],Int[Array, 'num_atoms num_features'],Float[Array, 'num_residues num_classes']]- Returns:
Tuple of (updated all_logits, updated s_embed, updated sequence).
-
features:
ProteinFeatures#
-
w_s_embed:
Embedding#
-
w_out:
Linear#