data_structures#
Dataclasses for the PrxteinMPNN project.
prxteinmpnn.utils.data_structures
- class prxteinmpnn.utils.data_structures.ProteinTuple(coordinates, aatype, atom_mask, residue_index, chain_index, full_coordinates=None, dihedrals=None, source=None, mapping=None, charges=None, radii=None, estat_backbone_mask=None, estat_resid=None, estat_chain_index=None)[source]#
Bases:
NamedTupleTuple-based protein structure representation.
- Parameters:
coordinates (np.ndarray)
aatype (np.ndarray)
atom_mask (np.ndarray)
residue_index (np.ndarray)
chain_index (np.ndarray)
full_coordinates (np.ndarray | None)
dihedrals (np.ndarray | None)
source (str | None)
mapping (np.ndarray | None)
charges (np.ndarray | None)
radii (np.ndarray | None)
estat_backbone_mask (np.ndarray | None)
estat_resid (np.ndarray | None)
estat_chain_index (np.ndarray | None)
- coordinates#
Atom positions in the structure, represented as a 3D array. Cartesian coordinates of atoms in angstroms. The atom types correspond to residue_constants.atom_types, i.e. the first three are N, CA, CB. Shape is (num_res, num_atom_type, 3), where num_res is the number of residues, num_atom_type is the number of atom types (e.g., N, CA, CB, C, O), and 3 is the spatial dimension (x, y, z).
- Type:
StructureAtomicCoordinates
- aatype#
Amino-acid type for each residue represented as an integer between 0
- Type:
ProteinSequence
- and 20,
where 20 is ‘X’. Shape is [num_res].
- atom_mask#
Binary float mask to indicate presence of a particular atom. 1.0 if an atom is present and 0.0 if not. This should be used for loss masking. Shape is [num_res, num_atom_type].
- Type:
AtomMask
- residue_index#
Residue index as used in PDB. It is not necessarily continuous or 0-indexed. Shape is [num_res].
- Type:
ResidueIndex
- chain_index#
Chain index for each residue. Shape is [num_res].
- Type:
ChainIndex
- dihedrals#
Dihedral angles for backbone atoms (phi, psi, omega). Shape is [num_res, 3]. If not provided, defaults to None.
- Type:
BackboneDihedrals | None
- coordinates: np.ndarray#
Alias for field number 0
- aatype: np.ndarray#
Alias for field number 1
- atom_mask: np.ndarray#
Alias for field number 2
- residue_index: np.ndarray#
Alias for field number 3
- chain_index: np.ndarray#
Alias for field number 4
- full_coordinates: np.ndarray | None#
Alias for field number 5
- dihedrals: np.ndarray | None#
Alias for field number 6
- source: str | None#
Alias for field number 7
- mapping: np.ndarray | None#
Alias for field number 8
- charges: np.ndarray | None#
Alias for field number 9
- radii: np.ndarray | None#
Alias for field number 10
- estat_backbone_mask: np.ndarray | None#
Alias for field number 11
- estat_resid: np.ndarray | None#
Alias for field number 12
- estat_chain_index: np.ndarray | None#
Alias for field number 13
- class prxteinmpnn.utils.data_structures.TrajectoryStaticFeatures(aatype, static_atom_mask_37, residue_indices, chain_index, valid_atom_mask, nitrogen_mask, num_residues)[source]#
Bases:
objectA container for pre-computed, frame-invariant protein features.
- Parameters:
aatype (np.ndarray)
static_atom_mask_37 (np.ndarray)
residue_indices (np.ndarray)
chain_index (np.ndarray)
valid_atom_mask (np.ndarray)
nitrogen_mask (np.ndarray)
num_residues (int)
- prxteinmpnn.utils.data_structures.include_feature(feature_name, include_features)[source]#
Determine if a feature should be included.
- class prxteinmpnn.utils.data_structures.EstatInfo(charges, radii, estat_backbone_mask, estat_resid, estat_chain_index)[source]#
Bases:
objectElectrostatics information extracted from a PQR file.
- Parameters:
- charges#
Numpy array of atomic charges.
- radii#
Numpy array of atomic radii.
- estat_backbone_mask#
Boolean numpy array indicating backbone atoms.
- estat_resid#
Integer numpy array of residue numbers.
- estat_chain_index#
Integer numpy array of chain indices (ord value).
- replace(**updates)#
Returns a new object replacing the specified fields with new values.
- class prxteinmpnn.utils.data_structures.Protein(coordinates, aatype, one_hot_sequence, mask, residue_index, chain_index, dihedrals=None, mapping=None, full_coordinates=None, full_atom_mask=None, charges=None, radii=None, estat_backbone_mask=None, estat_resid=None, estat_chain_index=None)[source]#
Bases:
objectProtein structure or ensemble representation.
- Parameters:
coordinates (Float[Array, 'num_residues num_atoms 3'])
aatype (Int[Array, 'num_residues'])
one_hot_sequence (Float[Array, 'num_residues num_classes'])
mask (Int[Array, 'num_residues 3'])
residue_index (Int[Array, 'num_residues'])
chain_index (Int[Array, 'num_residues'])
dihedrals (Float[Array, 'num_residues 3'] | None)
mapping (Int | None)
full_coordinates (Float[Array, 'num_residues num_atoms 3'] | None)
full_atom_mask (Int[Array, 'num_residues num_atoms'] | None)
charges (ndarray | None)
radii (ndarray | None)
estat_backbone_mask (ndarray | None)
estat_resid (ndarray | None)
estat_chain_index (ndarray | None)
- coordinates#
Atom positions in the structure, represented as a 3D array. Cartesian coordinates of atoms in angstroms. The atom types correspond to residue_constants.atom_types, i.e. the first three are N, CA, CB. Shape is (num_res, num_atom_type, 3), where num_res is the number of residues, num_atom_type is the number of atom types (e.g., N, CA, CB, C, O), and 3 is the spatial dimension (x, y, z).
- Type:
StructureAtomicCoordinates
- aatype#
Amino-acid type for each residue represented as an integer between 0 and 20, where 20 is ‘X’. Shape is [num_res].
- Type:
Sequence
- mask#
Binary float mask to indicate presence of alpha carbon atom. 1.0 if an atom is present and 0.0 if not. This should be used for loss masking. Shape is [num_res, num_atom_type].
- Type:
AlphaCarbonMask
- residue_index#
Residue index as used in PDB. It is not necessarily continuous or 0-indexed. Shape is [num_res].
- Type:
AtomResidueIndex
- chain_index#
Chain index for each residue. Shape is [num_res].
- Type:
ChainIndex
- dihedrals#
Dihedral angles for backbone atoms (phi, psi, omega). Shape is [num_res, 3]. If not provided, defaults to None.
- Type:
BackboneDihedrals | None
- mapping#
Optional array mapping residues in the ensemble to original structure indices. Shape is [num_res, num_frames]. If not provided, defaults to None.
- Type:
jnp.Array | None
- full_coordinates#
Full atomic coordinates including all heavy atoms. Shape is (num_res, num_full_atom_type, 3), where num_full_atom_type is the number of all heavy atom types (e.g., N, CA, CB, C, O, CG, etc.), and 3 is the spatial dimension (x, y, z). If not provided, defaults to None.
- Type:
StructureAtomicCoordinates | None
- full_atom_mask#
Binary float mask to indicate presence of a particular heavy atom. 1.0 if an atom is present and 0.0 if not. This should be used for loss masking. Shape is [num_res, num_full_atom_type]. If not provided, defaults to None.
- Type:
AtomMask | None
-
coordinates:
Float[Array, 'num_residues num_atoms 3']#
-
aatype:
Int[Array, 'num_residues']#
-
one_hot_sequence:
Float[Array, 'num_residues num_classes']#
-
mask:
Int[Array, 'num_residues 3']#
-
residue_index:
Int[Array, 'num_residues']#
-
chain_index:
Int[Array, 'num_residues']#
- classmethod from_tuple(protein_tuple, *, include_extras=None)[source]#
Create a Protein instance from a ProteinTuple.
- Parameters:
protein_tuple (ProteinTuple) – The input protein tuple.
include_extras (
Sequence[Literal['dihedrals','mapping','full_coordinates','full_atom_mask','all']] |None) – Optional list of extra fields to include from the tuple. If ‘all’ is included, all optional fields will be included. If None, no optional fields will be included.
- Returns:
The output protein dataclass.
- Return type:
- replace(**updates)#
Returns a new object replacing the specified fields with new values.
- class prxteinmpnn.utils.data_structures._EStepState(component_counts, weighted_data, weighted_squared_data, log_likelihood_total)[source]#
Bases:
objectState for accumulating statistics during the E-step.
- Parameters:
component_counts (Int[Array, 'n_components'])
weighted_data (Float[Array, 'num_samples num_features'] | Float[Array, 'n_batches n_samples n_features'])
weighted_squared_data (Float[Array, 'num_samples num_features'] | Float[Array, 'n_batches n_samples n_features'])
log_likelihood_total (Float[Array, ''])
-
component_counts:
Int[Array, 'n_components']#
-
weighted_data:
Float[Array, 'num_samples num_features']|Float[Array, 'n_batches n_samples n_features']#
-
weighted_squared_data:
Float[Array, 'num_samples num_features']|Float[Array, 'n_batches n_samples n_features']#
-
log_likelihood_total:
Float[Array, '']#
- replace(**updates)#
Returns a new object replacing the specified fields with new values.
- class prxteinmpnn.utils.data_structures.GMM(means, covariances, weights, responsibilities, n_components, n_features)[source]#
Bases:
objectDataclass to hold GMM parameters.
- Parameters:
-
means:
Float[Array, 'n_components n_features']#
-
covariances:
Float[Array, 'n_components n_features n_features']#
-
weights:
Float[Array, 'n_components']#
-
responsibilities:
Float[Array, 'n_samples n_components']#
- replace(**updates)#
Returns a new object replacing the specified fields with new values.
- class prxteinmpnn.utils.data_structures.EMLoopState(gmm, n_iter, log_likelihood, log_likelihood_diff)[source]#
Bases:
NamedTupleState for the in-memory EM loop.
- Parameters:
gmm (GMM)
n_iter (Int)
log_likelihood (Float[Array, ''])
log_likelihood_diff (Float[Array, ''])
-
n_iter:
Int# Alias for field number 1
-
log_likelihood:
Float[Array, '']# Alias for field number 2
-
log_likelihood_diff:
Float[Array, '']# Alias for field number 3
- class prxteinmpnn.utils.data_structures.EMFitterResult(gmm, n_iter, log_likelihood, log_likelihood_diff, converged, features=None, bic=None)[source]#
Bases:
objectResult of the Expectation-Maximization fitting process.
Attributes#
- gmmGMM
The final fitted Gaussian mixture model.
- n_iterjax.Array
The total number of iterations performed.
- log_likelihoodjax.Array
The log-likelihood of the data under the final model.
- convergedjax.Array
A boolean indicating if the algorithm converged within the max iterations.
-
n_iter:
Int#
-
log_likelihood:
Float[Array, '']#
-
log_likelihood_diff:
Float[Array, '']#
-
converged:
Bool[Array, '']#
-
features:
Float[Array, 'num_samples num_features']|Float[Array, 'n_batches n_samples n_features']|None= None#
- replace(**updates)#
Returns a new object replacing the specified fields with new values.
- Parameters:
gmm (GMM)
n_iter (Int)
log_likelihood (Float[Array, ''])
log_likelihood_diff (Float[Array, ''])
converged (Bool[Array, ''])
features (Float[Array, 'num_samples num_features'] | Float[Array, 'n_batches n_samples n_features'] | None)
bic (Float[Array, ''] | None)