data_structures#

Dataclasses and enums for the PrxteinMPNN project.

prxteinmpnn.utils.data_structures

class prxteinmpnn.utils.data_structures.ProteinStructure(coordinates, aatype, atom_mask, residue_index, chain_index, dihedrals=None)[source]#

Bases: object

Protein structure representation.

Parameters:
  • coordinates (StructureAtomicCoordinates)

  • aatype (ProteinSequence)

  • atom_mask (AtomMask)

  • residue_index (ResidueIndex)

  • chain_index (ChainIndex)

  • dihedrals (BackboneDihedrals | None)

coordinates#

Atom positions in the structure, represented as a 3D array. Cartesian coordinates of atoms in angstroms. The atom types correspond to residue_constants.atom_types, i.e. the first three are N, CA, CB. Shape is (num_res, num_atom_type, 3), where num_res is the number of residues, num_atom_type is the number of atom types (e.g., N, CA, CB, C, O), and 3 is the spatial dimension (x, y, z).

Type:

StructureAtomicCoordinates

aatype#

Amino-acid type for each residue represented as an integer between 0 and 20, where 20 is ‘X’. Shape is [num_res].

Type:

Sequence

atom_mask#

Binary float mask to indicate presence of a particular atom. 1.0 if an atom is present and 0.0 if not. This should be used for loss masking. Shape is [num_res, num_atom_type].

Type:

AtomMask

residue_index#

Residue index as used in PDB. It is not necessarily continuous or 0-indexed. Shape is [num_res].

Type:

AtomResidueIndex

coordinates: Float[Array, 'num_residues num_atoms 3']#
aatype: Int[Array, 'num_residues']#
atom_mask: Int[Array, 'num_residues num_atoms']#
residue_index: Int[Array, 'num_residues']#
chain_index: Int[Array, 'num_residues']#
dihedrals: Float[Array, 'num_residues 3'] | None = None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

class prxteinmpnn.utils.data_structures.ModelInputs(structure_coordinates=<factory>, sequence=<factory>, mask=<factory>, residue_index=<factory>, chain_index=<factory>, lengths=<factory>, bias=<factory>, k_neighbors=48, augment_eps=0.0)[source]#

Bases: object

Dataclass for general model inputs.

Note that any of these can be stacked together to form a batch of inputs.

Parameters:
  • structure_coordinates (Float[Array, 'num_residues num_atoms 3'])

  • sequence (Int[Array, 'num_residues'])

  • mask (Int[Array, 'num_residues 3'])

  • residue_index (Int[Array, 'num_residues'])

  • chain_index (Int[Array, 'num_residues'])

  • lengths (Int[Array, 'num_sequences'])

  • bias (Float[Array, 'num_residues num_classes'] | None)

  • k_neighbors (int)

  • augment_eps (float)

structure_coordinates#

Atomic coordinates of the structure. Shape is (num_residues, num_atoms, 3), where num_residues is the number of residues and num_atoms is the number of atoms per residue (e.g., 37 for standard residues).

Type:

StructureAtomicCoordinates

sequence#

Sequence of amino acids as an array of integers. Shape is (num_residues,), where num_residues is the number of residues.

Type:

Sequence

mask#

Mask for the model input, indicating valid atoms in the structure. Shape is (num_residues, num_atoms), where num_atoms is the number of atoms per residue.

Type:

AtomMask

residue_index#

Index of residues in the structure, used for mapping atoms in structures to their residues. Shape is (num_residues,).

Type:

ResidueIndex

chain_index#

Index of chains in the structure, used for mapping atoms in structures to their chains. Shape is (num_residues,).

Type:

ChainIndex

lengths#

Lengths of the sequences in the batch, used for padding and batching. Shape is (num_sequences,), where num_sequences is the number of sequences in the batch.

Type:

InputLengths

bias#

Bias for the model input, used for classification tasks. Defaults to zero bias of shape (sum(lengths), 20), where 20 is the number of amino acid types. Shape is (sum(lengths), num_classes), where num_classes is the number of classes (e.g., 20 for amino acids).

