sampling module#

Sampling utilities for PrXteinMPNN.

prxteinmpnn.sampling.make_conditional_logits_fn(model)[source]#

Create a function to compute conditional logits for a given sequence.

Conditional logits evaluate how well a sequence fits a structure by running the model with the sequence as input.

Parameters:

model (PrxteinMPNN) – A PrxteinMPNN Equinox model instance.

Return type:

Callable[[Union[Key[Array, ''], UInt32[Array, '2']], Float[Array, 'num_residues num_atoms 3'], Int[Array, 'num_residues 3'], Int[Array, 'num_residues'], Int[Array, 'num_residues'], Int[Array, 'num_residues'], Bool[Array, 'num_residues num_residues'] | None, Float[Array, 'n'] | None], Float[Array, 'num_residues num_classes']]

Returns:

A function that computes conditional logits for sequence-structure pairs.

Example

>>> from prxteinmpnn.io.weights import load_model
>>> model = load_model()
>>> logits_fn = make_conditional_logits_fn(model)
>>> logits = logits_fn(key, coords, mask, res_idx, chain_idx, sequence)
prxteinmpnn.sampling.make_encoding_conditional_logits_split_fn(model)[source]#

Create separate encoding and decoding functions for averaged encodings.

This splits the model into two parts: 1. Encoding: Structure -> Encoder features (node_features, edge_features, neighbor_indices) 2. Decoding: (Encoder features, Sequence) -> Logits

This separation allows: - Averaging encoder features across multiple noise levels - Efficient jacobian computation by caching encoder output - Reusing encoder output for multiple sequence evaluations

Parameters:

model (PrxteinMPNN) – A PrxteinMPNN Equinox model instance.

Returns:

  • encode_fn: Computes encoder features from structure

  • decode_fn: Computes logits from cached features and sequence

Return type:

Tuple of (encode_fn, decode_fn) where

Example

>>> encode_fn, decode_fn = make_encoding_conditional_logits_split_fn(model)
>>> # Encode once
>>> key = jax.random.key(0)
>>> encoding = encode_fn(key, coords, mask, res_idx, chain_idx, noise=0.1)
>>> # Decode multiple sequences using same encoding
>>> logits1 = decode_fn(encoding, sequence1)
>>> logits2 = decode_fn(encoding, sequence2)
prxteinmpnn.sampling.make_encoding_sampling_split_fn(model_parameters, decoding_order_fn=None, sampling_strategy='temperature')[source]#

Create separate encoding and sampling functions for averaged encodings.

This splits the sampling process into two parts: 1. Encoding: Structure -> Encoder features (can be averaged across noise levels) 2. Sampling: (Encoder features, PRNGKey) -> Sequences

This separation allows: - Averaging encoder features across multiple noise levels - Efficient reuse of encoder output for multiple samples - Lower memory usage when sampling many sequences

Supports tied positions: when tie_group_map is provided, positions in the same group will be sampled together and receive identical amino acids.

Parameters:
  • model_parameters (PrxteinMPNN) – A PrxteinMPNN Equinox model instance.

  • decoding_order_fn (Callable[[Union[Key[Array, ''], UInt32[Array, '2']], int, Array | None, int | None], tuple[Int[Array, 'num_residues'], Union[Key[Array, ''], UInt32[Array, '2']]]] | None) – Function to generate decoding order (default: random).

  • sampling_strategy (Literal['temperature', 'straight_through']) – Sampling strategy - “temperature” or “straight_through”.

Returns:

  • encode_fn: Computes encoder features from structure

  • sample_fn: Samples sequences from cached encoder features

Return type:

Tuple of (encode_fn, sample_fn) where

Example

>>> from prxteinmpnn.io.weights import load_model
>>> model = load_model()
>>> encode_fn, sample_fn = make_encoding_sampling_split_fn(
...     model, sampling_strategy="temperature"
... )
>>> # Encode once
>>> encoding = encode_fn(
...     key, coords, mask, res_idx, chain_idx,
...     k_neighbors=48, backbone_noise=0.1
... )
>>> # Sample multiple sequences using same encoding
>>> seq1 = sample_fn(key1, encoding, order1, temperature=0.1)
>>> seq2 = sample_fn(key2, encoding, order2, temperature=0.5)
prxteinmpnn.sampling.make_sample_sequences(model, decoding_order_fn=<PjitFunction of <function random_decoding_order>>, sampling_strategy='temperature', _num_encoder_layers=3, _num_decoder_layers=3)[source]#

