model#

Model module for PrxteinMPNN.

This module contains the core Equinox-based neural network components for ProteinMPNN.

class prxteinmpnn.model.Decoder(node_features, edge_features, hidden_features, num_layers=3, *, key)[source]#

Bases: Module

The complete decoder module for ProteinMPNN.

Parameters:
  • node_features (int)

  • edge_features (int)

  • hidden_features (int)

  • num_layers (int)

  • key (PRNGKeyArray)

call_conditional(node_features, edge_features, neighbor_indices, mask, ar_mask, one_hot_sequence, w_s_weight)[source]#

Forward pass for CONDITIONAL decoding (scoring).

Parameters:
  • node_features (Int[Array, 'num_atoms num_features']) – Node features from encoder of shape (N, 128).

  • edge_features (Float[Array, 'num_atoms num_neighbors num_features']) – Edge features from encoder of shape (N, K, 128).

  • neighbor_indices (Int[Array, 'num_atoms num_neighbors']) – Indices of neighbors for each node.

  • mask (Int[Array, 'num_residues 3']) – Alpha carbon mask of shape (N,).

  • ar_mask (Bool[Array, 'num_residues num_residues']) – Autoregressive mask for conditional decoding.

  • one_hot_sequence (Float[Array, 'num_residues num_classes']) – One-hot encoded protein sequence.

  • w_s_weight (Array) – Sequence embedding weight matrix.

Return type:

Int[Array, 'num_atoms num_features']

Returns:

Decoded node features of shape (N, 128).

Raises:

None

Example

>>> key = jax.random.PRNGKey(0)
>>> decoder = Decoder(128, 128, 128, num_layers=3, key=key)
>>> node_feats = jnp.ones((10, 128))
>>> edge_feats = jnp.ones((10, 30, 128))
>>> neighbor_indices = jnp.arange(300).reshape(10, 30)
>>> mask = jnp.ones((10,))
>>> ar_mask = jnp.ones((10, 10))
>>> seq = jax.nn.one_hot(jnp.arange(10), 21)
>>> w_s = jnp.ones((21, 128))
>>> output = decoder.call_conditional(
...     node_feats, edge_feats, neighbor_indices, mask, ar_mask, seq, w_s
... )
layers: tuple[DecoderLayer, ...]#
node_features_dim: int#
edge_features_dim: int#
class prxteinmpnn.model.DecoderLayer(node_features, edge_context_features, hidden_features, *, key)[source]#

Bases: Module

A single decoder layer for the ProteinMPNN model.

Parameters:
  • node_features (int)

  • edge_context_features (int)

  • hidden_features (int)

  • key (PRNGKeyArray)

message_mlp: MLP#
norm1: LayerNorm#
dense: MLP#
norm2: LayerNorm#
class prxteinmpnn.model.Encoder(node_features, edge_features, hidden_features, num_layers=3, *, key)[source]#

Bases: Module

The complete encoder module for ProteinMPNN.

Parameters:
  • node_features (int)

  • edge_features (int)

  • hidden_features (int)

  • num_layers (int)

  • key (PRNGKeyArray)

layers: tuple[EncoderLayer, ...]#
node_feature_dim: int#
class prxteinmpnn.model.EncoderLayer(node_features, edge_features, hidden_features, *, key)[source]#

Bases: Module

A single encoder layer for the ProteinMPNN model.

Parameters:
  • node_features (int)

  • edge_features (int)

  • hidden_features (int)

  • key (PRNGKeyArray)

_get_mlp_input(h, e, neighbor_indices)[source]#

Return the input tensor [h_i, e_ij, h_j] for edge_message_mlp.

Return type:

Array

Parameters:
  • h (Int[Array, 'num_atoms num_features'])

  • e (Float[Array, 'num_atoms num_neighbors num_features'])

  • neighbor_indices (Int[Array, 'num_atoms num_neighbors'])

edge_message_mlp: MLP#
norm1: LayerNorm#
dense: MLP#
norm2: LayerNorm#
edge_update_mlp: MLP#
norm3: LayerNorm#
node_features_dim: int#
edge_features_dim: int#
class prxteinmpnn.model.ProteinFeatures(node_features, edge_features, k_neighbors, *, key)[source]#

Bases: Module

