Source code for prxteinmpnn.model.encoder

"""Encoder module for the PrxteinMPNN model."""

from __future__ import annotations

from typing import TYPE_CHECKING

import jax
import jax.numpy as jnp

from prxteinmpnn.utils.concatenate import concatenate_neighbor_nodes
from prxteinmpnn.utils.gelu import GeLU
from prxteinmpnn.utils.normalize import layer_normalization

if TYPE_CHECKING:
  from collections.abc import Callable

  from jaxtyping import Int

  from prxteinmpnn.utils.types import (
    AtomMask,
    AttentionMask,
    EdgeFeatures,
    Message,
    ModelParameters,
    NeighborIndices,
    NodeFeatures,
  )


from .dense import dense_layer
from .masked_attention import MaskedAttentionEnum, mask_attention


[docs] def encoder_parameter_pytree( model_parameters: ModelParameters, num_encoder_layers: int = 3, ) -> ModelParameters: """Make the model weights accessible as a PyTree. Args: model_parameters: Model parameters for the encoder. edge_features: Edge features to initialize the node features. num_encoder_layers: Number of encoder layers to set up. Returns: tuple: A tuple containing the encoder parameters as a PyTree and the initial node features. """ all_encoder_layer_params_list = [] for i in range(num_encoder_layers): prefix = "protein_mpnn/~/enc_layer" if i > 0: prefix += f"_{i}" layer_name_suffix = f"enc{i}" layer_params_dict = { "W1": model_parameters[f"{prefix}/~/{layer_name_suffix}_W1"], "W2": model_parameters[f"{prefix}/~/{layer_name_suffix}_W2"], "W3": model_parameters[f"{prefix}/~/{layer_name_suffix}_W3"], "norm1": model_parameters[f"{prefix}/~/{layer_name_suffix}_norm1"], "dense_W_in": model_parameters[ f"{prefix}/~/position_wise_feed_forward/~/{layer_name_suffix}_dense_W_in" ], "dense_W_out": model_parameters[ f"{prefix}/~/position_wise_feed_forward/~/{layer_name_suffix}_dense_W_out" ], "norm2": model_parameters[f"{prefix}/~/{layer_name_suffix}_norm2"], "W11": model_parameters[f"{prefix}/~/{layer_name_suffix}_W11"], "W12": model_parameters[f"{prefix}/~/{layer_name_suffix}_W12"], "W13": model_parameters[f"{prefix}/~/{layer_name_suffix}_W13"], "norm3": model_parameters[f"{prefix}/~/{layer_name_suffix}_norm3"], } all_encoder_layer_params_list.append(layer_params_dict) return jax.tree_util.tree_map( lambda *args: jnp.stack(args), *all_encoder_layer_params_list, )
[docs] @jax.jit def encode( node_features: NodeFeatures, edge_features: EdgeFeatures, neighbor_indices: NeighborIndices, layer_params: ModelParameters, ) -> Message: """Encode node and edge features into messages. Args: node_features: Node features of shape (num_atoms, num_features). edge_features: Edge features of shape (num_atoms, num_neighbors, num_features). neighbor_indices: Indices of neighboring nodes of shape (num_atoms, num_neighbors). layer_params: Model parameters for the encoding layer. Returns: Message: Encoded messages of shape (num_atoms, num_neighbors, num_features). """ edge_features = concatenate_neighbor_nodes(node_features, edge_features, neighbor_indices) node_features_expand = jnp.tile( jnp.expand_dims(node_features, -2), [1, edge_features.shape[-2], 1], ) edge_features = jnp.concatenate([node_features_expand, edge_features], -1) w1, b1, w2, b2, w3, b3 = ( layer_params["W1"]["w"], layer_params["W1"]["b"], layer_params["W2"]["w"], layer_params["W2"]["b"], layer_params["W3"]["w"], layer_params["W3"]["b"], ) message = GeLU(jnp.dot(GeLU(jnp.dot(edge_features, w1) + b1), w2) + b2) return jnp.dot(message, w3) + b3
[docs] @jax.jit def encoder_normalize( message: Message, node_features: NodeFeatures, edge_features: EdgeFeatures, neighbor_indices: NeighborIndices, mask: AtomMask, layer_params: ModelParameters, scale: float = 30.0, ) -> tuple[NodeFeatures, EdgeFeatures]: """Normalize the encoded messages and update node features. Args: message: Encoded messages of shape (num_atoms, num_neighbors, num_features). node_features: Node features of shape (num_atoms, num_features). edge_features: Edge features of shape (num_atoms, num_neighbors, num_features). neighbor_indices: Indices of neighboring nodes of shape (num_atoms, num_neighbors). mask: Atom mask indicating valid atoms. layer_params: Model parameters for the normalization layer. scale: Scaling factor for normalization. Returns: tuple: Updated node features and edge features after normalization. """ node_features = node_features + (jnp.sum(message, -2) / scale) norm1_params = layer_params["norm1"] node_features = layer_normalization(node_features, norm1_params) node_features = node_features + dense_layer(layer_params, node_features) norm2_params = layer_params["norm2"] node_features = layer_normalization(node_features, norm2_params) node_features = mask[:, None] * node_features edge_features_cat = concatenate_neighbor_nodes(node_features, edge_features, neighbor_indices) node_features_expand = jnp.tile( jnp.expand_dims(node_features, -2), [1, edge_features_cat.shape[-2], 1], ) mlp_input = jnp.concatenate([node_features_expand, edge_features_cat], -1) w11, b11 = layer_params["W11"]["w"], layer_params["W11"]["b"] w12, b12 = layer_params["W12"]["w"], layer_params["W12"]["b"] w13, b13 = layer_params["W13"]["w"], layer_params["W13"]["b"] edge_message = GeLU(jnp.dot(GeLU(jnp.dot(mlp_input, w11) + b11), w12) + b12) edge_message = jnp.dot(edge_message, w13) + b13 norm3_params = layer_params["norm3"] updated_edge_features = layer_normalization(edge_features + edge_message, norm3_params) return node_features, updated_edge_features
[docs] def make_encode_layer( attention_mask_enum: MaskedAttentionEnum, ) -> Callable[..., Message]: """Create a function to run the encoder with given model parameters.""" if attention_mask_enum is not MaskedAttentionEnum.NONE: @jax.jit def masked_attn_encoder_fn( node_features: NodeFeatures, edge_features: EdgeFeatures, neighbor_indices: NeighborIndices, mask: AtomMask, attention_mask: AttentionMask, layer_params: ModelParameters, scale: float = 30.0, ) -> Message: """Run the encoder with the provided edge features and neighbor indices.""" message = encode(node_features, edge_features, neighbor_indices, layer_params) message = mask_attention(message, attention_mask) return encoder_normalize( message, node_features, edge_features, neighbor_indices, mask, layer_params, scale, ) return masked_attn_encoder_fn @jax.jit def encoder_fn( node_features: NodeFeatures, edge_features: EdgeFeatures, neighbor_indices: NeighborIndices, mask: AtomMask, layer_params: ModelParameters, scale: float = 30.0, ) -> Message: """Run the encoder with the provided edge features and neighbor indices.""" message = encode(node_features, edge_features, neighbor_indices, layer_params) return encoder_normalize( message, node_features, edge_features, neighbor_indices, mask, layer_params, scale, ) return encoder_fn
[docs] @jax.jit def initialize_node_features( model_parameters: ModelParameters, edge_features: EdgeFeatures, ) -> NodeFeatures: """Initialize node features based on model parameters.""" return jnp.zeros( (edge_features.shape[0], model_parameters["protein_mpnn/~/W_e"]["b"].shape[0]), )
[docs] def setup_encoder( model_parameters: ModelParameters, attention_mask_enum: MaskedAttentionEnum, num_encoder_layers: int = 3, ) -> tuple[ModelParameters, Callable[..., Message]]: """Set up the encoder parameters and initial node features.""" all_encoder_layer_params = encoder_parameter_pytree(model_parameters, num_encoder_layers) encode_layer_fn = make_encode_layer(attention_mask_enum=attention_mask_enum) return all_encoder_layer_params, encode_layer_fn
[docs] def make_encoder( model_parameters: ModelParameters, attention_mask_enum: MaskedAttentionEnum, num_encoder_layers: int = 3, scale: float = 30.0, ) -> Callable[..., tuple[NodeFeatures, EdgeFeatures]]: """Create a function to run the encoder with given model parameters.""" all_encoder_layer_params, encode_layer_fn = setup_encoder( model_parameters, attention_mask_enum, num_encoder_layers, ) if attention_mask_enum is MaskedAttentionEnum.NONE: @jax.jit def run_encoder( edge_features: EdgeFeatures, neighbor_indices: NeighborIndices, mask: AtomMask, ) -> tuple[NodeFeatures, EdgeFeatures]: """Run the encoder with the provided edge features and neighbor indices.""" node_features_encoder = initialize_node_features(model_parameters, edge_features) def encoder_loop_body( i: Int, carry: tuple[NodeFeatures, EdgeFeatures], ) -> tuple[NodeFeatures, EdgeFeatures]: node_features, edge_features = carry current_layer_params = jax.tree_util.tree_map(lambda x: x[i], all_encoder_layer_params) node_features, edge_features = encode_layer_fn( node_features, edge_features, neighbor_indices, mask, current_layer_params, scale, ) return (node_features, edge_features) return jax.lax.fori_loop( 0, num_encoder_layers, encoder_loop_body, (node_features_encoder, edge_features), ) return run_encoder @jax.jit def run_masked_attention_encoder( edge_features: EdgeFeatures, neighbor_indices: NeighborIndices, mask: AtomMask, attention_mask: AttentionMask, ) -> tuple[NodeFeatures, EdgeFeatures]: """Run the encoder with the provided edge features and neighbor indices.""" node_features_encoder = initialize_node_features(model_parameters, edge_features) def encoder_loop_body( i: Int, carry: tuple[NodeFeatures, EdgeFeatures], ) -> tuple[NodeFeatures, EdgeFeatures]: node_features, edge_features = carry current_layer_params = jax.tree_util.tree_map(lambda x: x[i], all_encoder_layer_params) node_features, edge_features = encode_layer_fn( node_features, edge_features, neighbor_indices, mask, attention_mask, current_layer_params, scale, ) return (node_features, edge_features) return jax.lax.fori_loop( 0, num_encoder_layers, encoder_loop_body, (node_features_encoder, edge_features), ) return run_masked_attention_encoder