Type:

InputBias

structure_coordinates: Float[Array, 'num_residues num_atoms 3']#

Structure atomic coordinates for the model input.

sequence: Int[Array, 'num_residues']#

A sequence of amino acids for the model input. As MPNN-alphabet based array of integers.

mask: Int[Array, 'num_residues 3']#

Mask for the model input, indicating valid atoms structure.

residue_index: Int[Array, 'num_residues']#

Index of residues in the structure, used for mapping atoms in structures to their residues.

chain_index: Int[Array, 'num_residues']#

Index of chains in the structure, used for mapping atoms in structures to their chains.

lengths: Int[Array, 'num_sequences']#

Lengths of the sequences in the batch, used for padding and batching.

bias: Float[Array, 'num_residues num_classes'] | None#

Bias for the model input, used for classification tasks. Defaults to zero bias of shape (sum(lengths), 20).

k_neighbors: int = 48#

Number of neighbors to consider for each atom in the structure.

augment_eps: float = 0.0#

Epsilon value for adding noise to the backbone coordinates, used for data augmentation.

replace(**updates)#

Returns a new object replacing the specified fields with new values.

class prxteinmpnn.utils.data_structures.SamplingInputs(structure_coordinates=<factory>, sequence=<factory>, mask=<factory>, residue_index=<factory>, chain_index=<factory>, lengths=<factory>, bias=<factory>, k_neighbors=48, augment_eps=0.0, prng_key=<factory>, hyperparameters=(0.0, ), iterations=1)[source]#

Bases: ModelInputs

Dataclass for inputs used in sequence sampling.

Parameters:
  • structure_coordinates (Float[Array, 'num_residues num_atoms 3'])

  • sequence (Int[Array, 'num_residues'])

  • mask (Int[Array, 'num_residues 3'])

  • residue_index (Int[Array, 'num_residues'])

  • chain_index (Int[Array, 'num_residues'])

  • lengths (Int[Array, 'num_sequences'])

  • bias (Float[Array, 'num_residues num_classes'] | None)

  • k_neighbors (int)

  • augment_eps (float)

  • prng_key (Key[Array, ''] | UInt32[Array, '2'])

  • hyperparameters (tuple[float | int | Array | GradientTransformation, ...])

  • iterations (int)

prng_key#

Random key for JAX operations.

Type:

PRNGKeyArray

initial_sequence#

Initial sequence of amino acids as an array of integers.

Type:

ProteinSequence

structure_coordinates#

Atomic coordinates of the structure.

Type:

StructureAtomicCoordinates

mask#

Mask indicating valid atoms in the structure.

Type:

AtomMask

residue_index#

Index of residues in the structure.

Type:

AtomResidueIndex

chain_index#

Index of chains in the structure.

Type:

AtomChainIndex

bias#

Bias for the model input, used for classification tasks.

Type:

InputBias | None

k_neighbors#

Number of neighbors to consider for each atom.

Type:

int

augment_eps#

Epsilon value for adding noise to the backbone coordinates.

Type:

float

hyperparameters#

Hyperparameters for sampling, e.g., temperature, top-k, etc.

Type:

SamplingHyperparameters

iterations#

Number of iterations for sampling.

Type:

int

replace(**updates)#

Returns a new object replacing the specified fields with new values.

prng_key: Union[Key[Array, ''], UInt32[Array, '2']]#

Random key for JAX operations.

hyperparameters: tuple[float | int | Array | GradientTransformation, ...] = (0.0,)#

Hyperparameters for sampling, e.g., temperature, top-k, etc.

iterations: int = 1#

Number of iterations for sampling.

class prxteinmpnn.utils.data_structures.OligomerType(value)[source]#

Bases: Enum

Enum for different types of oligomers.

MONOMER = 'monomer'#
HETEROMER = 'heteromer'#
HOMOOLIGOMER = 'homooligomer'#
TIED_HOMOOLIGOMER = 'tied_homooligomer'#