Extracts and projects features from raw protein coordinates.

This module encapsulates k-NN, RBF, positional encodings, and edge projections. Note: W_e projection is NOT here - it’s in the main model (matches ColabDesign).

Parameters:
  • node_features (int)

  • edge_features (int)

  • k_neighbors (int)

  • key (PRNGKeyArray)

w_pos: Linear#
w_e: Linear#
norm_edges: LayerNorm#
w_e_proj: Linear#
k_neighbors: int#
rbf_dim: int#
pos_embed_dim: int#
class prxteinmpnn.model.PrxteinMPNN(node_features, edge_features, hidden_features, num_encoder_layers, num_decoder_layers, k_neighbors, num_amino_acids=21, vocab_size=21, *, key)[source]#

Bases: Module

The complete end-to-end ProteinMPNN model.

Parameters:
  • node_features (int)

  • edge_features (int)

  • hidden_features (int)

  • num_encoder_layers (int)

  • num_decoder_layers (int)

  • k_neighbors (int)

  • num_amino_acids (int)

  • vocab_size (int)

  • key (PRNGKeyArray)

static _average_logits_over_group(logits, group_mask)[source]#

Average logits across positions in a tie group using log-sum-exp.

This implements numerically stable logit averaging for tied positions. Given logits of shape (N, 21) and a boolean mask indicating which positions belong to the current group, returns averaged logits of shape (1, 21).

Parameters:
  • logits (Float[Array, 'num_residues num_classes']) – Logits array of shape (N, 21).

  • group_mask (Array) – Boolean mask of shape (N,) indicating group membership.

Return type:

Array

Returns:

Averaged logits of shape (1, 21).

Raises:

None

Example

>>> logits = jnp.array([[0.1, 0.9], [0.3, 0.7]])
>>> group_mask = jnp.array([True, True])
>>> avg_logits = PrxteinMPNN._average_logits_over_group(logits, group_mask)
_call_autoregressive(edge_features, neighbor_indices, mask, ar_mask, _one_hot_sequence, prng_key, temperature, bias, tie_group_map, multi_state_strategy_idx, multi_state_alpha=0.5)[source]#

Run the autoregressive (sampling) path.

Parameters:
  • edge_features (Float[Array, 'num_atoms num_neighbors num_features']) – Edge features from feature extraction.

  • neighbor_indices (Int[Array, 'num_atoms num_neighbors']) – Indices of neighbors for each node.

  • mask (Int[Array, 'num_residues 3']) – Alpha carbon mask.

  • ar_mask (Bool[Array, 'num_residues num_residues']) – Autoregressive mask for sampling.

  • _one_hot_sequence (Float[Array, 'num_residues num_classes']) – Unused, required for jax.lax.switch signature.

  • prng_key (Union[Key[Array, ''], UInt32[Array, '2']]) – PRNG key for sampling.

  • temperature (Float) – Temperature for Gumbel-max sampling.

  • bias (Float[Array, 'num_residues num_classes']) – Bias to add to logits before sampling (N, 21).

  • tie_group_map (Array | None) – Optional (N,) array mapping each position to a group ID. When provided, positions in the same group sample identical amino acids.

  • multi_state_strategy_idx (Int) – Integer index for strategy (0=mean, 1=min, 2=product, 3=max_min).

  • multi_state_alpha (float) – Weight for min component when strategy=”max_min”.

Return type:

tuple[Float[Array, 'num_residues num_classes'], Float[Array, 'num_residues num_classes']]

Returns:

Tuple of (sampled sequence, logits).

Raises:

None

Example

>>> key = jax.random.PRNGKey(0)
>>> model = PrxteinMPNN(128, 128, 128, 3, 3, 30, key=key)
>>> edge_feats = jnp.ones((10, 30, 128))
>>> neighbor_idx = jnp.arange(300).reshape(10, 30)
>>> mask = jnp.ones((10,))
>>> ar_mask = jnp.ones((10, 10))
>>> temp = jnp.array(1.0)
>>> bias = jnp.zeros((10, 21))
>>> seq, logits = model._call_autoregressive(
...     edge_feats, neighbor_idx, mask, ar_mask, None, key, temp, bias, None
... )
_call_conditional(edge_features, neighbor_indices, mask, _ar_mask, one_hot_sequence, _prng_key, _temperature, _bias, _tie_group_map, _multi_state_strategy_idx, _multi_state_alpha)[source]#

