model#
ProteinMPNN implemented in a functional JAX interface.
- class prxteinmpnn.model.MaskedAttentionEnum(value)[source]#
Bases:
Enum
Enum for different types of masked attention.
- NONE = 'none'#
- CROSS = 'cross'#
- CONDITIONAL = 'conditional'#
- prxteinmpnn.model.dense_layer(layer_parameters, node_features)[source]#
Apply a dense layer to node features.
- Return type:
Int[Array, 'num_atoms num_features']
- Parameters:
layer_parameters (PyTree[str, 'P'])
node_features (Int[Array, 'num_atoms num_features'])
- prxteinmpnn.model.extract_features(prng_key, model_parameters, structure_coordinates, mask, residue_index, chain_index, k_neighbors=48, augment_eps=0.0)[source]#
Extract features from protein structure coordinates.
- Parameters:
structure_coordinates (
Float[Array, 'num_residues num_atoms 3']
) – Atomic coordinates of the protein structure.mask (
Int[Array, 'num_residues num_atoms']
) – Mask indicating valid atoms in the structure.residue_index (
Int[Array, 'num_residues']
) – Residue indices for each atom.chain_index (
Int[Array, 'num_residues']
) – Chain indices for each atom.model_parameters (
PyTree[str, 'P']
) – Model parameters for the feature extraction.prng_key (
Union
[Key[Array, '']
,UInt32[Array, '2']
]) – JAX random key for stochastic operations.k_neighbors (
int
) – Maximum number of neighbors to consider for each atom.augment_eps (
float
) – Standard deviation for Gaussian noise augmentation.
- Returns:
Edge features after concatenation and normalization. edge_indices: Indices of neighboring atoms.
- Return type:
edge_features
- prxteinmpnn.model.final_projection(model_parameters, node_features)[source]#
Convert node features to logits.
- Parameters:
model_parameters (
PyTree[str, 'P']
) – Model parameters for the final projection.node_features (
Int[Array, 'num_atoms num_features']
) – Node features after the last MPNN layer.
- Returns:
The final logits for the model.
- Return type:
Logits
- prxteinmpnn.model.make_decoder(model_parameters, attention_mask_enum, decoding_enum=DecodingEnum.UNCONDITIONAL, num_decoder_layers=3, scale=30.0)[source]#
Create a function to run the decoder with given model parameters.
- Return type:
Callable
[[Unpack
[tuple
[Int[Array, 'num_atoms num_features']
,Float[Array, 'num_atoms num_neighbors num_features']
,Int[Array, 'num_residues num_atoms']
]]],Int[Array, 'num_atoms num_features']
] |Callable
[[Unpack
[tuple
[Int[Array, 'num_atoms num_features']
,Float[Array, 'num_atoms num_neighbors num_features']
,Int[Array, 'num_residues num_atoms']
,Bool[Array, 'num_atoms num_atoms']
]]],Int[Array, 'num_atoms num_features']
] |Callable
[[Unpack
[tuple
[Union
[Key[Array, '']
,UInt32[Array, '2']
],Int[Array, 'num_atoms num_features']
,Float[Array, 'num_atoms num_neighbors num_features']
,Int[Array, 'num_atoms num_neighbors']
,Int[Array, 'num_residues num_atoms']
,Bool[Array, 'num_residues num_residues']
,float
]]],tuple
[Float[Array, 'num_residues num_classes']
,Float[Array, 'num_residues num_classes']
]] |Callable
[[Unpack
[tuple
[Int[Array, 'num_atoms num_features']
,Float[Array, 'num_atoms num_neighbors num_features']
,Int[Array, 'num_atoms num_neighbors']
,Int[Array, 'num_residues num_atoms']
,Bool[Array, 'num_residues num_residues']
,Int[Array, 'num_residues']
]]],Int[Array, 'num_atoms num_features']
]- Parameters:
model_parameters (PyTree[str, 'P'])
attention_mask_enum (MaskedAttentionEnum)
decoding_enum (DecodingEnum)
num_decoder_layers (int)
scale (float)