sampling module#
Sampling utilities for PrXteinMPNN.
- prxteinmpnn.sampling.make_conditional_logits_fn(model)[source]#
Create a function to compute conditional logits for a given sequence.
Conditional logits evaluate how well a sequence fits a structure by running the model with the sequence as input.
- Parameters:
model (
PrxteinMPNN) – A PrxteinMPNN Equinox model instance.- Return type:
Callable[[Union[Key[Array, ''],UInt32[Array, '2']],Float[Array, 'num_residues num_atoms 3'],Int[Array, 'num_residues 3'],Int[Array, 'num_residues'],Int[Array, 'num_residues'],Int[Array, 'num_residues'],Bool[Array, 'num_residues num_residues']|None,Float[Array, 'n']|None],Float[Array, 'num_residues num_classes']]- Returns:
A function that computes conditional logits for sequence-structure pairs.
Example
>>> from prxteinmpnn.io.weights import load_model >>> model = load_model() >>> logits_fn = make_conditional_logits_fn(model) >>> logits = logits_fn(key, coords, mask, res_idx, chain_idx, sequence)
- prxteinmpnn.sampling.make_encoding_conditional_logits_split_fn(model)[source]#
Create separate encoding and decoding functions for averaged encodings.
This splits the model into two parts: 1. Encoding: Structure -> Encoder features (node_features, edge_features, neighbor_indices) 2. Decoding: (Encoder features, Sequence) -> Logits
This separation allows: - Averaging encoder features across multiple noise levels - Efficient jacobian computation by caching encoder output - Reusing encoder output for multiple sequence evaluations
- Parameters:
model (
PrxteinMPNN) – A PrxteinMPNN Equinox model instance.- Returns:
encode_fn: Computes encoder features from structure
decode_fn: Computes logits from cached features and sequence
- Return type:
Tuple of (encode_fn, decode_fn) where
Example
>>> encode_fn, decode_fn = make_encoding_conditional_logits_split_fn(model) >>> # Encode once >>> key = jax.random.key(0) >>> encoding = encode_fn(key, coords, mask, res_idx, chain_idx, noise=0.1) >>> # Decode multiple sequences using same encoding >>> logits1 = decode_fn(encoding, sequence1) >>> logits2 = decode_fn(encoding, sequence2)
- prxteinmpnn.sampling.make_encoding_sampling_split_fn(model_parameters, decoding_order_fn=None, sampling_strategy='temperature')[source]#
Create separate encoding and sampling functions for averaged encodings.
This splits the sampling process into two parts: 1. Encoding: Structure -> Encoder features (can be averaged across noise levels) 2. Sampling: (Encoder features, PRNGKey) -> Sequences
This separation allows: - Averaging encoder features across multiple noise levels - Efficient reuse of encoder output for multiple samples - Lower memory usage when sampling many sequences
Supports tied positions: when tie_group_map is provided, positions in the same group will be sampled together and receive identical amino acids.
- Parameters:
model_parameters (
PrxteinMPNN) – A PrxteinMPNN Equinox model instance.decoding_order_fn (
Callable[[Union[Key[Array, ''],UInt32[Array, '2']],int,Array|None,int|None],tuple[Int[Array, 'num_residues'],Union[Key[Array, ''],UInt32[Array, '2']]]] |None) – Function to generate decoding order (default: random).sampling_strategy (
Literal['temperature','straight_through']) – Sampling strategy - “temperature” or “straight_through”.
- Returns:
encode_fn: Computes encoder features from structure
sample_fn: Samples sequences from cached encoder features
- Return type:
Tuple of (encode_fn, sample_fn) where
Example
>>> from prxteinmpnn.io.weights import load_model >>> model = load_model() >>> encode_fn, sample_fn = make_encoding_sampling_split_fn( ... model, sampling_strategy="temperature" ... ) >>> # Encode once >>> encoding = encode_fn( ... key, coords, mask, res_idx, chain_idx, ... k_neighbors=48, backbone_noise=0.1 ... ) >>> # Sample multiple sequences using same encoding >>> seq1 = sample_fn(key1, encoding, order1, temperature=0.1) >>> seq2 = sample_fn(key2, encoding, order2, temperature=0.5)
- prxteinmpnn.sampling.make_sample_sequences(model, decoding_order_fn=<PjitFunction of <function random_decoding_order>>, sampling_strategy='temperature', _num_encoder_layers=3, _num_decoder_layers=3)[source]#
Create a function to sample sequences from a structure using PrxteinMPNN.
- Parameters:
model (
PrxteinMPNN) – A PrxteinMPNN Equinox model instance.decoding_order_fn (
Callable[[Union[Key[Array, ''],UInt32[Array, '2']],int,Array|None,int|None],tuple[Int[Array, 'num_residues'],Union[Key[Array, ''],UInt32[Array, '2']]]]) – Function to generate decoding order (default: random). Should accept (key, num_residues, tie_group_map, num_groups).sampling_strategy (
Literal['temperature','straight_through']) – Sampling strategy - “temperature” or “straight_through”._num_encoder_layers (
int) – Deprecated, ignored (kept for API compatibility)._num_decoder_layers (
int) – Deprecated, ignored (kept for API compatibility).
- Return type:
Callable[...,tuple[Int[Array, 'num_residues'],Float[Array, 'num_residues num_classes'],Int[Array, 'num_residues']]]- Returns:
A function that samples sequences from structures.
Example
>>> from prxteinmpnn.io.weights import load_model >>> model = load_model() >>> sample_fn = make_sample_sequences(model, sampling_strategy="temperature") >>> seq, logits, order = sample_fn(key, coords, mask, res_idx, chain_idx) >>> >>> # With tied positions >>> tie_map = jnp.array([0, 0, 1, 1, 2]) # Positions 0-1 tied, 2-3 tied >>> seq, logits, order = sample_fn( ... key, coords, mask, res_idx, chain_idx, ... tie_group_map=tie_map, num_groups=3 ... ) >>> >>> # For optimization >>> optimize_fn = make_sample_sequences(model, sampling_strategy="straight_through") >>> seq, logits, order = optimize_fn( ... key, coords, mask, res_idx, chain_idx, ... iterations=100, learning_rate=0.01 ... )
- prxteinmpnn.sampling.make_unconditional_logits_fn(model)[source]#
Create a function to compute unconditional logits from a structure.
Unconditional logits are computed without sequence input, predicting the most likely amino acids at each position based purely on structure.
- Parameters:
model (
PrxteinMPNN) – A PrxteinMPNN Equinox model instance.- Return type:
Callable[[Union[Key[Array, ''],UInt32[Array, '2']],Float[Array, 'num_residues num_atoms 3'],Int[Array, 'num_residues 3'],Int[Array, 'num_residues'],Int[Array, 'num_residues'],Bool[Array, 'num_residues num_residues']|None,Float[Array, 'n']|None],Float[Array, 'num_residues num_classes']]- Returns:
A function that computes unconditional logits from structures.
Example
>>> from prxteinmpnn.io.weights import load_model >>> model = load_model() >>> logits_fn = make_unconditional_logits_fn(model) >>> logits = logits_fn(key, coords, mask, res_idx, chain_idx)
- prxteinmpnn.sampling.sample(prng_key, model, structure_coordinates, mask, residue_index, chain_index, **kwargs)[source]#
Sample sequences from a structure using the default temperature sampler.
This is a convenience wrapper around make_sample_sequences.
- Parameters:
prng_key (
Array) – JAX random key.model (
PrxteinMPNN) – A PrxteinMPNN Equinox model instance.structure_coordinates (
Array) – Atomic coordinates (N, 4, 3).mask (
Array) – Alpha carbon mask indicating valid residues.residue_index (
Array) – Residue indices.chain_index (
Array) – Chain indices.**kwargs (
Any) – Additional keyword arguments for the sampler.
- Return type:
- Returns:
Tuple of (sampled sequence, logits, decoding order).