masked_attention

Contents

masked_attention#

Masked attention functions.

prxteinmpnn.model.masked_attention.mask_attention(message, attention_mask)[source]#

Apply attention mask to the message.

Parameters:
  • message (Float[Array, 'num_atoms num_neighbors num_features']) – The message to be masked.

  • attention_mask (Bool[Array, 'num_atoms num_atoms']) – The attention mask to apply.

Return type:

Float[Array, 'num_atoms num_neighbors num_features']

Returns:

The masked message.