data_structures

Contents

data_structures#

Dataclasses for the PrxteinMPNN project.

prxteinmpnn.utils.data_structures

class prxteinmpnn.utils.data_structures.ProteinTuple(coordinates, aatype, atom_mask, residue_index, chain_index, full_coordinates=None, dihedrals=None, source=None, mapping=None, charges=None, radii=None, estat_backbone_mask=None, estat_resid=None, estat_chain_index=None)[source]#

Bases: NamedTuple

Tuple-based protein structure representation.

Parameters:
  • coordinates (np.ndarray)

  • aatype (np.ndarray)

  • atom_mask (np.ndarray)

  • residue_index (np.ndarray)

  • chain_index (np.ndarray)

  • full_coordinates (np.ndarray | None)

  • dihedrals (np.ndarray | None)

  • source (str | None)

  • mapping (np.ndarray | None)

  • charges (np.ndarray | None)

  • radii (np.ndarray | None)

  • estat_backbone_mask (np.ndarray | None)

  • estat_resid (np.ndarray | None)

  • estat_chain_index (np.ndarray | None)

coordinates#

Atom positions in the structure, represented as a 3D array. Cartesian coordinates of atoms in angstroms. The atom types correspond to residue_constants.atom_types, i.e. the first three are N, CA, CB. Shape is (num_res, num_atom_type, 3), where num_res is the number of residues, num_atom_type is the number of atom types (e.g., N, CA, CB, C, O), and 3 is the spatial dimension (x, y, z).

Type:

StructureAtomicCoordinates

aatype#

Amino-acid type for each residue represented as an integer between 0

Type:

ProteinSequence

and 20,

where 20 is ‘X’. Shape is [num_res].

atom_mask#

Binary float mask to indicate presence of a particular atom. 1.0 if an atom is present and 0.0 if not. This should be used for loss masking. Shape is [num_res, num_atom_type].

Type:

AtomMask

residue_index#

Residue index as used in PDB. It is not necessarily continuous or 0-indexed. Shape is [num_res].

Type:

ResidueIndex

chain_index#

Chain index for each residue. Shape is [num_res].

Type:

ChainIndex

dihedrals#

Dihedral angles for backbone atoms (phi, psi, omega). Shape is [num_res, 3]. If not provided, defaults to None.

Type:

BackboneDihedrals | None

coordinates: np.ndarray#

Alias for field number 0

aatype: np.ndarray#

Alias for field number 1

atom_mask: np.ndarray#

Alias for field number 2

residue_index: np.ndarray#

Alias for field number 3

chain_index: np.ndarray#

Alias for field number 4

full_coordinates: np.ndarray | None#

Alias for field number 5

dihedrals: np.ndarray | None#

Alias for field number 6

source: str | None#

Alias for field number 7

mapping: np.ndarray | None#

Alias for field number 8

charges: np.ndarray | None#

Alias for field number 9

radii: np.ndarray | None#

Alias for field number 10

estat_backbone_mask: np.ndarray | None#

Alias for field number 11

estat_resid: np.ndarray | None#

Alias for field number 12

estat_chain_index: np.ndarray | None#

Alias for field number 13

class prxteinmpnn.utils.data_structures.TrajectoryStaticFeatures(aatype, static_atom_mask_37, residue_indices, chain_index, valid_atom_mask, nitrogen_mask, num_residues)[source]#

Bases: object

A container for pre-computed, frame-invariant protein features.

Parameters:
  • aatype (np.ndarray)

  • static_atom_mask_37 (np.ndarray)

  • residue_indices (np.ndarray)

  • chain_index (np.ndarray)

  • valid_atom_mask (np.ndarray)

  • nitrogen_mask (np.ndarray)

  • num_residues (int)

aatype: ndarray#
static_atom_mask_37: ndarray#
residue_indices: ndarray#
chain_index: ndarray#
valid_atom_mask: ndarray#
nitrogen_mask: ndarray#
num_residues: int#
prxteinmpnn.utils.data_structures.include_feature(feature_name, include_features)[source]#

Determine if a feature should be included.

Parameters:
  • feature_name (str) – The name of the feature to check.

  • include_features (Sequence[str] | None) – The list of features to include. If None, no features are included.

Returns:

True if the feature should be included, False otherwise.

Return type:

bool

class prxteinmpnn.utils.data_structures.EstatInfo(charges, radii, estat_backbone_mask, estat_resid, estat_chain_index)[source]#

