decoder#
Decoder module for the PrxteinMPNN model.
- class prxteinmpnn.model.decoder.DecodingEnum(value)[source]#
Bases:
Enum
Enum for different types of decoders.
- CONDITIONAL = 'conditional'#
- UNCONDITIONAL = 'unconditional'#
- AUTOREGRESSIVE = 'autoregressive'#
- prxteinmpnn.model.decoder.decoder_parameter_pytree(model_parameters, num_decoder_layers=3)[source]#
Make the model weights accessible as a PyTree.
- Parameters:
model_parameters (
PyTree[str, 'P']
) – Model parameters for the decoder.num_decoder_layers (
int
) – Number of decoder layers to set up.
- Return type:
PyTree[str, 'P']
- Returns:
Decoder parameters as a PyTree.
- prxteinmpnn.model.decoder.embed_sequence(model_parameters, one_hot_sequence)[source]#
Embed a one-hot encoded sequence.
- Return type:
Int[Array, 'num_atoms num_features']
- Parameters:
model_parameters (PyTree[str, 'P'])
one_hot_sequence (Float[Array, 'num_residues num_classes'])
- prxteinmpnn.model.decoder.initialize_conditional_decoder(one_hot_sequence, node_features, edge_features, neighbor_indices, layer_params)[source]#
Initialize the decoder with node and edge features.
- Parameters:
one_hot_sequence (
Float[Array, 'num_residues num_classes']
) – One-hot encoded sequence of shape (num_residues, num_classes).node_features (
Int[Array, 'num_atoms num_features']
) – Node features of shape (num_atoms, num_features).edge_features (
Float[Array, 'num_atoms num_neighbors num_features']
) – EdgeFeatures of shape (num_atoms, num_neighbors, num_features).neighbor_indices (
Int[Array, 'num_atoms num_neighbors']
) – Indices of neighboring nodes of shape (num_atoms, num_neighbors).layer_params (
PyTree[str, 'P']
) – ModelParameters for the embedding layer.
- Return type:
tuple
[Float[Array, 'num_atoms num_neighbors num_features']
,Float[Array, 'num_residues num_neighbors num_features']
]- Returns:
A tuple of node-edge features and sequence-edge features.
- prxteinmpnn.model.decoder.decode_message(node_features, edge_features, layer_params)[source]#
Decode node and edge features into messages.
- Parameters:
node_features (
Int[Array, 'num_atoms num_features']
) – Node features of shape (num_atoms, num_features).edge_features (
Float[Array, 'num_atoms num_neighbors num_features']
) – Edge features of shape (num_atoms, num_neighbors, num_features).layer_params (
PyTree[str, 'P']
) – Model parameters for the encoding layer.
- Returns:
decoded messages of shape (num_atoms, num_neighbors, num_features).
- Return type:
Message
- prxteinmpnn.model.decoder.decoder_normalize(message, node_features, mask, layer_params, scale=30.0)[source]#
Normalize the decoded messages and update node features.
- Parameters:
message (
Float[Array, 'num_atoms num_neighbors num_features']
) – decoded messages of shape (num_atoms, num_neighbors, num_features).node_features (
Int[Array, 'num_atoms num_features']
) – Node features of shape (num_atoms, num_features).mask (
Int[Array, 'num_residues num_atoms']
) – Atom mask indicating valid atoms.layer_params (
PyTree[str, 'P']
) – Model parameters for the normalization layer.scale (
float
) – Scaling factor for normalization.
- Return type:
Int[Array, 'num_atoms num_features']
- Returns:
Updated node features after normalization.
- prxteinmpnn.model.decoder.make_decode_layer(attention_mask_enum)[source]#
Create a function to run the decoder with given model parameters.
- Return type:
Callable
[[Unpack
[tuple
[Int[Array, 'num_atoms num_features']
,Float[Array, 'num_atoms num_neighbors num_features']
,Int[Array, 'num_residues num_atoms']
,Bool[Array, 'num_atoms num_atoms']
,PyTree[str, 'P']
,float
]]],Int[Array, 'num_atoms num_features']
] |Callable
[[Unpack
[tuple
[Int[Array, 'num_atoms num_features']
,Float[Array, 'num_atoms num_neighbors num_features']
,Int[Array, 'num_residues num_atoms']
,PyTree[str, 'P']
,float
]]],Int[Array, 'num_atoms num_features']
]- Parameters:
attention_mask_enum (MaskedAttentionEnum)
- prxteinmpnn.model.decoder.setup_decoder(model_parameters, attention_mask_enum, decoding_enum, num_decoder_layers=3)[source]#
Set up the decoder parameters and initial node features.
- Return type:
tuple
[PyTree[str, 'P']
,Callable
[...
,Float[Array, 'num_atoms num_neighbors num_features']
]]- Parameters:
model_parameters (PyTree[str, 'P'])
attention_mask_enum (MaskedAttentionEnum)
decoding_enum (DecodingEnum)
num_decoder_layers (int)
- prxteinmpnn.model.decoder._check_enums(attention_mask_enum, decoding_enum)[source]#
Check if the provided enums are valid.
- Return type:
- Parameters:
attention_mask_enum (MaskedAttentionEnum)
decoding_enum (DecodingEnum)
- prxteinmpnn.model.decoder.make_decoder(model_parameters, attention_mask_enum, decoding_enum=DecodingEnum.UNCONDITIONAL, num_decoder_layers=3, scale=30.0)[source]#
Create a function to run the decoder with given model parameters.
- Return type:
Callable
[[Unpack
[tuple
[Int[Array, 'num_atoms num_features']
,Float[Array, 'num_atoms num_neighbors num_features']
,Int[Array, 'num_residues num_atoms']
]]],Int[Array, 'num_atoms num_features']
] |Callable
[[Unpack
[tuple
[Int[Array, 'num_atoms num_features']
,Float[Array, 'num_atoms num_neighbors num_features']
,Int[Array, 'num_residues num_atoms']
,Bool[Array, 'num_atoms num_atoms']
]]],Int[Array, 'num_atoms num_features']
] |Callable
[[Unpack
[tuple
[Union
[Key[Array, '']
,UInt32[Array, '2']
],Int[Array, 'num_atoms num_features']
,Float[Array, 'num_atoms num_neighbors num_features']
,Int[Array, 'num_atoms num_neighbors']
,Int[Array, 'num_residues num_atoms']
,Bool[Array, 'num_residues num_residues']
,float
]]],tuple
[Float[Array, 'num_residues num_classes']
,Float[Array, 'num_residues num_classes']
]] |Callable
[[Unpack
[tuple
[Int[Array, 'num_atoms num_features']
,Float[Array, 'num_atoms num_neighbors num_features']
,Int[Array, 'num_atoms num_neighbors']
,Int[Array, 'num_residues num_atoms']
,Bool[Array, 'num_residues num_residues']
,Int[Array, 'num_residues']
]]],Int[Array, 'num_atoms num_features']
]- Parameters:
model_parameters (PyTree[str, 'P'])
attention_mask_enum (MaskedAttentionEnum)
decoding_enum (DecodingEnum)
num_decoder_layers (int)
scale (float)