Source code for prxteinmpnn.model.features
"""Feature extraction for protein structures in the PrxteinMPNN model."""
from functools import partial
import jax
import jax.numpy as jnp
from jaxtyping import Array, Int, PRNGKeyArray
from prxteinmpnn.utils.coordinates import (
apply_noise_to_coordinates,
compute_backbone_coordinates,
compute_backbone_distance,
)
from prxteinmpnn.utils.graph import NeighborOffsets, compute_neighbor_offsets
from prxteinmpnn.utils.normalize import layer_normalization
from prxteinmpnn.utils.radial_basis import compute_radial_basis
from prxteinmpnn.utils.types import (
AtomMask,
ChainIndex,
EdgeFeatures,
ModelParameters,
NeighborIndices,
ResidueIndex,
StructureAtomicCoordinates,
)
EdgeChainNeighbors = Int[Array, "num_atoms num_neighbors"]
EncodedPositions = Int[Array, "num_atoms num_neighbors (2 * MAXIMUM_RELATIVE_FEATURES + 2)"]
MAXIMUM_RELATIVE_FEATURES = 32
top_k = jax.jit(jax.lax.top_k, static_argnames=("k",))
[docs]
@jax.jit
def get_edge_chains_neighbors(
chain_index: ChainIndex,
neighbor_indices: NeighborIndices,
) -> EdgeChainNeighbors:
"""Compute edge chains for neighbors."""
edge_chains = (chain_index[:, None] == chain_index[None, :]).astype(int)
return jnp.take_along_axis(edge_chains, neighbor_indices, axis=1)
[docs]
@jax.jit
def encode_positions(
neighbor_offsets: NeighborOffsets,
edge_chains_neighbors: EdgeChainNeighbors,
model_parameters: ModelParameters,
) -> EncodedPositions:
"""Encode positions based on neighbor offsets and edge chains."""
neighbor_offset_factor = jnp.clip(
neighbor_offsets + MAXIMUM_RELATIVE_FEATURES,
0,
2 * MAXIMUM_RELATIVE_FEATURES,
)
edge_chain_factor = (1 - edge_chains_neighbors) * (2 * MAXIMUM_RELATIVE_FEATURES + 1)
encoded_offset = neighbor_offset_factor * edge_chains_neighbors + edge_chain_factor
encoded_offset_one_hot = jax.nn.one_hot(encoded_offset, 2 * MAXIMUM_RELATIVE_FEATURES + 2)
pos_enc_params = model_parameters[
"protein_mpnn/~/protein_features/~/positional_encodings/~/embedding_linear"
]
return jnp.dot(encoded_offset_one_hot, pos_enc_params["w"]) + pos_enc_params["b"]
[docs]
@jax.jit
def embed_edges(
edge_features: EncodedPositions,
model_parameters: ModelParameters,
) -> EdgeFeatures:
"""Embed edge features using model parameters."""
edge_emb_params = model_parameters["protein_mpnn/~/protein_features/~/edge_embedding"]
return jnp.dot(edge_features, edge_emb_params["w"])
[docs]
@jax.jit
def project_features(
model_parameters: ModelParameters,
edge_features: EdgeFeatures,
) -> EdgeFeatures:
"""Project edge features using model parameters."""
w_e, b_e = (
model_parameters["protein_mpnn/~/W_e"]["w"],
model_parameters["protein_mpnn/~/W_e"]["b"],
)
return jnp.dot(edge_features, w_e) + b_e