decoder#

Decoder module for PrxteinMPNN.

This module contains the Equinox-based decoder implementation for ProteinMPNN.

class prxteinmpnn.model.decoder.DecoderLayer(node_features, edge_context_features, hidden_features, *, key)[source]#

Bases: Module

A single decoder layer for the ProteinMPNN model.

Parameters:
  • node_features (int)

  • edge_context_features (int)

  • hidden_features (int)

  • key (PRNGKeyArray)

message_mlp: MLP#
norm1: LayerNorm#
dense: MLP#
norm2: LayerNorm#
class prxteinmpnn.model.decoder.Decoder(node_features, edge_features, hidden_features, num_layers=3, *, key)[source]#

Bases: Module

The complete decoder module for ProteinMPNN.

Parameters:
  • node_features (int)

  • edge_features (int)

  • hidden_features (int)

  • num_layers (int)

  • key (PRNGKeyArray)

node_features_dim: int#
edge_features_dim: int#
layers: tuple[DecoderLayer, ...]#
call_conditional(node_features, edge_features, neighbor_indices, mask, ar_mask, one_hot_sequence, w_s_weight)[source]#

Forward pass for CONDITIONAL decoding (scoring).

Parameters:
  • node_features (Int[Array, 'num_atoms num_features']) – Node features from encoder of shape (N, 128).

  • edge_features (Float[Array, 'num_atoms num_neighbors num_features']) – Edge features from encoder of shape (N, K, 128).

  • neighbor_indices (Int[Array, 'num_atoms num_neighbors']) – Indices of neighbors for each node.

  • mask (Int[Array, 'num_residues 3']) – Alpha carbon mask of shape (N,).

  • ar_mask (Bool[Array, 'num_residues num_residues']) – Autoregressive mask for conditional decoding.

  • one_hot_sequence (Float[Array, 'num_residues num_classes']) – One-hot encoded protein sequence.

  • w_s_weight (Array) – Sequence embedding weight matrix.

Return type:

Int[Array, 'num_atoms num_features']

Returns:

Decoded node features of shape (N, 128).

Raises:

None

Example

>>> key = jax.random.PRNGKey(0)
>>> decoder = Decoder(128, 128, 128, num_layers=3, key=key)
>>> node_feats = jnp.ones((10, 128))
>>> edge_feats = jnp.ones((10, 30, 128))
>>> neighbor_indices = jnp.arange(300).reshape(10, 30)
>>> mask = jnp.ones((10,))
>>> ar_mask = jnp.ones((10, 10))
>>> seq = jax.nn.one_hot(jnp.arange(10), 21)
>>> w_s = jnp.ones((21, 128))
>>> output = decoder.call_conditional(
...     node_feats, edge_feats, neighbor_indices, mask, ar_mask, seq, w_s
... )