Source code for prxteinmpnn.utils.data_structures

"""Dataclasses and enums for the PrxteinMPNN project.

prxteinmpnn.utils.data_structures
"""

from __future__ import annotations

import enum
from collections.abc import Iterator
from typing import TYPE_CHECKING

import jax
import jax.numpy as jnp
from flax.struct import dataclass, field

if TYPE_CHECKING:
  from jaxtyping import PRNGKeyArray

  from prxteinmpnn.utils.types import (
    AlphaCarbonMask,
    AtomMask,
    BackboneDihedrals,
    ChainIndex,
    InputBias,
    InputLengths,
    ProteinSequence,
    ResidueIndex,
    SamplingHyperparameters,
    StructureAtomicCoordinates,
  )


[docs] @dataclass(frozen=True) class ProteinStructure: """Protein structure representation. Attributes: coordinates (StructureAtomicCoordinates): Atom positions in the structure, represented as a 3D array. Cartesian coordinates of atoms in angstroms. The atom types correspond to residue_constants.atom_types, i.e. the first three are N, CA, CB. Shape is (num_res, num_atom_type, 3), where num_res is the number of residues, num_atom_type is the number of atom types (e.g., N, CA, CB, C, O), and 3 is the spatial dimension (x, y, z). aatype (Sequence): Amino-acid type for each residue represented as an integer between 0 and 20, where 20 is 'X'. Shape is [num_res]. atom_mask (AtomMask): Binary float mask to indicate presence of a particular atom. 1.0 if an atom is present and 0.0 if not. This should be used for loss masking. Shape is [num_res, num_atom_type]. residue_index (AtomResidueIndex): Residue index as used in PDB. It is not necessarily continuous or 0-indexed. Shape is [num_res]. """ coordinates: StructureAtomicCoordinates aatype: ProteinSequence atom_mask: AtomMask residue_index: ResidueIndex chain_index: ChainIndex dihedrals: BackboneDihedrals | None = None
ProteinEnsemble = Iterator["ProteinStructure"]
[docs] @dataclass(frozen=True) class ModelInputs: """Dataclass for general model inputs. Note that any of these can be stacked together to form a batch of inputs. Attributes: structure_coordinates (StructureAtomicCoordinates): Atomic coordinates of the structure. Shape is (num_residues, num_atoms, 3), where num_residues is the number of residues and num_atoms is the number of atoms per residue (e.g., 37 for standard residues). sequence (Sequence): Sequence of amino acids as an array of integers. Shape is (num_residues,), where num_residues is the number of residues. mask (AtomMask): Mask for the model input, indicating valid atoms in the structure. Shape is (num_residues, num_atoms), where num_atoms is the number of atoms per residue. residue_index (ResidueIndex): Index of residues in the structure, used for mapping atoms in structures to their residues. Shape is (num_residues,). chain_index (ChainIndex): Index of chains in the structure, used for mapping atoms in structures to their chains. Shape is (num_residues,). lengths (InputLengths): Lengths of the sequences in the batch, used for padding and batching. Shape is (num_sequences,), where num_sequences is the number of sequences in the batch. bias (InputBias): Bias for the model input, used for classification tasks. Defaults to zero bias of shape (sum(lengths), 20), where 20 is the number of amino acid types. Shape is (sum(lengths), num_classes), where num_classes is the number of classes (e.g., 20 for amino acids). """ structure_coordinates: StructureAtomicCoordinates = field(default_factory=lambda: jnp.array([])) """Structure atomic coordinates for the model input.""" sequence: ProteinSequence = field(default_factory=lambda: jnp.array([])) """A sequence of amino acids for the model input. As MPNN-alphabet based array of integers.""" mask: AlphaCarbonMask = field(default_factory=lambda: jnp.array([])) """Mask for the model input, indicating valid atoms structure.""" residue_index: ResidueIndex = field(default_factory=lambda: jnp.array([])) """Index of residues in the structure, used for mapping atoms in structures to their residues.""" chain_index: ChainIndex = field(default_factory=lambda: jnp.array([])) """Index of chains in the structure, used for mapping atoms in structures to their chains.""" lengths: InputLengths = field(default_factory=lambda: jnp.array([])) """Lengths of the sequences in the batch, used for padding and batching.""" bias: InputBias | None = field(default_factory=lambda: None) """Bias for the model input, used for classification tasks. Defaults to zero bias of shape (sum(lengths), 20).""" k_neighbors: int = 48 """Number of neighbors to consider for each atom in the structure.""" augment_eps: float = 0.0 """Epsilon value for adding noise to the backbone coordinates, used for data augmentation."""
[docs] @dataclass(frozen=True) class SamplingInputs(ModelInputs): """Dataclass for inputs used in sequence sampling. Attributes: prng_key (PRNGKeyArray): Random key for JAX operations. initial_sequence (ProteinSequence): Initial sequence of amino acids as an array of integers. structure_coordinates (StructureAtomicCoordinates): Atomic coordinates of the structure. mask (AtomMask): Mask indicating valid atoms in the structure. residue_index (AtomResidueIndex): Index of residues in the structure. chain_index (AtomChainIndex): Index of chains in the structure. bias (InputBias | None): Bias for the model input, used for classification tasks. k_neighbors (int): Number of neighbors to consider for each atom. augment_eps (float): Epsilon value for adding noise to the backbone coordinates. hyperparameters (SamplingHyperparameters): Hyperparameters for sampling, e.g., temperature, top-k, etc. iterations (int): Number of iterations for sampling. """ prng_key: PRNGKeyArray = field(default_factory=lambda: jax.random.PRNGKey(0)) """Random key for JAX operations.""" hyperparameters: SamplingHyperparameters = (0.0,) """Hyperparameters for sampling, e.g., temperature, top-k, etc.""" iterations: int = 1 """Number of iterations for sampling."""
[docs] class OligomerType(enum.Enum): """Enum for different types of oligomers.""" MONOMER = "monomer" HETEROMER = "heteromer" HOMOOLIGOMER = "homooligomer" TIED_HOMOOLIGOMER = "tied_homooligomer"