"""Dataclasses for the PrxteinMPNN project.
prxteinmpnn.utils.data_structures
"""
from __future__ import annotations
from collections.abc import Generator, Sequence
from typing import TYPE_CHECKING, Literal, NamedTuple
import jax.numpy as jnp
from flax.struct import dataclass
from prxteinmpnn.utils.residue_constants import atom_order
if TYPE_CHECKING:
import numpy as np
from jaxtyping import Int
from prxteinmpnn.utils.types import (
BIC,
AlphaCarbonMask,
AtomMask,
BackboneDihedrals,
ChainIndex,
ComponentCounts,
Converged,
Covariances,
EnsembleData,
LogLikelihood,
Means,
OneHotProteinSequence,
ProteinSequence,
ResidueIndex,
Responsibilities,
StructureAtomicCoordinates,
Weights,
)
from dataclasses import dataclass as dc
[docs]
class ProteinTuple(NamedTuple):
"""Tuple-based protein structure representation.
Attributes:
coordinates (StructureAtomicCoordinates): Atom positions in the structure, represented as a
3D array. Cartesian coordinates of atoms in angstroms.
The atom types correspond to residue_constants.atom_types, i.e. the first three are N, CA, CB.
Shape is (num_res, num_atom_type, 3), where num_res is the number of residues,
num_atom_type is the number of atom types (e.g., N, CA, CB, C, O), and 3 is the spatial
dimension (x, y, z).
aatype (ProteinSequence): Amino-acid type for each residue represented as an integer between 0
and 20,
where 20 is 'X'. Shape is [num_res].
atom_mask (AtomMask): Binary float mask to indicate presence of a particular atom.
1.0 if an atom is present and 0.0 if not. This should be used for loss masking.
Shape is [num_res, num_atom_type].
residue_index (ResidueIndex): Residue index as used in PDB. It is not necessarily
continuous or 0-indexed. Shape is [num_res].
chain_index (ChainIndex): Chain index for each residue. Shape is [num_res].
dihedrals (BackboneDihedrals | None): Dihedral angles for backbone atoms (phi, psi, omega).
Shape is [num_res, 3]. If not provided, defaults to None.
"""
coordinates: np.ndarray
aatype: np.ndarray
atom_mask: np.ndarray
residue_index: np.ndarray
chain_index: np.ndarray
full_coordinates: np.ndarray | None = None
dihedrals: np.ndarray | None = None
source: str | None = None
mapping: np.ndarray | None = None
charges: np.ndarray | None = None
radii: np.ndarray | None = None
estat_backbone_mask: np.ndarray | None = None
estat_resid: np.ndarray | None = None
estat_chain_index: np.ndarray | None = None
[docs]
@dc
class TrajectoryStaticFeatures:
"""A container for pre-computed, frame-invariant protein features."""
aatype: np.ndarray
static_atom_mask_37: np.ndarray
residue_indices: np.ndarray
chain_index: np.ndarray
valid_atom_mask: np.ndarray
nitrogen_mask: np.ndarray
num_residues: int
[docs]
def include_feature(feature_name: str, include_features: Sequence[str] | None) -> bool:
"""Determine if a feature should be included.
Args:
feature_name (str): The name of the feature to check.
include_features (Sequence[str] | None): The list of features to include.
If None, no features are included.
Returns:
bool: True if the feature should be included, False otherwise.
"""
if include_features is None:
return False
return feature_name in include_features or "all" in include_features
[docs]
@dataclass
class EstatInfo:
"""Electrostatics information extracted from a PQR file.
Attributes:
charges: Numpy array of atomic charges.
radii: Numpy array of atomic radii.
estat_backbone_mask: Boolean numpy array indicating backbone atoms.
estat_resid: Integer numpy array of residue numbers.
estat_chain_index: Integer numpy array of chain indices (ord value).
"""
charges: np.ndarray
radii: np.ndarray
estat_backbone_mask: np.ndarray
estat_resid: np.ndarray
estat_chain_index: np.ndarray
[docs]
@dataclass(frozen=True)
class Protein:
"""Protein structure or ensemble representation.
Attributes:
coordinates (StructureAtomicCoordinates): Atom positions in the structure, represented as a
3D array. Cartesian coordinates of atoms in angstroms. The atom types correspond to
residue_constants.atom_types, i.e. the first three are N, CA, CB. Shape is
(num_res, num_atom_type, 3), where num_res is the number of residues, num_atom_type is the
number of atom types (e.g., N, CA, CB, C, O), and 3 is the spatial dimension (x, y, z).
aatype (Sequence): Amino-acid type for each residue represented as an integer between 0 and 20,
where 20 is 'X'. Shape is [num_res].
mask (AlphaCarbonMask): Binary float mask to indicate presence of alpha carbon atom.
1.0 if an atom is present and 0.0 if not. This should be used for loss masking.
Shape is [num_res, num_atom_type].
residue_index (AtomResidueIndex): Residue index as used in PDB. It is not necessarily
continuous or 0-indexed. Shape is [num_res].
chain_index (ChainIndex): Chain index for each residue. Shape is [num_res].
dihedrals (BackboneDihedrals | None): Dihedral angles for backbone atoms (phi, psi, omega).
Shape is [num_res, 3]. If not provided, defaults to None.
mapping (jnp.Array | None): Optional array mapping residues in the ensemble to original
structure indices. Shape is [num_res, num_frames]. If not provided, defaults to None.
full_coordinates (StructureAtomicCoordinates | None): Full atomic coordinates
including all heavy atoms. Shape is (num_res, num_full_atom_type, 3), where num_full_atom_type
is the number of all heavy atom types (e.g., N, CA, CB, C, O, CG, etc.), and 3 is the spatial
dimension (x, y, z). If not provided, defaults to None.
full_atom_mask (AtomMask | None): Binary float mask to indicate presence of a particular
heavy atom. 1.0 if an atom is present and 0.0 if not. This should be used for loss masking.
Shape is [num_res, num_full_atom_type]. If not provided, defaults to None.
"""
coordinates: StructureAtomicCoordinates
aatype: ProteinSequence
one_hot_sequence: OneHotProteinSequence
mask: AlphaCarbonMask
residue_index: ResidueIndex
chain_index: ChainIndex
dihedrals: BackboneDihedrals | None = None
mapping: Int | None = None
full_coordinates: StructureAtomicCoordinates | None = None
full_atom_mask: AtomMask | None = None
charges: np.ndarray | None = None
radii: np.ndarray | None = None
estat_backbone_mask: np.ndarray | None = None
estat_resid: np.ndarray | None = None
estat_chain_index: np.ndarray | None = None
[docs]
@classmethod
def from_tuple(
cls,
protein_tuple: ProteinTuple,
*,
include_extras: Sequence[
Literal["dihedrals", "mapping", "full_coordinates", "full_atom_mask", "all"]
]
| None = None,
) -> Protein:
"""Create a Protein instance from a ProteinTuple.
Args:
protein_tuple (ProteinTuple): The input protein tuple.
include_extras:
Optional list of extra fields to include from the tuple.
If 'all' is included, all optional fields will be included.
If None, no optional fields will be included.
Returns:
Protein: The output protein dataclass.
"""
return cls(
coordinates=jnp.asarray(protein_tuple.coordinates, dtype=jnp.float32),
aatype=jnp.asarray(protein_tuple.aatype, dtype=jnp.int8),
one_hot_sequence=jnp.eye(21)[protein_tuple.aatype],
mask=jnp.asarray(protein_tuple.atom_mask[:, atom_order["CA"]], dtype=jnp.float32),
residue_index=jnp.asarray(protein_tuple.residue_index, dtype=jnp.int32),
chain_index=jnp.asarray(protein_tuple.chain_index, dtype=jnp.int32),
dihedrals=(
None
if protein_tuple.dihedrals is None or not include_feature("dihedrals", include_extras)
else jnp.asarray(protein_tuple.dihedrals, dtype=jnp.float64)
),
mapping=jnp.asarray(protein_tuple.mapping, dtype=jnp.int32)
if protein_tuple.mapping is not None
and include_extras is not None
and ("mapping" in include_extras or "all" in include_extras)
else None,
full_coordinates=(
None
if protein_tuple.full_coordinates is None
or not include_feature("full_coordinates", include_extras)
else jnp.asarray(protein_tuple.full_coordinates, dtype=jnp.float32)
),
full_atom_mask=(
None
if protein_tuple.full_coordinates is None
or not include_feature("full_atom_mask", include_extras)
else jnp.asarray(protein_tuple.atom_mask, dtype=jnp.float32)
),
charges=protein_tuple.charges,
radii=protein_tuple.radii,
estat_backbone_mask=protein_tuple.estat_backbone_mask,
estat_resid=protein_tuple.estat_resid,
estat_chain_index=protein_tuple.estat_chain_index,
)
ProteinStream = Generator[ProteinTuple, None]
ProteinBatch = Sequence[Protein]
[docs]
@dataclass
class _EStepState:
"""State for accumulating statistics during the E-step."""
component_counts: ComponentCounts
weighted_data: EnsembleData
weighted_squared_data: EnsembleData
log_likelihood_total: LogLikelihood
[docs]
@dataclass
class GMM:
"""Dataclass to hold GMM parameters."""
means: Means
covariances: Covariances
weights: Weights
responsibilities: Responsibilities
n_components: int
n_features: int
[docs]
class EMLoopState(NamedTuple):
"""State for the in-memory EM loop."""
gmm: GMM
n_iter: Int
log_likelihood: LogLikelihood
log_likelihood_diff: LogLikelihood
[docs]
@dataclass
class EMFitterResult:
"""Result of the Expectation-Maximization fitting process.
Attributes
----------
gmm : GMM
The final fitted Gaussian mixture model.
n_iter : jax.Array
The total number of iterations performed.
log_likelihood : jax.Array
The log-likelihood of the data under the final model.
converged : jax.Array
A boolean indicating if the algorithm converged within the max iterations.
"""
gmm: GMM
n_iter: Int
log_likelihood: LogLikelihood
log_likelihood_diff: LogLikelihood
converged: Converged
features: EnsembleData | None = None
bic: BIC | None = None
OligomerType = Literal["monomer", "heteromer", "homooligomer", "tied_homooligomer"]