Run the conditional (scoring) path.

Parameters:
  • edge_features (Float[Array, 'num_atoms num_neighbors num_features']) – Edge features from feature extraction.

  • neighbor_indices (Int[Array, 'num_atoms num_neighbors']) – Indices of neighbors for each node.

  • mask (Int[Array, 'num_residues 3']) – Alpha carbon mask.

  • _ar_mask (Bool[Array, 'num_residues num_residues']) – Autoregressive mask for conditional decoding.

  • one_hot_sequence (Float[Array, 'num_residues num_classes']) – One-hot encoded protein sequence.

  • prng_key – Unused, required for jax.lax.switch signature.

  • _temperature (Float) – Unused, required for jax.lax.switch signature.

  • _bias (Float[Array, 'num_residues num_classes']) – Unused, required for jax.lax.switch signature.

  • _tie_group_map (Array | None) – Unused, required for jax.lax.switch signature.

  • _multi_state_strategy_idx (Int) – Unused, required for jax.lax.switch signature.

  • _multi_state_alpha (float) – Unused, required for jax.lax.switch signature.

  • _prng_key (Key[Array, ''] | UInt32[Array, '2'])

Return type:

tuple[Float[Array, 'num_residues num_classes'], Float[Array, 'num_residues num_classes']]

Returns:

Tuple of (input sequence, logits).

Raises:

None

Example

>>> key = jax.random.PRNGKey(0)
>>> model = PrxteinMPNN(128, 128, 128, 3, 3, 30, key=key)
>>> edge_feats = jnp.ones((10, 30, 128))
>>> neighbor_idx = jnp.arange(300).reshape(10, 30)
>>> mask = jnp.ones((10,))
>>> ar_mask = jnp.ones((10, 10))
>>> seq = jax.nn.one_hot(jnp.arange(10), 21)
>>> out_seq, logits = model._call_conditional(
...     edge_feats, neighbor_idx, mask, ar_mask, seq
... )
_call_unconditional(edge_features, neighbor_indices, mask, _ar_mask, _one_hot_sequence, _prng_key, _temperature, _bias, _tie_group_map, _multi_state_strategy_idx, _multi_state_alpha)[source]#

Run the unconditional (scoring) path.

Parameters:
  • edge_features (Float[Array, 'num_atoms num_neighbors num_features']) – Edge features from feature extraction.

  • neighbor_indices (Int[Array, 'num_atoms num_neighbors']) – Indices of neighbors for each node.

  • mask (Int[Array, 'num_residues 3']) – Alpha carbon mask.

  • _ar_mask (Bool[Array, 'num_residues num_residues']) – Unused, required for jax.lax.switch signature.

  • _one_hot_sequence (Float[Array, 'num_residues num_classes']) – Unused, required for jax.lax.switch signature.

  • prng_key – Unused, required for jax.lax.switch signature.

  • _temperature (Float) – Unused, required for jax.lax.switch signature.

  • _bias (Float[Array, 'num_residues num_classes']) – Unused, required for jax.lax.switch signature.

  • _tie_group_map (Array | None) – Unused, required for jax.lax.switch signature.

  • _multi_state_strategy_idx (Int) – Unused, required for jax.lax.switch signature.

  • _multi_state_alpha (float) – Unused, required for jax.lax.switch signature.

  • _prng_key (Key[Array, ''] | UInt32[Array, '2'])

Return type:

tuple[Float[Array, 'num_residues num_classes'], Float[Array, 'num_residues num_classes']]

Returns:

Tuple of (dummy sequence, logits).

Raises:

None

Example

>>> key = jax.random.PRNGKey(0)
>>> model = PrxteinMPNN(128, 128, 128, 3, 3, 30, key=key)
>>> edge_feats = jnp.ones((10, 30, 128))
>>> neighbor_idx = jnp.arange(300).reshape(10, 30)
>>> mask = jnp.ones((10,))
>>> seq, logits = model._call_unconditional(edge_feats, neighbor_idx, mask)
static _combine_logits_multistate(logits, group_mask, strategy='mean', alpha=0.5)[source]#

Combine logits across tied positions using different multi-state strategies.