Create a function to sample sequences from a structure using PrxteinMPNN.

Parameters:
  • model (PrxteinMPNN) – A PrxteinMPNN Equinox model instance.

  • decoding_order_fn (Callable[[Union[Key[Array, ''], UInt32[Array, '2']], int, Array | None, int | None], tuple[Int[Array, 'num_residues'], Union[Key[Array, ''], UInt32[Array, '2']]]]) – Function to generate decoding order (default: random). Should accept (key, num_residues, tie_group_map, num_groups).

  • sampling_strategy (Literal['temperature', 'straight_through']) – Sampling strategy - “temperature” or “straight_through”.

  • _num_encoder_layers (int) – Deprecated, ignored (kept for API compatibility).

  • _num_decoder_layers (int) – Deprecated, ignored (kept for API compatibility).

Return type:

Callable[..., tuple[Int[Array, 'num_residues'], Float[Array, 'num_residues num_classes'], Int[Array, 'num_residues']]]

Returns:

A function that samples sequences from structures.

Example

>>> from prxteinmpnn.io.weights import load_model
>>> model = load_model()
>>> sample_fn = make_sample_sequences(model, sampling_strategy="temperature")
>>> seq, logits, order = sample_fn(key, coords, mask, res_idx, chain_idx)
>>>
>>> # With tied positions
>>> tie_map = jnp.array([0, 0, 1, 1, 2])  # Positions 0-1 tied, 2-3 tied
>>> seq, logits, order = sample_fn(
...     key, coords, mask, res_idx, chain_idx,
...     tie_group_map=tie_map, num_groups=3
... )
>>>
>>> # For optimization
>>> optimize_fn = make_sample_sequences(model, sampling_strategy="straight_through")
>>> seq, logits, order = optimize_fn(
...     key, coords, mask, res_idx, chain_idx,
...     iterations=100, learning_rate=0.01
... )
prxteinmpnn.sampling.make_unconditional_logits_fn(model)[source]#

Create a function to compute unconditional logits from a structure.

Unconditional logits are computed without sequence input, predicting the most likely amino acids at each position based purely on structure.

Parameters:

model (PrxteinMPNN) – A PrxteinMPNN Equinox model instance.

Return type:

Callable[[Union[Key[Array, ''], UInt32[Array, '2']], Float[Array, 'num_residues num_atoms 3'], Int[Array, 'num_residues 3'], Int[Array, 'num_residues'], Int[Array, 'num_residues'], Bool[Array, 'num_residues num_residues'] | None, Float[Array, 'n'] | None], Float[Array, 'num_residues num_classes']]

Returns:

A function that computes unconditional logits from structures.

Example

>>> from prxteinmpnn.io.weights import load_model
>>> model = load_model()
>>> logits_fn = make_unconditional_logits_fn(model)
>>> logits = logits_fn(key, coords, mask, res_idx, chain_idx)
prxteinmpnn.sampling.sample(prng_key, model, structure_coordinates, mask, residue_index, chain_index, **kwargs)[source]#

Sample sequences from a structure using the default temperature sampler.

This is a convenience wrapper around make_sample_sequences.

Parameters:
  • prng_key (Array) – JAX random key.

  • model (PrxteinMPNN) – A PrxteinMPNN Equinox model instance.

  • structure_coordinates (Array) – Atomic coordinates (N, 4, 3).

  • mask (Array) – Alpha carbon mask indicating valid residues.

  • residue_index (Array) – Residue indices.

  • chain_index (Array) – Chain indices.

  • **kwargs (Any) – Additional keyword arguments for the sampler.

Return type:

tuple[Array, Array, Array]

Returns:

Tuple of (sampled sequence, logits, decoding order).