encoder#
Encoder module for the PrxteinMPNN model.
- prxteinmpnn.model.encoder.encoder_parameter_pytree(model_parameters, num_encoder_layers=3)[source]#
Make the model weights accessible as a PyTree.
- Parameters:
model_parameters (
PyTree[str, 'P']
) – Model parameters for the encoder.edge_features – Edge features to initialize the node features.
num_encoder_layers (
int
) – Number of encoder layers to set up.
- Returns:
A tuple containing the encoder parameters as a PyTree and the initial node features.
- Return type:
- prxteinmpnn.model.encoder.encode(node_features, edge_features, neighbor_indices, layer_params)[source]#
Encode node and edge features into messages.
- Parameters:
node_features (
Int[Array, 'num_atoms num_features']
) – Node features of shape (num_atoms, num_features).edge_features (
Float[Array, 'num_atoms num_neighbors num_features']
) – Edge features of shape (num_atoms, num_neighbors, num_features).neighbor_indices (
Int[Array, 'num_atoms num_neighbors']
) – Indices of neighboring nodes of shape (num_atoms, num_neighbors).layer_params (
PyTree[str, 'P']
) – Model parameters for the encoding layer.
- Returns:
Encoded messages of shape (num_atoms, num_neighbors, num_features).
- Return type:
Message
- prxteinmpnn.model.encoder.encoder_normalize(message, node_features, edge_features, neighbor_indices, mask, layer_params, scale=30.0)[source]#
Normalize the encoded messages and update node features.
- Parameters:
message (
Float[Array, 'num_atoms num_neighbors num_features']
) – Encoded messages of shape (num_atoms, num_neighbors, num_features).node_features (
Int[Array, 'num_atoms num_features']
) – Node features of shape (num_atoms, num_features).edge_features (
Float[Array, 'num_atoms num_neighbors num_features']
) – Edge features of shape (num_atoms, num_neighbors, num_features).neighbor_indices (
Int[Array, 'num_atoms num_neighbors']
) – Indices of neighboring nodes of shape (num_atoms, num_neighbors).mask (
Int[Array, 'num_residues num_atoms']
) – Atom mask indicating valid atoms.layer_params (
PyTree[str, 'P']
) – Model parameters for the normalization layer.scale (
float
) – Scaling factor for normalization.
- Returns:
Updated node features and edge features after normalization.
- Return type:
- prxteinmpnn.model.encoder.make_encode_layer(attention_mask_enum)[source]#
Create a function to run the encoder with given model parameters.
- Return type:
Callable
[...
,Float[Array, 'num_atoms num_neighbors num_features']
]- Parameters:
attention_mask_enum (MaskedAttentionEnum)
- prxteinmpnn.model.encoder.initialize_node_features(model_parameters, edge_features)[source]#
Initialize node features based on model parameters.
- Return type:
Int[Array, 'num_atoms num_features']
- Parameters:
model_parameters (PyTree[str, 'P'])
edge_features (Float[Array, 'num_atoms num_neighbors num_features'])
- prxteinmpnn.model.encoder.setup_encoder(model_parameters, attention_mask_enum, num_encoder_layers=3)[source]#
Set up the encoder parameters and initial node features.