model#

ProteinMPNN implemented in a functional JAX interface.

class prxteinmpnn.model.MaskedAttentionEnum(value)[source]#

Bases: Enum

Enum for different types of masked attention.

NONE = 'none'#
CROSS = 'cross'#
CONDITIONAL = 'conditional'#
prxteinmpnn.model.dense_layer(layer_parameters, node_features)[source]#

Apply a dense layer to node features.

Return type:

Int[Array, 'num_atoms num_features']

Parameters:
  • layer_parameters (PyTree[str, 'P'])

  • node_features (Int[Array, 'num_atoms num_features'])

prxteinmpnn.model.extract_features(prng_key, model_parameters, structure_coordinates, mask, residue_index, chain_index, k_neighbors=48, augment_eps=0.0)[source]#

Extract features from protein structure coordinates.

Parameters:
  • structure_coordinates (Float[Array, 'num_residues num_atoms 3']) – Atomic coordinates of the protein structure.

  • mask (Int[Array, 'num_residues num_atoms']) – Mask indicating valid atoms in the structure.

  • residue_index (Int[Array, 'num_residues']) – Residue indices for each atom.

  • chain_index (Int[Array, 'num_residues']) – Chain indices for each atom.

  • model_parameters (PyTree[str, 'P']) – Model parameters for the feature extraction.

  • prng_key (Union[Key[Array, ''], UInt32[Array, '2']]) – JAX random key for stochastic operations.

  • k_neighbors (int) – Maximum number of neighbors to consider for each atom.

  • augment_eps (float) – Standard deviation for Gaussian noise augmentation.

Returns:

Edge features after concatenation and normalization. edge_indices: Indices of neighboring atoms.

Return type:

edge_features

prxteinmpnn.model.final_projection(model_parameters, node_features)[source]#

Convert node features to logits.

Parameters:
  • model_parameters (PyTree[str, 'P']) – Model parameters for the final projection.

  • node_features (Int[Array, 'num_atoms num_features']) – Node features after the last MPNN layer.

Returns:

The final logits for the model.

Return type:

Logits

prxteinmpnn.model.make_decoder(model_parameters, attention_mask_enum, decoding_enum=DecodingEnum.UNCONDITIONAL, num_decoder_layers=3, scale=30.0)[source]#

Create a function to run the decoder with given model parameters.

Return type:

Callable[[Unpack[tuple[Int[Array, 'num_atoms num_features'], Float[Array, 'num_atoms num_neighbors num_features'], Int[Array, 'num_residues num_atoms']]]], Int[Array, 'num_atoms num_features']] | Callable[[Unpack[tuple[Int[Array, 'num_atoms num_features'], Float[Array, 'num_atoms num_neighbors num_features'], Int[Array, 'num_residues num_atoms'], Bool[Array, 'num_atoms num_atoms']]]], Int[Array, 'num_atoms num_features']] | Callable[[Unpack[tuple[Union[Key[Array, ''], UInt32[Array, '2']], Int[Array, 'num_atoms num_features'], Float[Array, 'num_atoms num_neighbors num_features'], Int[Array, 'num_atoms num_neighbors'], Int[Array, 'num_residues num_atoms'], Bool[Array, 'num_residues num_residues'], float]]], tuple[Float[Array, 'num_residues num_classes'], Float[Array, 'num_residues num_classes']]] | Callable[[Unpack[tuple[Int[Array, 'num_atoms num_features'], Float[Array, 'num_atoms num_neighbors num_features'], Int[Array, 'num_atoms num_neighbors'], Int[Array, 'num_residues num_atoms'], Bool[Array, 'num_residues num_residues'], Int[Array, 'num_residues']]]], Int[Array, 'num_atoms num_features']]

Parameters:
prxteinmpnn.model.make_encoder(model_parameters, attention_mask_enum, num_encoder_layers=3, scale=30.0)[source]#

Create a function to run the encoder with given model parameters.

Return type:

Callable[..., tuple[Int[Array, 'num_atoms num_features'], Float[Array, 'num_atoms num_neighbors num_features']]]

Parameters: