decoder#
Decoder module for PrxteinMPNN.
This module contains the Equinox-based decoder implementation for ProteinMPNN.
- class prxteinmpnn.model.decoder.DecoderLayer(node_features, edge_context_features, hidden_features, *, key)[source]#
Bases:
ModuleA single decoder layer for the ProteinMPNN model.
- Parameters:
-
message_mlp:
MLP#
-
norm1:
LayerNorm#
-
dense:
MLP#
-
norm2:
LayerNorm#
- class prxteinmpnn.model.decoder.Decoder(node_features, edge_features, hidden_features, num_layers=3, *, key)[source]#
Bases:
ModuleThe complete decoder module for ProteinMPNN.
- Parameters:
-
layers:
tuple[DecoderLayer,...]#
- call_conditional(node_features, edge_features, neighbor_indices, mask, ar_mask, one_hot_sequence, w_s_weight)[source]#
Forward pass for CONDITIONAL decoding (scoring).
- Parameters:
node_features (
Int[Array, 'num_atoms num_features']) – Node features from encoder of shape (N, 128).edge_features (
Float[Array, 'num_atoms num_neighbors num_features']) – Edge features from encoder of shape (N, K, 128).neighbor_indices (
Int[Array, 'num_atoms num_neighbors']) – Indices of neighbors for each node.mask (
Int[Array, 'num_residues 3']) – Alpha carbon mask of shape (N,).ar_mask (
Bool[Array, 'num_residues num_residues']) – Autoregressive mask for conditional decoding.one_hot_sequence (
Float[Array, 'num_residues num_classes']) – One-hot encoded protein sequence.w_s_weight (
Array) – Sequence embedding weight matrix.
- Return type:
Int[Array, 'num_atoms num_features']- Returns:
Decoded node features of shape (N, 128).
- Raises:
None –
Example
>>> key = jax.random.PRNGKey(0) >>> decoder = Decoder(128, 128, 128, num_layers=3, key=key) >>> node_feats = jnp.ones((10, 128)) >>> edge_feats = jnp.ones((10, 30, 128)) >>> neighbor_indices = jnp.arange(300).reshape(10, 30) >>> mask = jnp.ones((10,)) >>> ar_mask = jnp.ones((10, 10)) >>> seq = jax.nn.one_hot(jnp.arange(10), 21) >>> w_s = jnp.ones((21, 128)) >>> output = decoder.call_conditional( ... node_feats, edge_feats, neighbor_indices, mask, ar_mask, seq, w_s ... )