Parameters:
  • logits (Float[Array, 'num_residues num_classes']) – Logits array of shape (N, 21).

  • group_mask (Array) – Boolean mask of shape (N,) indicating group membership.

  • strategy (Literal['mean', 'min', 'product', 'max_min']) – Strategy for combining logits: - “mean”: Average logits (consensus prediction, default) - “min”: Minimum logits (worst-case robust design) - “product”: Sum of logits (multiply probabilities) - “max_min”: Weighted combination of min and mean (alpha controls weight)

  • alpha (float) – Weight for min component when strategy=”max_min” (0=pure mean, 1=pure min).

Return type:

Array

Returns:

Combined logits of shape (1, 21).

Example

>>> logits = jnp.array([[10.0, -5.0], [8.0, -3.0]])
>>> group_mask = jnp.array([True, True])
>>> # Average strategy (compromise)
>>> avg = PrxteinMPNN._combine_logits_multistate(logits, group_mask, "mean")
>>> # Min strategy (robust to worst case)
>>> robust = PrxteinMPNN._combine_logits_multistate(logits, group_mask, "min")
static _combine_logits_multistate_idx(logits, group_mask, strategy_idx, alpha=0.5)[source]#

Combine logits using strategy index (JAX-traceable version).

This is a JAX-traceable wrapper around _combine_logits_multistate that accepts an integer strategy index instead of a string. Used internally when the function needs to be JIT-compiled.

Parameters:
  • logits (Float[Array, 'num_residues num_classes']) – Logits array of shape (N, 21).

  • group_mask (Array) – Boolean mask of shape (N,) indicating group membership.

  • strategy_idx (Int) – Integer strategy index (0=mean, 1=min, 2=product, 3=max_min).

  • alpha (float) – Weight for min component when strategy_idx=3 (0=pure mean, 1=pure min).

Return type:

Array

Returns:

Combined logits of shape (1, 21).

_process_group_positions(group_mask, all_layers_h, s_embed, encoder_context, edge_features, neighbor_indices, mask, mask_bw)[source]#

Process all positions in a group through decoder and collect logits.

Parameters:
  • group_mask (Array) – Boolean mask (N,) for positions in current group.

  • all_layers_h (Int[Array, 'num_atoms num_features']) – Hidden states (num_layers+1, N, C).

  • s_embed (Int[Array, 'num_atoms num_features']) – Sequence embeddings (N, C).

  • encoder_context (Array) – Precomputed encoder context (N, K, features).

  • edge_features (Float[Array, 'num_atoms num_neighbors num_features']) – Edge features (N, K, C).

  • neighbor_indices (Int[Array, 'num_atoms num_neighbors']) – Neighbor indices (N, K).

  • mask (Int[Array, 'num_residues 3']) – Alpha carbon mask (N,).

  • mask_bw (Array) – Backward mask (N, K).

Return type:

tuple[Int[Array, 'num_atoms num_features'], Array]

Returns:

Tuple of (updated all_layers_h, computed logits (N, 21)).

_run_autoregressive_scan(prng_key, node_features, edge_features, neighbor_indices, mask, autoregressive_mask, temperature, bias, tie_group_map=None, multi_state_strategy_idx=0, multi_state_alpha=0.5)[source]#

Run JAX scan loop for autoregressive sampling with optional tied positions.

When tie_group_map is provided, the scan iterates over groups instead of individual positions. For each group: 1. Decoder processes all positions in the group 2. Logits are computed for all group members 3. Logits are averaged across the group (log-sum-exp) 4. A single token is sampled from the averaged logits 5. The token is broadcast to all positions in the group

Parameters:
  • prng_key (Union[Key[Array, ''], UInt32[Array, '2']]) – PRNG key for sampling.

  • node_features (Int[Array, 'num_atoms num_features']) – Node features from encoder.

  • edge_features (Float[Array, 'num_atoms num_neighbors num_features']) – Edge features from encoder.

  • neighbor_indices (Int[Array, 'num_atoms num_neighbors']) – Indices of neighbors for each node.

  • mask (Int[Array, 'num_residues 3']) – Alpha carbon mask.

  • autoregressive_mask (Bool[Array, 'num_residues num_residues']) – Mask defining decoding order.

  • temperature (Float) – Temperature for Gumbel-max sampling.

  • bias (Float[Array, 'num_residues num_classes']) – Bias to add to logits before sampling (N, 21).

  • tie_group_map (Array | None) – Optional (N,) array mapping each position to a group ID. When provided, positions in the same group are sampled together using combined logits.

  • multi_state_strategy_idx (Int) – Integer strategy index (0=mean, 1=min, 2=product, 3=max_min).

  • multi_state_alpha (float) – Weight for min component when strategy_idx=3.

