masked_attention#
Masked attention functions.
- class prxteinmpnn.model.masked_attention.MaskedAttentionEnum(value)[source]#
Bases:
Enum
Enum for different types of masked attention.
- NONE = 'none'#
- CROSS = 'cross'#
- CONDITIONAL = 'conditional'#
- prxteinmpnn.model.masked_attention.mask_attention(message, attention_mask)[source]#
Apply attention mask to the message.
- Parameters:
message (
Float[Array, 'num_atoms num_neighbors num_features']
) – The message to be masked.attention_mask (
Bool[Array, 'num_atoms num_atoms']
) – The attention mask to apply.
- Return type:
Float[Array, 'num_atoms num_neighbors num_features']
- Returns:
The masked message.