Bases: object

Electrostatics information extracted from a PQR file.

Parameters:
charges#

Numpy array of atomic charges.

radii#

Numpy array of atomic radii.

estat_backbone_mask#

Boolean numpy array indicating backbone atoms.

estat_resid#

Integer numpy array of residue numbers.

estat_chain_index#

Integer numpy array of chain indices (ord value).

charges: ndarray#
radii: ndarray#
estat_backbone_mask: ndarray#
estat_resid: ndarray#
estat_chain_index: ndarray#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

class prxteinmpnn.utils.data_structures.Protein(coordinates, aatype, one_hot_sequence, mask, residue_index, chain_index, dihedrals=None, mapping=None, full_coordinates=None, full_atom_mask=None, charges=None, radii=None, estat_backbone_mask=None, estat_resid=None, estat_chain_index=None)[source]#

Bases: object

Protein structure or ensemble representation.

Parameters:
  • coordinates (Float[Array, 'num_residues num_atoms 3'])

  • aatype (Int[Array, 'num_residues'])

  • one_hot_sequence (Float[Array, 'num_residues num_classes'])

  • mask (Int[Array, 'num_residues 3'])

  • residue_index (Int[Array, 'num_residues'])

  • chain_index (Int[Array, 'num_residues'])

  • dihedrals (Float[Array, 'num_residues 3'] | None)

  • mapping (Int | None)

  • full_coordinates (Float[Array, 'num_residues num_atoms 3'] | None)

  • full_atom_mask (Int[Array, 'num_residues num_atoms'] | None)

  • charges (ndarray | None)

  • radii (ndarray | None)

  • estat_backbone_mask (ndarray | None)

  • estat_resid (ndarray | None)

  • estat_chain_index (ndarray | None)

coordinates#

Atom positions in the structure, represented as a 3D array. Cartesian coordinates of atoms in angstroms. The atom types correspond to residue_constants.atom_types, i.e. the first three are N, CA, CB. Shape is (num_res, num_atom_type, 3), where num_res is the number of residues, num_atom_type is the number of atom types (e.g., N, CA, CB, C, O), and 3 is the spatial dimension (x, y, z).

Type:

StructureAtomicCoordinates

aatype#

Amino-acid type for each residue represented as an integer between 0 and 20, where 20 is ‘X’. Shape is [num_res].

Type:

Sequence

mask#

Binary float mask to indicate presence of alpha carbon atom. 1.0 if an atom is present and 0.0 if not. This should be used for loss masking. Shape is [num_res, num_atom_type].

Type:

AlphaCarbonMask

residue_index#

Residue index as used in PDB. It is not necessarily continuous or 0-indexed. Shape is [num_res].

Type:

AtomResidueIndex

chain_index#

Chain index for each residue. Shape is [num_res].

Type:

ChainIndex

dihedrals#

Dihedral angles for backbone atoms (phi, psi, omega). Shape is [num_res, 3]. If not provided, defaults to None.

Type:

BackboneDihedrals | None

mapping#

Optional array mapping residues in the ensemble to original structure indices. Shape is [num_res, num_frames]. If not provided, defaults to None.

Type:

jnp.Array | None

full_coordinates#

Full atomic coordinates including all heavy atoms. Shape is (num_res, num_full_atom_type, 3), where num_full_atom_type is the number of all heavy atom types (e.g., N, CA, CB, C, O, CG, etc.), and 3 is the spatial dimension (x, y, z). If not provided, defaults to None.

Type:

StructureAtomicCoordinates | None

full_atom_mask#

Binary float mask to indicate presence of a particular heavy atom. 1.0 if an atom is present and 0.0 if not. This should be used for loss masking. Shape is [num_res, num_full_atom_type]. If not provided, defaults to None.

Type:

AtomMask | None

coordinates: Float[Array, 'num_residues num_atoms 3']#
aatype: Int[Array, 'num_residues']#
one_hot_sequence: Float[Array, 'num_residues num_classes']#
mask: Int[Array, 'num_residues 3']#
residue_index: Int[Array, 'num_residues']#
chain_index: Int[Array, 'num_residues']#
dihedrals: Float[Array, 'num_residues 3'] | None = None#
mapping: Int | None = None#
full_coordinates: Float[Array, 'num_residues num_atoms 3'] | None = None#
full_atom_mask: Int[Array, 'num_residues num_atoms'] | None = None#
charges: ndarray | None = None#
radii: ndarray | None = None#
estat_backbone_mask: ndarray | None = None#
estat_resid: ndarray | None = None#
estat_chain_index: ndarray | None = None#
classmethod from_tuple(protein_tuple, *, include_extras=None)[source]#

