decoder#

Decoder module for the PrxteinMPNN model.

class prxteinmpnn.model.decoder.DecodingEnum(value)[source]#

Bases: Enum

Enum for different types of decoders.

CONDITIONAL = 'conditional'#
UNCONDITIONAL = 'unconditional'#
AUTOREGRESSIVE = 'autoregressive'#
prxteinmpnn.model.decoder.decoder_parameter_pytree(model_parameters, num_decoder_layers=3)[source]#

Make the model weights accessible as a PyTree.

Parameters:
  • model_parameters (PyTree[str, 'P']) – Model parameters for the decoder.

  • num_decoder_layers (int) – Number of decoder layers to set up.

Return type:

PyTree[str, 'P']

Returns:

Decoder parameters as a PyTree.

prxteinmpnn.model.decoder.embed_sequence(model_parameters, one_hot_sequence)[source]#

Embed a one-hot encoded sequence.

Return type:

Int[Array, 'num_atoms num_features']

Parameters:
  • model_parameters (PyTree[str, 'P'])

  • one_hot_sequence (Float[Array, 'num_residues num_classes'])

prxteinmpnn.model.decoder.initialize_conditional_decoder(one_hot_sequence, node_features, edge_features, neighbor_indices, layer_params)[source]#

Initialize the decoder with node and edge features.

Parameters:
  • one_hot_sequence (Float[Array, 'num_residues num_classes']) – One-hot encoded sequence of shape (num_residues, num_classes).

  • node_features (Int[Array, 'num_atoms num_features']) – Node features of shape (num_atoms, num_features).

  • edge_features (Float[Array, 'num_atoms num_neighbors num_features']) – EdgeFeatures of shape (num_atoms, num_neighbors, num_features).

  • neighbor_indices (Int[Array, 'num_atoms num_neighbors']) – Indices of neighboring nodes of shape (num_atoms, num_neighbors).

  • layer_params (PyTree[str, 'P']) – ModelParameters for the embedding layer.

Return type:

tuple[Float[Array, 'num_atoms num_neighbors num_features'], Float[Array, 'num_residues num_neighbors num_features']]

Returns:

A tuple of node-edge features and sequence-edge features.

prxteinmpnn.model.decoder.decode_message(node_features, edge_features, layer_params)[source]#

Decode node and edge features into messages.

Parameters:
  • node_features (Int[Array, 'num_atoms num_features']) – Node features of shape (num_atoms, num_features).

  • edge_features (Float[Array, 'num_atoms num_neighbors num_features']) – Edge features of shape (num_atoms, num_neighbors, num_features).

  • layer_params (PyTree[str, 'P']) – Model parameters for the encoding layer.

Returns:

decoded messages of shape (num_atoms, num_neighbors, num_features).

Return type:

Message

prxteinmpnn.model.decoder.decoder_normalize(message, node_features, mask, layer_params, scale=30.0)[source]#

Normalize the decoded messages and update node features.

Parameters:
  • message (Float[Array, 'num_atoms num_neighbors num_features']) – decoded messages of shape (num_atoms, num_neighbors, num_features).

  • node_features (Int[Array, 'num_atoms num_features']) – Node features of shape (num_atoms, num_features).

  • mask (Int[Array, 'num_residues num_atoms']) – Atom mask indicating valid atoms.

  • layer_params (PyTree[str, 'P']) – Model parameters for the normalization layer.

  • scale (float) – Scaling factor for normalization.

Return type:

Int[Array, 'num_atoms num_features']

Returns:

Updated node features after normalization.

prxteinmpnn.model.decoder.make_decode_layer(attention_mask_enum)[source]#

Create a function to run the decoder with given model parameters.

Return type:

Callable[[Unpack[tuple[Int[Array, 'num_atoms num_features'], Float[Array, 'num_atoms num_neighbors num_features'], Int[Array, 'num_residues num_atoms'], Bool[Array, 'num_atoms num_atoms'], PyTree[str, 'P'], float]]], Int[Array, 'num_atoms num_features']] | Callable[[Unpack[tuple[Int[Array, 'num_atoms num_features'], Float[Array, 'num_atoms num_neighbors num_features'], Int[Array, 'num_residues num_atoms'], PyTree[str, 'P'], float]]], Int[Array, 'num_atoms num_features']]

Parameters:

attention_mask_enum (MaskedAttentionEnum)

prxteinmpnn.model.decoder.setup_decoder(model_parameters, attention_mask_enum, decoding_enum, num_decoder_layers=3)[source]#

Set up the decoder parameters and initial node features.

Return type:

tuple[PyTree[str, 'P'], Callable[..., Float[Array, 'num_atoms num_neighbors num_features']]]

Parameters:
prxteinmpnn.model.decoder._check_enums(attention_mask_enum, decoding_enum)[source]#

Check if the provided enums are valid.

Return type:

None

Parameters:
prxteinmpnn.model.decoder.make_decoder(model_parameters, attention_mask_enum, decoding_enum=DecodingEnum.UNCONDITIONAL, num_decoder_layers=3, scale=30.0)[source]#

Create a function to run the decoder with given model parameters.

Return type:

Callable[[Unpack[tuple[Int[Array, 'num_atoms num_features'], Float[Array, 'num_atoms num_neighbors num_features'], Int[Array, 'num_residues num_atoms']]]], Int[Array, 'num_atoms num_features']] | Callable[[Unpack[tuple[Int[Array, 'num_atoms num_features'], Float[Array, 'num_atoms num_neighbors num_features'], Int[Array, 'num_residues num_atoms'], Bool[Array, 'num_atoms num_atoms']]]], Int[Array, 'num_atoms num_features']] | Callable[[Unpack[tuple[Union[Key[Array, ''], UInt32[Array, '2']], Int[Array, 'num_atoms num_features'], Float[Array, 'num_atoms num_neighbors num_features'], Int[Array, 'num_atoms num_neighbors'], Int[Array, 'num_residues num_atoms'], Bool[Array, 'num_residues num_residues'], float]]], tuple[Float[Array, 'num_residues num_classes'], Float[Array, 'num_residues num_classes']]] | Callable[[Unpack[tuple[Int[Array, 'num_atoms num_features'], Float[Array, 'num_atoms num_neighbors num_features'], Int[Array, 'num_atoms num_neighbors'], Int[Array, 'num_residues num_atoms'], Bool[Array, 'num_residues num_residues'], Int[Array, 'num_residues']]]], Int[Array, 'num_atoms num_features']]

Parameters: