Source code for prxteinmpnn.model.masked_attention
"""Masked attention functions."""
import enum
import jax
import jax.numpy as jnp
from prxteinmpnn.utils.types import AttentionMask, Message
[docs]
class MaskedAttentionEnum(enum.Enum):
"""Enum for different types of masked attention."""
NONE = "none"
CROSS = "cross"
CONDITIONAL = "conditional"
[docs]
@jax.jit
def mask_attention(message: Message, attention_mask: AttentionMask) -> Message:
"""Apply attention mask to the message.
Args:
message: The message to be masked.
attention_mask: The attention mask to apply.
Returns:
The masked message.
"""
return jnp.expand_dims(attention_mask, -1) * message