masked_attention#

Masked attention functions.

class prxteinmpnn.model.masked_attention.MaskedAttentionEnum(value)[source]#

Bases: Enum

Enum for different types of masked attention.

NONE = 'none'#
CROSS = 'cross'#
CONDITIONAL = 'conditional'#
prxteinmpnn.model.masked_attention.mask_attention(message, attention_mask)[source]#

Apply attention mask to the message.

Parameters:
  • message (Float[Array, 'num_atoms num_neighbors num_features']) – The message to be masked.

  • attention_mask (Bool[Array, 'num_atoms num_atoms']) – The attention mask to apply.

Return type:

Float[Array, 'num_atoms num_neighbors num_features']

Returns:

The masked message.