encoder#

Encoder module for the PrxteinMPNN model.

prxteinmpnn.model.encoder.encoder_parameter_pytree(model_parameters, num_encoder_layers=3)[source]#

Make the model weights accessible as a PyTree.

Parameters:
  • model_parameters (PyTree[str, 'P']) – Model parameters for the encoder.

  • edge_features – Edge features to initialize the node features.

  • num_encoder_layers (int) – Number of encoder layers to set up.

Returns:

A tuple containing the encoder parameters as a PyTree and the initial node features.

Return type:

tuple

prxteinmpnn.model.encoder.encode(node_features, edge_features, neighbor_indices, layer_params)[source]#

Encode node and edge features into messages.

Parameters:
  • node_features (Int[Array, 'num_atoms num_features']) – Node features of shape (num_atoms, num_features).

  • edge_features (Float[Array, 'num_atoms num_neighbors num_features']) – Edge features of shape (num_atoms, num_neighbors, num_features).

  • neighbor_indices (Int[Array, 'num_atoms num_neighbors']) – Indices of neighboring nodes of shape (num_atoms, num_neighbors).

  • layer_params (PyTree[str, 'P']) – Model parameters for the encoding layer.

Returns:

Encoded messages of shape (num_atoms, num_neighbors, num_features).

Return type:

Message

prxteinmpnn.model.encoder.encoder_normalize(message, node_features, edge_features, neighbor_indices, mask, layer_params, scale=30.0)[source]#

Normalize the encoded messages and update node features.

Parameters:
  • message (Float[Array, 'num_atoms num_neighbors num_features']) – Encoded messages of shape (num_atoms, num_neighbors, num_features).

  • node_features (Int[Array, 'num_atoms num_features']) – Node features of shape (num_atoms, num_features).

  • edge_features (Float[Array, 'num_atoms num_neighbors num_features']) – Edge features of shape (num_atoms, num_neighbors, num_features).

  • neighbor_indices (Int[Array, 'num_atoms num_neighbors']) – Indices of neighboring nodes of shape (num_atoms, num_neighbors).

  • mask (Int[Array, 'num_residues num_atoms']) – Atom mask indicating valid atoms.

  • layer_params (PyTree[str, 'P']) – Model parameters for the normalization layer.

  • scale (float) – Scaling factor for normalization.

Returns:

Updated node features and edge features after normalization.

Return type:

tuple

prxteinmpnn.model.encoder.make_encode_layer(attention_mask_enum)[source]#

Create a function to run the encoder with given model parameters.

Return type:

Callable[..., Float[Array, 'num_atoms num_neighbors num_features']]

Parameters:

attention_mask_enum (MaskedAttentionEnum)

prxteinmpnn.model.encoder.initialize_node_features(model_parameters, edge_features)[source]#

Initialize node features based on model parameters.

Return type:

Int[Array, 'num_atoms num_features']

Parameters:
  • model_parameters (PyTree[str, 'P'])

  • edge_features (Float[Array, 'num_atoms num_neighbors num_features'])

prxteinmpnn.model.encoder.setup_encoder(model_parameters, attention_mask_enum, num_encoder_layers=3)[source]#

Set up the encoder parameters and initial node features.

Return type:

tuple[PyTree[str, 'P'], Callable[..., Float[Array, 'num_atoms num_neighbors num_features']]]

Parameters:
prxteinmpnn.model.encoder.make_encoder(model_parameters, attention_mask_enum, num_encoder_layers=3, scale=30.0)[source]#

Create a function to run the encoder with given model parameters.

Return type:

Callable[..., tuple[Int[Array, 'num_atoms num_features'], Float[Array, 'num_atoms num_neighbors num_features']]]

Parameters: