data_structures#
Dataclasses and enums for the PrxteinMPNN project.
prxteinmpnn.utils.data_structures
- class prxteinmpnn.utils.data_structures.ProteinStructure(coordinates, aatype, atom_mask, residue_index, chain_index, dihedrals=None)[source]#
Bases:
object
Protein structure representation.
- Parameters:
coordinates (StructureAtomicCoordinates)
aatype (ProteinSequence)
atom_mask (AtomMask)
residue_index (ResidueIndex)
chain_index (ChainIndex)
dihedrals (BackboneDihedrals | None)
- coordinates#
Atom positions in the structure, represented as a 3D array. Cartesian coordinates of atoms in angstroms. The atom types correspond to residue_constants.atom_types, i.e. the first three are N, CA, CB. Shape is (num_res, num_atom_type, 3), where num_res is the number of residues, num_atom_type is the number of atom types (e.g., N, CA, CB, C, O), and 3 is the spatial dimension (x, y, z).
- Type:
StructureAtomicCoordinates
- aatype#
Amino-acid type for each residue represented as an integer between 0 and 20, where 20 is ‘X’. Shape is [num_res].
- Type:
Sequence
- atom_mask#
Binary float mask to indicate presence of a particular atom. 1.0 if an atom is present and 0.0 if not. This should be used for loss masking. Shape is [num_res, num_atom_type].
- Type:
AtomMask
- residue_index#
Residue index as used in PDB. It is not necessarily continuous or 0-indexed. Shape is [num_res].
- Type:
AtomResidueIndex
-
coordinates:
Float[Array, 'num_residues num_atoms 3']
#
-
aatype:
Int[Array, 'num_residues']
#
-
atom_mask:
Int[Array, 'num_residues num_atoms']
#
-
residue_index:
Int[Array, 'num_residues']
#
-
chain_index:
Int[Array, 'num_residues']
#
- replace(**updates)#
Returns a new object replacing the specified fields with new values.
- class prxteinmpnn.utils.data_structures.ModelInputs(structure_coordinates=<factory>, sequence=<factory>, mask=<factory>, residue_index=<factory>, chain_index=<factory>, lengths=<factory>, bias=<factory>, k_neighbors=48, augment_eps=0.0)[source]#
Bases:
object
Dataclass for general model inputs.
Note that any of these can be stacked together to form a batch of inputs.
- Parameters:
structure_coordinates (Float[Array, 'num_residues num_atoms 3'])
sequence (Int[Array, 'num_residues'])
mask (Int[Array, 'num_residues 3'])
residue_index (Int[Array, 'num_residues'])
chain_index (Int[Array, 'num_residues'])
lengths (Int[Array, 'num_sequences'])
bias (Float[Array, 'num_residues num_classes'] | None)
k_neighbors (int)
augment_eps (float)
- structure_coordinates#
Atomic coordinates of the structure. Shape is (num_residues, num_atoms, 3), where num_residues is the number of residues and num_atoms is the number of atoms per residue (e.g., 37 for standard residues).
- Type:
StructureAtomicCoordinates
- sequence#
Sequence of amino acids as an array of integers. Shape is (num_residues,), where num_residues is the number of residues.
- Type:
Sequence
- mask#
Mask for the model input, indicating valid atoms in the structure. Shape is (num_residues, num_atoms), where num_atoms is the number of atoms per residue.
- Type:
AtomMask
- residue_index#
Index of residues in the structure, used for mapping atoms in structures to their residues. Shape is (num_residues,).
- Type:
ResidueIndex
- chain_index#
Index of chains in the structure, used for mapping atoms in structures to their chains. Shape is (num_residues,).
- Type:
ChainIndex
- lengths#
Lengths of the sequences in the batch, used for padding and batching. Shape is (num_sequences,), where num_sequences is the number of sequences in the batch.
- Type:
InputLengths
- bias#
Bias for the model input, used for classification tasks. Defaults to zero bias of shape (sum(lengths), 20), where 20 is the number of amino acid types. Shape is (sum(lengths), num_classes), where num_classes is the number of classes (e.g., 20 for amino acids).
- Type:
InputBias
-
structure_coordinates:
Float[Array, 'num_residues num_atoms 3']
# Structure atomic coordinates for the model input.
-
sequence:
Int[Array, 'num_residues']
# A sequence of amino acids for the model input. As MPNN-alphabet based array of integers.
-
mask:
Int[Array, 'num_residues 3']
# Mask for the model input, indicating valid atoms structure.
-
residue_index:
Int[Array, 'num_residues']
# Index of residues in the structure, used for mapping atoms in structures to their residues.
-
chain_index:
Int[Array, 'num_residues']
# Index of chains in the structure, used for mapping atoms in structures to their chains.
-
lengths:
Int[Array, 'num_sequences']
# Lengths of the sequences in the batch, used for padding and batching.
-
bias:
Float[Array, 'num_residues num_classes']
|None
# Bias for the model input, used for classification tasks. Defaults to zero bias of shape (sum(lengths), 20).
-
augment_eps:
float
= 0.0# Epsilon value for adding noise to the backbone coordinates, used for data augmentation.
- replace(**updates)#
Returns a new object replacing the specified fields with new values.
- class prxteinmpnn.utils.data_structures.SamplingInputs(structure_coordinates=<factory>, sequence=<factory>, mask=<factory>, residue_index=<factory>, chain_index=<factory>, lengths=<factory>, bias=<factory>, k_neighbors=48, augment_eps=0.0, prng_key=<factory>, hyperparameters=(0.0, ), iterations=1)[source]#
Bases:
ModelInputs
Dataclass for inputs used in sequence sampling.
- Parameters:
structure_coordinates (Float[Array, 'num_residues num_atoms 3'])
sequence (Int[Array, 'num_residues'])
mask (Int[Array, 'num_residues 3'])
residue_index (Int[Array, 'num_residues'])
chain_index (Int[Array, 'num_residues'])
lengths (Int[Array, 'num_sequences'])
bias (Float[Array, 'num_residues num_classes'] | None)
k_neighbors (int)
augment_eps (float)
prng_key (Key[Array, ''] | UInt32[Array, '2'])
hyperparameters (tuple[float | int | Array | GradientTransformation, ...])
iterations (int)
- prng_key#
Random key for JAX operations.
- Type:
PRNGKeyArray
- initial_sequence#
Initial sequence of amino acids as an array of integers.
- Type:
ProteinSequence
- structure_coordinates#
Atomic coordinates of the structure.
- Type:
StructureAtomicCoordinates
- mask#
Mask indicating valid atoms in the structure.
- Type:
AtomMask
- residue_index#
Index of residues in the structure.
- Type:
AtomResidueIndex
- chain_index#
Index of chains in the structure.
- Type:
AtomChainIndex
- bias#
Bias for the model input, used for classification tasks.
- Type:
InputBias | None
- hyperparameters#
Hyperparameters for sampling, e.g., temperature, top-k, etc.
- Type:
SamplingHyperparameters
- replace(**updates)#
Returns a new object replacing the specified fields with new values.