Return type:

tuple[Float[Array, 'num_residues num_classes'], Float[Array, 'num_residues num_classes']]

Returns:

Tuple of (sampled sequence, final logits).

Raises:

None

Example

>>> key = jax.random.PRNGKey(0)
>>> model = PrxteinMPNN(128, 128, 128, 3, 3, 30, key=key)
>>> node_feats = jnp.ones((10, 128))
>>> edge_feats = jnp.ones((10, 30, 128))
>>> neighbor_idx = jnp.arange(300).reshape(10, 30)
>>> mask = jnp.ones((10,))
>>> ar_mask = jnp.ones((10, 10))
>>> temp = jnp.array(1.0)
>>> bias = jnp.zeros((10, 21))
>>> seq, logits = model._run_autoregressive_scan(
...     key, node_feats, edge_feats, neighbor_idx, mask, ar_mask, temp, bias
... )
_run_tied_position_scan(prng_key, node_features, edge_features, neighbor_indices, mask, encoder_context, mask_bw, temperature, bias, tie_group_map, decoding_order, multi_state_strategy_idx=0, multi_state_alpha=0.5)[source]#

Run group-based autoregressive scan with logit combining.

Parameters:
  • prng_key (Union[Key[Array, ''], UInt32[Array, '2']]) – PRNG key.

  • node_features (Int[Array, 'num_atoms num_features']) – Node features (N, C).

  • edge_features (Float[Array, 'num_atoms num_neighbors num_features']) – Edge features (N, K, C).

  • neighbor_indices (Int[Array, 'num_atoms num_neighbors']) – Neighbor indices (N, K).

  • mask (Int[Array, 'num_residues 3']) – Alpha carbon mask (N,).

  • encoder_context (Array) – Precomputed encoder context (N, K, features).

  • mask_bw (Array) – Backward mask (N, K).

  • temperature (Float) – Sampling temperature.

  • bias (Float[Array, 'num_residues num_classes']) – Logits array (N, 21).

  • tie_group_map (Array) – Group mapping (N,).

  • decoding_order (Array) – Position decoding order (N,).

  • multi_state_strategy_idx (Int) – Integer strategy index (0=mean, 1=min, 2=product, 3=max_min).

  • multi_state_alpha (float) – Weight for min component when strategy_idx=3.

Return type:

tuple[Float[Array, 'num_residues num_classes'], Float[Array, 'num_residues num_classes']]

Returns:

Tuple of (final sequence, final logits).

_sample_and_broadcast_to_group(avg_logits, group_mask, bias, temperature, key, all_logits, s_embed, sequence)[source]#

Sample once and broadcast token to all positions in a group.

Parameters:
  • avg_logits (Array) – Averaged logits (1, 21).

  • group_mask (Array) – Boolean mask (N,) for group positions.

  • bias (Float[Array, 'num_residues num_classes']) – Bias array (N, 21).

  • temperature (Float) – Sampling temperature.

  • key (Union[Key[Array, ''], UInt32[Array, '2']]) – PRNG key.

  • all_logits (Float[Array, 'num_residues num_classes']) – Current logits array (N, 21).

  • s_embed (Int[Array, 'num_atoms num_features']) – Current sequence embeddings (N, C).

  • sequence (Float[Array, 'num_residues num_classes']) – Current sequence (N, 21).

Return type:

tuple[Float[Array, 'num_residues num_classes'], Int[Array, 'num_atoms num_features'], Float[Array, 'num_residues num_classes']]

Returns:

Tuple of (updated all_logits, updated s_embed, updated sequence).

features: ProteinFeatures#
encoder: Encoder#
decoder: Decoder#
w_s_embed: Embedding#
w_out: Linear#
node_features_dim: int#
edge_features_dim: int#
num_decoder_layers: int#