Source code for prxteinmpnn.model.decoder

"""Decoder module for PrxteinMPNN.

This module contains the Equinox-based decoder implementation for ProteinMPNN.
"""

from __future__ import annotations

from functools import partial
from typing import TYPE_CHECKING

import equinox as eqx
import jax
import jax.numpy as jnp

from prxteinmpnn.utils.concatenate import concatenate_neighbor_nodes

if TYPE_CHECKING:
  from prxteinmpnn.utils.types import (
    AlphaCarbonMask,
    Array,
    AutoRegressiveMask,
    EdgeFeatures,
    NeighborIndices,
    NodeFeatures,
    OneHotProteinSequence,
    PRNGKeyArray,
  )

# Layer normalization with a standard epsilon
LayerNorm = eqx.nn.LayerNorm
_gelu = partial(jax.nn.gelu, approximate=False)


[docs] class DecoderLayer(eqx.Module): """A single decoder layer for the ProteinMPNN model.""" message_mlp: eqx.nn.MLP norm1: LayerNorm dense: eqx.nn.MLP # Use eqx.nn.MLP directly norm2: LayerNorm def __init__( self, node_features: int, edge_context_features: int, # This will be 384 hidden_features: int, *, key: PRNGKeyArray, ) -> None: """Initialize the decoder layer. Args: node_features: Dimension of node features (e.g., 128). edge_context_features: Dimension of edge context (e.g., 384). hidden_features: Dimension of hidden layer in dense MLP. key: PRNG key for initialization. Returns: None Raises: None Example: >>> key = jax.random.PRNGKey(0) >>> layer = DecoderLayer(128, 384, 128, key=key) """ keys = jax.random.split(key, 2) # Input dim is [h_i (128), e_context (384)] = 512 mlp_input_dim = node_features + edge_context_features # Message MLP: 512 -> 128 -> 128 -> 128 (width=node_features, not hidden_features) self.message_mlp = eqx.nn.MLP( in_size=mlp_input_dim, out_size=node_features, width_size=node_features, # 128, matches functional W1/W2/W3 depth=2, activation=_gelu, key=keys[0], ) self.norm1 = LayerNorm(node_features) # Use eqx.nn.MLP for the dense layer self.dense = eqx.nn.MLP( in_size=node_features, out_size=node_features, width_size=hidden_features, depth=1, activation=_gelu, key=keys[1], ) self.norm2 = LayerNorm(node_features) def __call__( self, node_features: NodeFeatures, layer_edge_features: EdgeFeatures, # This is the (N, K, 384) context mask: AlphaCarbonMask, scale: float = 30.0, attention_mask: Array | None = None, # Optional attention mask for conditional decoding ) -> NodeFeatures: """Forward pass for the decoder layer. Works for both N-batch (N, C) and single-node (1, C) inputs. Args: node_features: Node features tensor of shape (N, C). layer_edge_features: Edge context features of shape (N, K, 384). mask: Alpha carbon mask of shape (N,). scale: Scaling factor for message aggregation (default: 30.0). attention_mask: Optional attention mask for conditional decoding. Returns: Updated node features of shape (N, C). Raises: None Example: >>> key = jax.random.PRNGKey(0) >>> layer = DecoderLayer(128, 384, 128, key=key) >>> node_feats = jnp.ones((10, 128)) >>> edge_feats = jnp.ones((10, 30, 384)) >>> mask = jnp.ones((10,)) >>> output = layer(node_feats, edge_feats, mask) """ # Tile central node features [h_i (N, 1, C)] node_features_expand = jnp.tile( jnp.expand_dims(node_features, -2), [1, layer_edge_features.shape[1], 1], ) # Concat with context [h_i (N, K, C), e_context (N, K, 384)] mlp_input = jnp.concatenate([node_features_expand, layer_edge_features], -1) # Apply MLP to each (atom, neighbor) pair: vmap over atoms, then over neighbors message = jax.vmap(jax.vmap(self.message_mlp))(mlp_input) # Apply attention mask if provided (for conditional decoding) if attention_mask is not None: message = jnp.expand_dims(attention_mask, -1) * message # Aggregate messages aggregated_message = jnp.sum(message, -2) / scale node_features = node_features + aggregated_message # vmap over N node_features_norm1 = jax.vmap(self.norm1)(node_features) dense_output = jax.vmap(self.dense)(node_features_norm1) # This works node_features = node_features_norm1 + dense_output node_features_norm2 = jax.vmap(self.norm2)(node_features) # Handle both batched (N,) mask and scalar mask if jnp.ndim(mask) == 0: return mask * node_features_norm2 return mask[:, None] * node_features_norm2
[docs] class Decoder(eqx.Module): """The complete decoder module for ProteinMPNN.""" layers: tuple[DecoderLayer, ...] node_features_dim: int = eqx.field(static=True) edge_features_dim: int = eqx.field(static=True) def __init__( self, node_features: int, edge_features: int, # This is the raw edge_features dim (128) hidden_features: int, num_layers: int = 3, *, key: PRNGKeyArray, ) -> None: """Initialize the decoder. Args: node_features: Dimension of node features (e.g., 128). edge_features: Dimension of edge features (e.g., 128). hidden_features: Dimension of hidden layer in decoder layers. num_layers: Number of decoder layers (default: 3). key: PRNG key for initialization. Returns: None Raises: None Example: >>> key = jax.random.PRNGKey(0) >>> decoder = Decoder(128, 128, 128, num_layers=3, key=key) """ self.node_features_dim = node_features self.edge_features_dim = edge_features keys = jax.random.split(key, num_layers) # The context dim is 384 ([h_i/s_i, e_ij, h_j/s_j]) edge_context_features = 384 self.layers = tuple( DecoderLayer(node_features, edge_context_features, hidden_features, key=k) for k in keys ) def __call__( self, node_features: NodeFeatures, edge_features: EdgeFeatures, # Raw 128-dim edges neighbor_indices: NeighborIndices, mask: AlphaCarbonMask, ) -> NodeFeatures: """Forward pass for UNCONDITIONAL decoding. Args: node_features: Node features from encoder of shape (N, 128). edge_features: Edge features from encoder of shape (N, K, 128). neighbor_indices: Indices of neighbors for each node. mask: Alpha carbon mask of shape (N,). Returns: Decoded node features of shape (N, 128). Raises: None Example: >>> key = jax.random.PRNGKey(0) >>> decoder = Decoder(128, 128, 128, num_layers=3, key=key) >>> node_feats = jnp.ones((10, 128)) >>> edge_feats = jnp.ones((10, 30, 128)) >>> neighbor_idx = jnp.arange(300).reshape(10, 30) >>> mask = jnp.ones((10,)) >>> output = decoder(node_feats, edge_feats, neighbor_idx, mask) """ # Prepare 384-dim context tensor *once* # For unconditional: [0, h_E_ij, h_V_j] where j is the neighbor # First concatenate zeros with edge features zeros_with_edges = concatenate_neighbor_nodes( jnp.zeros_like(node_features), edge_features, neighbor_indices, ) # Shape: (N, K, 128 + 128) = (N, K, 256) # Then concatenate node features with the above layer_edge_features = concatenate_neighbor_nodes( node_features, zeros_with_edges, neighbor_indices, ) # Shape: (N, K, 256 + 128) = (N, K, 384) loop_node_features = node_features for layer in self.layers: loop_node_features = layer( loop_node_features, layer_edge_features, mask, ) return loop_node_features
[docs] def call_conditional( self, node_features: NodeFeatures, # h_i from encoder edge_features: EdgeFeatures, # e_ij from encoder neighbor_indices: NeighborIndices, mask: AlphaCarbonMask, ar_mask: AutoRegressiveMask, one_hot_sequence: OneHotProteinSequence, w_s_weight: Array, # Sequence embedding weight ) -> NodeFeatures: """Forward pass for CONDITIONAL decoding (scoring). Args: node_features: Node features from encoder of shape (N, 128). edge_features: Edge features from encoder of shape (N, K, 128). neighbor_indices: Indices of neighbors for each node. mask: Alpha carbon mask of shape (N,). ar_mask: Autoregressive mask for conditional decoding. one_hot_sequence: One-hot encoded protein sequence. w_s_weight: Sequence embedding weight matrix. Returns: Decoded node features of shape (N, 128). Raises: None Example: >>> key = jax.random.PRNGKey(0) >>> decoder = Decoder(128, 128, 128, num_layers=3, key=key) >>> node_feats = jnp.ones((10, 128)) >>> edge_feats = jnp.ones((10, 30, 128)) >>> neighbor_indices = jnp.arange(300).reshape(10, 30) >>> mask = jnp.ones((10,)) >>> ar_mask = jnp.ones((10, 10)) >>> seq = jax.nn.one_hot(jnp.arange(10), 21) >>> w_s = jnp.ones((21, 128)) >>> output = decoder.call_conditional( ... node_feats, edge_feats, neighbor_indices, mask, ar_mask, seq, w_s ... ) """ # 1. Embed the sequence embedded_sequence = one_hot_sequence @ w_s_weight # s_i # 2. Initialize context features # Following functional implementation (decoder.py lines 127-141) # First: [0, e_ij, h_j] -> (N, K, 256) temp_node_edge = concatenate_neighbor_nodes( jnp.zeros_like(node_features), edge_features, neighbor_indices, ) # Second: [h_i, [0, e_ij, h_j]] -> (N, K, 384) node_edge_features = concatenate_neighbor_nodes( node_features, temp_node_edge, neighbor_indices, ) # [e_ij, s_j] -> (N, K, 256) # Note: concatenate_neighbor_nodes returns [edge_features, neighbor_features] sequence_edge_features = concatenate_neighbor_nodes( embedded_sequence, edge_features, neighbor_indices, ) # 3. Prepare masks attention_mask = jnp.take_along_axis(ar_mask, neighbor_indices, axis=1) mask_bw = mask[:, None] * attention_mask mask_fw = mask[:, None] * (1 - attention_mask) masked_node_edge_features = mask_fw[..., None] * node_edge_features # 4. Run the decoder loop # Following functional implementation (decoder.py lines 480-497) loop_node_features = node_features for layer in self.layers: # Construct the decoder context for this layer by gathering neighbor features # and concatenating with sequence edge features current_features = concatenate_neighbor_nodes( loop_node_features, # (N, 128) -> gather neighbors -> (N, K, 128) = h_j sequence_edge_features, # (N, K, 256) = [e_ij, s_j] neighbor_indices, ) # Result: (N, K, 384) = [e_ij, s_j, h_j] layer_edge_features = (mask_bw[..., None] * current_features) + masked_node_edge_features # Run the layer (masking already applied to layer_edge_features) loop_node_features = layer( loop_node_features, layer_edge_features, mask, ) return loop_node_features