decoder#
Decoder module for the PrxteinMPNN model.
- prxteinmpnn.model.decoder.decoder_parameter_pytree(model_parameters, num_decoder_layers=3)[source]#
Make the model weights accessible as a PyTree.
- Parameters:
model_parameters (
PyTree[str, 'P']) – Model parameters for the decoder.num_decoder_layers (
int) – Number of decoder layers to set up.
- Return type:
PyTree[str, 'P']- Returns:
Decoder parameters as a PyTree.
- prxteinmpnn.model.decoder.embed_sequence(model_parameters, one_hot_sequence)[source]#
Embed a one-hot encoded sequence.
- Return type:
Int[Array, 'num_atoms num_features']- Parameters:
model_parameters (PyTree[str, 'P'])
one_hot_sequence (Float[Array, 'num_residues num_classes'])
- prxteinmpnn.model.decoder.initialize_conditional_decoder(one_hot_sequence, node_features, edge_features, neighbor_indices, layer_params)[source]#
Initialize the decoder with node and edge features.
- Parameters:
one_hot_sequence (
Float[Array, 'num_residues num_classes']) – One-hot encoded sequence of shape (num_residues, num_classes).node_features (
Int[Array, 'num_atoms num_features']) – Node features of shape (num_atoms, num_features).edge_features (
Float[Array, 'num_atoms num_neighbors num_features']) – EdgeFeatures of shape (num_atoms, num_neighbors, num_features).neighbor_indices (
Int[Array, 'num_atoms num_neighbors']) – Indices of neighboring nodes of shape (num_atoms, num_neighbors).layer_params (
PyTree[str, 'P']) – ModelParameters for the embedding layer.
- Return type:
tuple[Float[Array, 'num_atoms num_neighbors num_features'],Float[Array, 'num_residues num_neighbors num_features']]- Returns:
A tuple of node-edge features and sequence-edge features.
- prxteinmpnn.model.decoder.decode_message(node_features, edge_features, layer_params)[source]#
Decode node and edge features into messages.
- Parameters:
node_features (
Int[Array, 'num_atoms num_features']) – Node features of shape (num_atoms, num_features).edge_features (
Float[Array, 'num_atoms num_neighbors num_features']) – Edge features of shape (num_atoms, num_neighbors, num_features).layer_params (
PyTree[str, 'P']) – Model parameters for the encoding layer.
- Returns:
decoded messages of shape (num_atoms, num_neighbors, num_features).
- Return type:
Message
- prxteinmpnn.model.decoder.decoder_normalize(message, node_features, mask, layer_params, scale=30.0)[source]#
Normalize the decoded messages and update node features.
- Parameters:
message (
Float[Array, 'num_atoms num_neighbors num_features']) – decoded messages of shape (num_atoms, num_neighbors, num_features).node_features (
Int[Array, 'num_atoms num_features']) – Node features of shape (num_atoms, num_features).mask (
Int[Array, 'num_residues num_atoms']) – Atom mask indicating valid atoms.layer_params (
PyTree[str, 'P']) – Model parameters for the normalization layer.scale (
float) – Scaling factor for normalization.
- Return type:
Int[Array, 'num_atoms num_features']- Returns:
Updated node features after normalization.
- prxteinmpnn.model.decoder.make_decode_layer(attention_mask_type)[source]#
Create a function to run the decoder with given model parameters.
- Return type:
Callable[[Unpack[tuple[Int[Array, 'num_atoms num_features'],Float[Array, 'num_atoms num_neighbors num_features'],Int[Array, 'num_residues num_atoms'],Bool[Array, 'num_atoms num_atoms'],PyTree[str, 'P'],float]]],Int[Array, 'num_atoms num_features']] |Callable[[Unpack[tuple[Int[Array, 'num_atoms num_features'],Float[Array, 'num_atoms num_neighbors num_features'],Int[Array, 'num_residues num_atoms'],PyTree[str, 'P'],float]]],Int[Array, 'num_atoms num_features']]- Parameters:
attention_mask_type (Literal['none', 'cross', 'conditional'] | None)
- prxteinmpnn.model.decoder.setup_decoder(model_parameters, attention_mask_type, decoding_approach, num_decoder_layers=3)[source]#
Set up the decoder parameters and initial node features.
- prxteinmpnn.model.decoder.make_decoder(model_parameters, attention_mask_type, decoding_approach='unconditional', num_decoder_layers=3, scale=30.0)[source]#
Create a function to run the decoder with given model parameters.
- Return type:
Callable[[Unpack[tuple[Int[Array, 'num_atoms num_features'],Float[Array, 'num_atoms num_neighbors num_features'],Int[Array, 'num_residues num_atoms']]]],Int[Array, 'num_atoms num_features']] |Callable[[Unpack[tuple[Int[Array, 'num_atoms num_features'],Float[Array, 'num_atoms num_neighbors num_features'],Int[Array, 'num_residues num_atoms'],Bool[Array, 'num_atoms num_atoms']]]],Int[Array, 'num_atoms num_features']] |Callable[[Unpack[tuple[Union[Key[Array, ''],UInt32[Array, '2']],Int[Array, 'num_atoms num_features'],Float[Array, 'num_atoms num_neighbors num_features'],Int[Array, 'num_atoms num_neighbors'],Int[Array, 'num_residues num_atoms'],Bool[Array, 'num_residues num_residues'],Float|None,Float[Array, 'num_residues num_classes']|None]]],tuple[Float[Array, 'num_residues num_classes'],Float[Array, 'num_residues num_classes']]] |Callable[[Unpack[tuple[Int[Array, 'num_atoms num_features'],Float[Array, 'num_atoms num_neighbors num_features'],Int[Array, 'num_atoms num_neighbors'],Int[Array, 'num_residues num_atoms'],Bool[Array, 'num_residues num_residues'],Int[Array, 'num_residues']]]],Int[Array, 'num_atoms num_features']]- Parameters: