sampling module#

Sampling utilities for PrXteinMPNN.

class prxteinmpnn.sampling.SamplingConfig(sampling_strategy, iterations=1, temperature=1.0, target_logits=None, learning_rate=0.1)[source]#

Bases: object

Configuration for sequence sampling.

Parameters:
  • sampling_strategy (SamplingEnum)

  • iterations (int)

  • temperature (float)

  • target_logits (Float[Array, 'num_residues num_classes'] | None)

  • learning_rate (float)

iterations: int = 1#
learning_rate: float = 0.1#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

target_logits: Float[Array, 'num_residues num_classes'] | None = None#
temperature: float = 1.0#
sampling_strategy: SamplingEnum#
class prxteinmpnn.sampling.SamplingEnum(value)[source]#

Bases: Enum

Enum for different sampling strategies.

GREEDY = 'greedy'#
TOP_K = 'top_k'#
TOP_P = 'top_p'#
TEMPERATURE = 'temperature'#
STRAIGHT_THROUGH = 'straight_through'#
prxteinmpnn.sampling.make_sample_sequences(model_parameters, decoding_order_fn, config, num_encoder_layers=3, num_decoder_layers=3, model_inputs=None, sampling_inputs=None)[source]#

Create a function to sample sequences from a structure using ProteinMPNN.

Parameters:
  • model_parameters (ModelParameters) – Pre-trained ProteinMPNN model parameters.

  • decoding_order_fn (DecodingOrderFn) – Function to generate decoding order.

  • config (SamplingConfig) – Configuration for sampling, including strategy and hyperparameters.

  • num_encoder_layers (int) – Number of encoder layers. Default is 3.

  • num_decoder_layers (int) – Number of decoder layers. Default is 3.

  • model_inputs (ModelInputs | None) – Optional model inputs for sampling. Output function signature requires prng_key, bias, k_neighbors, augment_eps, hyperparameters, and iterations.

  • sampling_inputs (SamplingInputs | None) – Optional sampling inputs for sequence sampling. Output function signature does not require any arguments, as it uses the attributes of sampling_inputs.

  • optimizer (optax.GradientTransformation | None) – Optional optimizer for straight-through estimator. If provided, the sampling function will optimize the logits using this optimizer. If not provided, the sampling function will not perform optimization.

  • provided (If both model_inputs and sampling_inputs are)

  • used. (sampling_inputs will be)

Return type:

Callable[[Union[Key[Array, ''], UInt32[Array, '2']], Int[Array, 'num_residues'], Float[Array, 'num_residues num_atoms 3'], Int[Array, 'num_residues num_atoms'], Int[Array, 'num_residues'], Int[Array, 'num_residues'], Float[Array, 'num_residues num_classes'] | None, int, float, tuple[float | int | Array | GradientTransformation, ...], int], tuple[Int[Array, 'num_residues'], Float[Array, 'num_residues num_classes'], Int[Array, 'num_residues']]] | Callable[[Union[Key[Array, ''], UInt32[Array, '2']], tuple[float | int | Array | GradientTransformation, ...], int], tuple[Int[Array, 'num_residues'], Float[Array, 'num_residues num_classes'], Int[Array, 'num_residues']]] | Callable[[], tuple[Int[Array, 'num_residues'], Float[Array, 'num_residues num_classes'], Int[Array, 'num_residues']]]

Returns:

A function that samples sequences given structural inputs.