model#

ProteinMPNN implemented in a functional JAX interface.

prxteinmpnn.model.dense_layer(layer_parameters, node_features)[source]#

Apply a dense layer to node features.

Return type:

Int[Array, 'num_atoms num_features']

Parameters:
  • layer_parameters (PyTree[str, 'P'])

  • node_features (Int[Array, 'num_atoms num_features'])

prxteinmpnn.model.extract_features(prng_key, model_parameters, structure_coordinates, mask, residue_index, chain_index, k_neighbors=48, backbone_noise=None)[source]#

Extract features from protein structure coordinates.

Parameters:
  • structure_coordinates (Float[Array, 'num_residues num_atoms 3']) – Atomic coordinates of the protein structure.

  • mask (Int[Array, 'num_residues num_atoms']) – Mask indicating valid atoms in the structure.

  • residue_index (Int[Array, 'num_residues']) – Residue indices for each atom.

  • chain_index (Int[Array, 'num_residues']) – Chain indices for each atom.

  • model_parameters (PyTree[str, 'P']) – Model parameters for the feature extraction.

  • prng_key (Union[Key[Array, ''], UInt32[Array, '2']]) – JAX random key for stochastic operations.

  • k_neighbors (int) – Maximum number of neighbors to consider for each atom.

  • backbone_noise (Float[Array, 'n'] | None) – Standard deviation for Gaussian noise augmentation.

Returns:

Edge features after concatenation and normalization. edge_indices: Indices of neighboring atoms.

Return type:

edge_features

prxteinmpnn.model.final_projection(model_parameters, node_features)[source]#

Convert node features to logits.

Parameters:
  • model_parameters (PyTree[str, 'P']) – Model parameters for the final projection.

  • node_features (Int[Array, 'num_atoms num_features']) – Node features after the last MPNN layer.

Returns:

The final logits for the model.

Return type:

Logits

prxteinmpnn.model.make_decoder(model_parameters, attention_mask_type, decoding_approach='unconditional', num_decoder_layers=3, scale=30.0)[source]#

Create a function to run the decoder with given model parameters.

Return type:

Callable[[Unpack[tuple[Int[Array, 'num_atoms num_features'], Float[Array, 'num_atoms num_neighbors num_features'], Int[Array, 'num_residues num_atoms']]]], Int[Array, 'num_atoms num_features']] | Callable[[Unpack[tuple[Int[Array, 'num_atoms num_features'], Float[Array, 'num_atoms num_neighbors num_features'], Int[Array, 'num_residues num_atoms'], Bool[Array, 'num_atoms num_atoms']]]], Int[Array, 'num_atoms num_features']] | Callable[[Unpack[tuple[Union[Key[Array, ''], UInt32[Array, '2']], Int[Array, 'num_atoms num_features'], Float[Array, 'num_atoms num_neighbors num_features'], Int[Array, 'num_atoms num_neighbors'], Int[Array, 'num_residues num_atoms'], Bool[Array, 'num_residues num_residues'], Float | None, Float[Array, 'num_residues num_classes'] | None]]], tuple[Float[Array, 'num_residues num_classes'], Float[Array, 'num_residues num_classes']]] | Callable[[Unpack[tuple[Int[Array, 'num_atoms num_features'], Float[Array, 'num_atoms num_neighbors num_features'], Int[Array, 'num_atoms num_neighbors'], Int[Array, 'num_residues num_atoms'], Bool[Array, 'num_residues num_residues'], Int[Array, 'num_residues']]]], Int[Array, 'num_atoms num_features']]

Parameters:
  • model_parameters (ModelParameters)

  • attention_mask_type (MaskedAttentionType | None)

  • decoding_approach (DecodingApproach)

  • num_decoder_layers (int)

  • scale (float)

prxteinmpnn.model.make_encoder(model_parameters, attention_mask_type=None, num_encoder_layers=3, scale=30.0)[source]#

Create a function to run the encoder with given model parameters.

Return type:

Callable[..., tuple[Int[Array, 'num_atoms num_features'], Float[Array, 'num_atoms num_neighbors num_features']]]

Parameters:
  • model_parameters (ModelParameters)

  • attention_mask_type (MaskedAttentionType | None)

  • num_encoder_layers (int)

  • scale (float)