"""Feature extraction module for PrxteinMPNN.
This module contains the ProteinFeatures class that extracts and projects
features from raw protein coordinates.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import equinox as eqx
import jax
import jax.numpy as jnp
from prxteinmpnn.utils.coordinates import (
apply_noise_to_coordinates,
compute_backbone_coordinates,
compute_backbone_distance,
)
from prxteinmpnn.utils.graph import compute_neighbor_offsets
from prxteinmpnn.utils.radial_basis import compute_radial_basis
if TYPE_CHECKING:
from prxteinmpnn.utils.types import (
AlphaCarbonMask,
BackboneNoise,
ChainIndex,
EdgeFeatures,
NeighborIndices,
ResidueIndex,
StructureAtomicCoordinates,
)
# Type alias for PRNG keys
PRNGKeyArray = jax.Array
# Layer normalization
LayerNorm = eqx.nn.LayerNorm
# Feature extraction constants
MAXIMUM_RELATIVE_FEATURES = 32
POS_EMBED_DIM = 16
top_k = jax.jit(jax.lax.top_k, static_argnames=("k",))
[docs]
class ProteinFeatures(eqx.Module):
"""Extracts and projects features from raw protein coordinates.
This module encapsulates k-NN, RBF, positional encodings, and edge projections.
Note: W_e projection is NOT here - it's in the main model (matches ColabDesign).
"""
w_pos: eqx.nn.Linear
w_e: eqx.nn.Linear
norm_edges: LayerNorm
w_e_proj: eqx.nn.Linear # Final edge projection
k_neighbors: int = eqx.field(static=True)
rbf_dim: int = eqx.field(static=True)
pos_embed_dim: int = eqx.field(static=True)
def __init__(
self,
node_features: int, # noqa: ARG002
edge_features: int,
k_neighbors: int,
*,
key: PRNGKeyArray,
) -> None:
"""Initialize feature extraction layers.
Args:
node_features: Dimension of node features (not directly used, kept for API compat).
edge_features: Dimension of edge features.
k_neighbors: Number of nearest neighbors to consider.
key: PRNG key for initialization.
"""
keys = jax.random.split(key, 3)
self.k_neighbors = k_neighbors
self.rbf_dim = 16
self.pos_embed_dim = POS_EMBED_DIM
pos_one_hot_dim = 2 * MAXIMUM_RELATIVE_FEATURES + 2 # 66
edge_embed_in_dim = 416 # Match original model's edge embedding input size
self.w_pos = eqx.nn.Linear(pos_one_hot_dim, POS_EMBED_DIM, key=keys[0])
self.w_e = eqx.nn.Linear(edge_embed_in_dim, edge_features, use_bias=False, key=keys[1])
self.norm_edges = LayerNorm(edge_features)
self.w_e_proj = eqx.nn.Linear(edge_features, edge_features, key=keys[2])
def __call__(
self,
prng_key: PRNGKeyArray,
structure_coordinates: StructureAtomicCoordinates,
mask: AlphaCarbonMask,
residue_index: ResidueIndex,
chain_index: ChainIndex,
backbone_noise: BackboneNoise | None,
) -> tuple[EdgeFeatures, NeighborIndices, PRNGKeyArray]:
"""Extract and project features from protein structure.
Args:
prng_key: PRNG key for coordinate noise.
structure_coordinates: Atomic coordinates (N, CA, C, O).
mask: Alpha carbon mask.
residue_index: Residue indices.
chain_index: Chain indices.
backbone_noise: Noise to add to backbone coordinates.
Returns:
Tuple of (edge_features, neighbor_indices, updated_prng_key).
"""
if backbone_noise is None:
backbone_noise = jnp.array(0.0, dtype=jnp.float32)
noised_coordinates, prng_key = apply_noise_to_coordinates(
prng_key,
structure_coordinates,
backbone_noise=backbone_noise,
)
backbone_atom_coordinates = compute_backbone_coordinates(noised_coordinates)
distances = compute_backbone_distance(backbone_atom_coordinates)
distances_masked = jnp.array(
jnp.where(
(mask[:, None] * mask[None, :]).astype(bool),
distances,
jnp.inf,
),
)
k = min(self.k_neighbors, structure_coordinates.shape[0])
_, neighbor_indices = top_k(-distances_masked, k)
neighbor_indices = jnp.array(neighbor_indices, dtype=jnp.int32)
rbf = compute_radial_basis(backbone_atom_coordinates, neighbor_indices)
neighbor_offsets = compute_neighbor_offsets(residue_index, neighbor_indices)
# Get edge chains neighbors
edge_chains = (chain_index[:, None] == chain_index[None, :]).astype(int)
edge_chains_neighbors = jnp.take_along_axis(
edge_chains,
neighbor_indices,
axis=1,
)
# Encode positions
neighbor_offset_factor = jnp.clip(
neighbor_offsets + MAXIMUM_RELATIVE_FEATURES,
0,
2 * MAXIMUM_RELATIVE_FEATURES,
)
edge_chain_factor = (1 - edge_chains_neighbors) * (2 * MAXIMUM_RELATIVE_FEATURES + 1)
encoded_offset = neighbor_offset_factor * edge_chains_neighbors + edge_chain_factor
encoded_offset_one_hot = jax.nn.one_hot(
encoded_offset,
2 * MAXIMUM_RELATIVE_FEATURES + 2,
)
# vmap over (N, K)
encoded_positions = jax.vmap(jax.vmap(self.w_pos))(encoded_offset_one_hot)
# Embed edges
edges = jnp.concatenate([encoded_positions, rbf], axis=-1)
edge_features = jax.vmap(jax.vmap(self.w_e))(edges)
edge_features = jax.vmap(jax.vmap(self.norm_edges))(edge_features)
# Final edge projection (W_e in ColabDesign)
edge_features = jax.vmap(jax.vmap(self.w_e_proj))(edge_features)
return edge_features, neighbor_indices, prng_key