model#
ProteinMPNN implemented in a functional JAX interface.
- prxteinmpnn.model.dense_layer(layer_parameters, node_features)[source]#
Apply a dense layer to node features.
- Return type:
Int[Array, 'num_atoms num_features']- Parameters:
layer_parameters (PyTree[str, 'P'])
node_features (Int[Array, 'num_atoms num_features'])
- prxteinmpnn.model.extract_features(prng_key, model_parameters, structure_coordinates, mask, residue_index, chain_index, k_neighbors=48, backbone_noise=None)[source]#
Extract features from protein structure coordinates.
- Parameters:
structure_coordinates (
Float[Array, 'num_residues num_atoms 3']) – Atomic coordinates of the protein structure.mask (
Int[Array, 'num_residues num_atoms']) – Mask indicating valid atoms in the structure.residue_index (
Int[Array, 'num_residues']) – Residue indices for each atom.chain_index (
Int[Array, 'num_residues']) – Chain indices for each atom.model_parameters (
PyTree[str, 'P']) – Model parameters for the feature extraction.prng_key (
Union[Key[Array, ''],UInt32[Array, '2']]) – JAX random key for stochastic operations.k_neighbors (
int) – Maximum number of neighbors to consider for each atom.backbone_noise (
Float[Array, 'n']|None) – Standard deviation for Gaussian noise augmentation.
- Returns:
Edge features after concatenation and normalization. edge_indices: Indices of neighboring atoms.
- Return type:
edge_features
- prxteinmpnn.model.final_projection(model_parameters, node_features)[source]#
Convert node features to logits.
- Parameters:
model_parameters (
PyTree[str, 'P']) – Model parameters for the final projection.node_features (
Int[Array, 'num_atoms num_features']) – Node features after the last MPNN layer.
- Returns:
The final logits for the model.
- Return type:
Logits
- prxteinmpnn.model.make_decoder(model_parameters, attention_mask_type, decoding_approach='unconditional', num_decoder_layers=3, scale=30.0)[source]#
Create a function to run the decoder with given model parameters.
- Return type:
Callable[[Unpack[tuple[Int[Array, 'num_atoms num_features'],Float[Array, 'num_atoms num_neighbors num_features'],Int[Array, 'num_residues num_atoms']]]],Int[Array, 'num_atoms num_features']] |Callable[[Unpack[tuple[Int[Array, 'num_atoms num_features'],Float[Array, 'num_atoms num_neighbors num_features'],Int[Array, 'num_residues num_atoms'],Bool[Array, 'num_atoms num_atoms']]]],Int[Array, 'num_atoms num_features']] |Callable[[Unpack[tuple[Union[Key[Array, ''],UInt32[Array, '2']],Int[Array, 'num_atoms num_features'],Float[Array, 'num_atoms num_neighbors num_features'],Int[Array, 'num_atoms num_neighbors'],Int[Array, 'num_residues num_atoms'],Bool[Array, 'num_residues num_residues'],Float|None,Float[Array, 'num_residues num_classes']|None]]],tuple[Float[Array, 'num_residues num_classes'],Float[Array, 'num_residues num_classes']]] |Callable[[Unpack[tuple[Int[Array, 'num_atoms num_features'],Float[Array, 'num_atoms num_neighbors num_features'],Int[Array, 'num_atoms num_neighbors'],Int[Array, 'num_residues num_atoms'],Bool[Array, 'num_residues num_residues'],Int[Array, 'num_residues']]]],Int[Array, 'num_atoms num_features']]- Parameters: