sampling module#
Sampling utilities for PrXteinMPNN.
- class prxteinmpnn.sampling.SamplingConfig(sampling_strategy, iterations=1, temperature=1.0, target_logits=None, learning_rate=0.1)[source]#
Bases:
object
Configuration for sequence sampling.
- Parameters:
sampling_strategy (SamplingEnum)
iterations (int)
temperature (float)
target_logits (Float[Array, 'num_residues num_classes'] | None)
learning_rate (float)
- replace(**updates)#
Returns a new object replacing the specified fields with new values.
-
sampling_strategy:
SamplingEnum
#
- class prxteinmpnn.sampling.SamplingEnum(value)[source]#
Bases:
Enum
Enum for different sampling strategies.
- GREEDY = 'greedy'#
- TOP_K = 'top_k'#
- TOP_P = 'top_p'#
- TEMPERATURE = 'temperature'#
- BEAM_SEARCH = 'beam_search'#
- STRAIGHT_THROUGH = 'straight_through'#
- prxteinmpnn.sampling.make_sample_sequences(model_parameters, decoding_order_fn, config, num_encoder_layers=3, num_decoder_layers=3, model_inputs=None, sampling_inputs=None)[source]#
Create a function to sample sequences from a structure using ProteinMPNN.
- Parameters:
model_parameters (ModelParameters) – Pre-trained ProteinMPNN model parameters.
decoding_order_fn (DecodingOrderFn) – Function to generate decoding order.
config (SamplingConfig) – Configuration for sampling, including strategy and hyperparameters.
num_encoder_layers (int) – Number of encoder layers. Default is 3.
num_decoder_layers (int) – Number of decoder layers. Default is 3.
model_inputs (ModelInputs | None) – Optional model inputs for sampling. Output function signature requires prng_key, bias, k_neighbors, augment_eps, hyperparameters, and iterations.
sampling_inputs (SamplingInputs | None) – Optional sampling inputs for sequence sampling. Output function signature does not require any arguments, as it uses the attributes of sampling_inputs.
optimizer (optax.GradientTransformation | None) – Optional optimizer for straight-through estimator. If provided, the sampling function will optimize the logits using this optimizer. If not provided, the sampling function will not perform optimization.
provided (If both model_inputs and sampling_inputs are)
used. (sampling_inputs will be)
- Return type:
Callable
[[Union
[Key[Array, '']
,UInt32[Array, '2']
],Int[Array, 'num_residues']
,Float[Array, 'num_residues num_atoms 3']
,Int[Array, 'num_residues num_atoms']
,Int[Array, 'num_residues']
,Int[Array, 'num_residues']
,Float[Array, 'num_residues num_classes']
|None
,int
,float
,tuple
[float
|int
|Array
|GradientTransformation
,...
],int
],tuple
[Int[Array, 'num_residues']
,Float[Array, 'num_residues num_classes']
,Int[Array, 'num_residues']
]] |Callable
[[Union
[Key[Array, '']
,UInt32[Array, '2']
],tuple
[float
|int
|Array
|GradientTransformation
,...
],int
],tuple
[Int[Array, 'num_residues']
,Float[Array, 'num_residues num_classes']
,Int[Array, 'num_residues']
]] |Callable
[[],tuple
[Int[Array, 'num_residues']
,Float[Array, 'num_residues num_classes']
,Int[Array, 'num_residues']
]]- Returns:
A function that samples sequences given structural inputs.