Source code for prxteinmpnn.utils.data_structures

"""Dataclasses for the PrxteinMPNN project.

prxteinmpnn.utils.data_structures
"""

from __future__ import annotations

from collections.abc import Generator, Sequence
from typing import TYPE_CHECKING, Literal, NamedTuple

import jax.numpy as jnp
from flax.struct import dataclass

if TYPE_CHECKING:
  import numpy as np
  from jaxtyping import Int

  from prxteinmpnn.utils.types import (
    AtomMask,
    BackboneDihedrals,
    ChainIndex,
    OneHotProteinSequence,
    ProteinSequence,
    ResidueIndex,
    StructureAtomicCoordinates,
  )

from dataclasses import dataclass as dc


[docs] class ProteinTuple(NamedTuple): """Tuple-based protein structure representation. Attributes: coordinates (StructureAtomicCoordinates): Atom positions in the structure, represented as a 3D array. Cartesian coordinates of atoms in angstroms. The atom types correspond to residue_constants.atom_types, i.e. the first three are N, CA, CB. Shape is (num_res, num_atom_type, 3), where num_res is the number of residues, num_atom_type is the number of atom types (e.g., N, CA, CB, C, O), and 3 is the spatial dimension (x, y, z). aatype (ProteinSequence): Amino-acid type for each residue represented as an integer between 0 and 20, where 20 is 'X'. Shape is [num_res]. atom_mask (AtomMask): Binary float mask to indicate presence of a particular atom. 1.0 if an atom is present and 0.0 if not. This should be used for loss masking. Shape is [num_res, num_atom_type]. residue_index (ResidueIndex): Residue index as used in PDB. It is not necessarily continuous or 0-indexed. Shape is [num_res]. chain_index (ChainIndex): Chain index for each residue. Shape is [num_res]. dihedrals (BackboneDihedrals | None): Dihedral angles for backbone atoms (phi, psi, omega). Shape is [num_res, 3]. If not provided, defaults to None. """ coordinates: np.ndarray aatype: np.ndarray atom_mask: np.ndarray residue_index: np.ndarray chain_index: np.ndarray full_coordinates: np.ndarray | None = None dihedrals: np.ndarray | None = None source: str | None = None mapping: np.ndarray | None = None
[docs] @dc class TrajectoryStaticFeatures: """A container for pre-computed, frame-invariant protein features.""" aatype: np.ndarray static_atom_mask_37: np.ndarray residue_indices: np.ndarray chain_index: np.ndarray valid_atom_mask: np.ndarray nitrogen_mask: np.ndarray num_residues: int
[docs] @dataclass(frozen=True) class Protein: """Protein structure or ensemble representation. Attributes: coordinates (StructureAtomicCoordinates): Atom positions in the structure, represented as a 3D array. Cartesian coordinates of atoms in angstroms. The atom types correspond to residue_constants.atom_types, i.e. the first three are N, CA, CB. Shape is (num_res, num_atom_type, 3), where num_res is the number of residues, num_atom_type is the number of atom types (e.g., N, CA, CB, C, O), and 3 is the spatial dimension (x, y, z). aatype (Sequence): Amino-acid type for each residue represented as an integer between 0 and 20, where 20 is 'X'. Shape is [num_res]. atom_mask (AtomMask): Binary float mask to indicate presence of a particular atom. 1.0 if an atom is present and 0.0 if not. This should be used for loss masking. Shape is [num_res, num_atom_type]. residue_index (AtomResidueIndex): Residue index as used in PDB. It is not necessarily continuous or 0-indexed. Shape is [num_res]. """ coordinates: StructureAtomicCoordinates aatype: ProteinSequence one_hot_sequence: OneHotProteinSequence atom_mask: AtomMask residue_index: ResidueIndex chain_index: ChainIndex dihedrals: BackboneDihedrals | None = None mapping: Int | None = None full_coordinates: StructureAtomicCoordinates | None = None
[docs] @classmethod def from_tuple(cls, protein_tuple: ProteinTuple) -> Protein: """Create a Protein instance from a ProteinTuple. Args: protein_tuple (ProteinTuple): The input protein tuple. Returns: Protein: The output protein dataclass. """ return cls( coordinates=jnp.asarray(protein_tuple.coordinates, dtype=jnp.float32), aatype=jnp.asarray(protein_tuple.aatype, dtype=jnp.int8), one_hot_sequence=jnp.eye(21)[protein_tuple.aatype], atom_mask=jnp.asarray(protein_tuple.atom_mask, dtype=jnp.float32), residue_index=jnp.asarray(protein_tuple.residue_index, dtype=jnp.int32), chain_index=jnp.asarray(protein_tuple.chain_index, dtype=jnp.int32), dihedrals=( None if protein_tuple.dihedrals is None else jnp.asarray(protein_tuple.dihedrals, dtype=jnp.float64) ), mapping=jnp.asarray(protein_tuple.mapping, dtype=jnp.int32) if protein_tuple.mapping is not None else None, full_coordinates=( None if protein_tuple.full_coordinates is None else jnp.asarray(protein_tuple.full_coordinates, dtype=jnp.float32) ), )
ProteinStream = Generator[ProteinTuple, None] ProteinBatch = Sequence[Protein]
[docs] @dataclass(frozen=True) class ProteinEnsemble: """Protein structure or ensemble representation. Attributes: coordinates (StructureAtomicCoordinates): Atom positions in the structure, represented as a 3D array. Cartesian coordinates of atoms in angstroms. The atom types correspond to residue_constants.atom_types, i.e. the first three are N, CA, CB. Shape is (num_res, num_atom_type, 3), where num_res is the number of residues, num_atom_type is the number of atom types (e.g., N, CA, CB, C, O), and 3 is the spatial dimension (x, y, z). aatype (Sequence): Amino-acid type for each residue represented as an integer between 0 and 20, where 20 is 'X'. Shape is [num_res]. atom_mask (AtomMask): Binary float mask to indicate presence of a particular atom. 1.0 if an atom is present and 0.0 if not. This should be used for loss masking. Shape is [num_res, num_atom_type]. residue_index (AtomResidueIndex): Residue index as used in PDB. It is not necessarily continuous or 0-indexed. Shape is [num_res]. chain_index (ChainIndex): Chain index for each residue. Shape is [num_res]. dihedrals (BackboneDihedrals | None): Dihedral angles for backbone atoms (phi, psi, omega). Shape is [num_res, 3]. If not provided, defaults to None. mapping (jnp.Array | None): Optional array mapping residues in the ensemble to original structure indices. Shape is [num_res, num_frames]. If not provided, defaults to None. """ coordinates: StructureAtomicCoordinates aatype: ProteinSequence one_hot_sequence: OneHotProteinSequence atom_mask: AtomMask residue_index: ResidueIndex chain_index: ChainIndex dihedrals: BackboneDihedrals | None = None mapping: Int | None = None
OligomerType = Literal["monomer", "heteromer", "homooligomer", "tied_homooligomer"]