Create a Protein instance from a ProteinTuple.

Parameters:
  • protein_tuple (ProteinTuple) – The input protein tuple.

  • include_extras (Sequence[Literal['dihedrals', 'mapping', 'full_coordinates', 'full_atom_mask', 'all']] | None) – Optional list of extra fields to include from the tuple. If ‘all’ is included, all optional fields will be included. If None, no optional fields will be included.

Returns:

The output protein dataclass.

Return type:

Protein

replace(**updates)#

Returns a new object replacing the specified fields with new values.

class prxteinmpnn.utils.data_structures._EStepState(component_counts, weighted_data, weighted_squared_data, log_likelihood_total)[source]#

Bases: object

State for accumulating statistics during the E-step.

Parameters:
  • component_counts (Int[Array, 'n_components'])

  • weighted_data (Float[Array, 'num_samples num_features'] | Float[Array, 'n_batches n_samples n_features'])

  • weighted_squared_data (Float[Array, 'num_samples num_features'] | Float[Array, 'n_batches n_samples n_features'])

  • log_likelihood_total (Float[Array, ''])

component_counts: Int[Array, 'n_components']#
weighted_data: Float[Array, 'num_samples num_features'] | Float[Array, 'n_batches n_samples n_features']#
weighted_squared_data: Float[Array, 'num_samples num_features'] | Float[Array, 'n_batches n_samples n_features']#
log_likelihood_total: Float[Array, '']#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

class prxteinmpnn.utils.data_structures.GMM(means, covariances, weights, responsibilities, n_components, n_features)[source]#

Bases: object

Dataclass to hold GMM parameters.

Parameters:
  • means (Float[Array, 'n_components n_features'])

  • covariances (Float[Array, 'n_components n_features n_features'])

  • weights (Float[Array, 'n_components'])

  • responsibilities (Float[Array, 'n_samples n_components'])

  • n_components (int)

  • n_features (int)

means: Float[Array, 'n_components n_features']#
covariances: Float[Array, 'n_components n_features n_features']#
weights: Float[Array, 'n_components']#
responsibilities: Float[Array, 'n_samples n_components']#
n_components: int#
n_features: int#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

class prxteinmpnn.utils.data_structures.EMLoopState(gmm, n_iter, log_likelihood, log_likelihood_diff)[source]#

Bases: NamedTuple

State for the in-memory EM loop.

Parameters:
  • gmm (GMM)

  • n_iter (Int)

  • log_likelihood (Float[Array, ''])

  • log_likelihood_diff (Float[Array, ''])

gmm: GMM#

Alias for field number 0

n_iter: Int#

Alias for field number 1

log_likelihood: Float[Array, '']#

Alias for field number 2

log_likelihood_diff: Float[Array, '']#

Alias for field number 3

class prxteinmpnn.utils.data_structures.EMFitterResult(gmm, n_iter, log_likelihood, log_likelihood_diff, converged, features=None, bic=None)[source]#

Bases: object

Result of the Expectation-Maximization fitting process.

Attributes#

gmmGMM

The final fitted Gaussian mixture model.

n_iterjax.Array

The total number of iterations performed.

log_likelihoodjax.Array

The log-likelihood of the data under the final model.

convergedjax.Array

A boolean indicating if the algorithm converged within the max iterations.

gmm: GMM#
n_iter: Int#
log_likelihood: Float[Array, '']#
log_likelihood_diff: Float[Array, '']#
converged: Bool[Array, '']#
features: Float[Array, 'num_samples num_features'] | Float[Array, 'n_batches n_samples n_features'] | None = None#
bic: Float[Array, ''] | None = None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

Parameters:
  • gmm (GMM)

  • n_iter (Int)

  • log_likelihood (Float[Array, ''])

  • log_likelihood_diff (Float[Array, ''])

  • converged (Bool[Array, ''])

  • features (Float[Array, 'num_samples num_features'] | Float[Array, 'n_batches n_samples n_features'] | None)

  • bic (Float[Array, ''] | None)