sample#
Factory for creating sequence sampling functions in ProteinMPNN.
prxteinmpnn.sampling.factory
- prxteinmpnn.sampling.sample.make_sample_sequences(model_parameters, decoding_order_fn, config, num_encoder_layers=3, num_decoder_layers=3, model_inputs=None, sampling_inputs=None)[source]#
Create a function to sample sequences from a structure using ProteinMPNN.
- Parameters:
model_parameters (ModelParameters) – Pre-trained ProteinMPNN model parameters.
decoding_order_fn (DecodingOrderFn) – Function to generate decoding order.
config (SamplingConfig) – Configuration for sampling, including strategy and hyperparameters.
num_encoder_layers (int) – Number of encoder layers. Default is 3.
num_decoder_layers (int) – Number of decoder layers. Default is 3.
model_inputs (ModelInputs | None) – Optional model inputs for sampling. Output function signature requires prng_key, bias, k_neighbors, augment_eps, hyperparameters, and iterations.
sampling_inputs (SamplingInputs | None) – Optional sampling inputs for sequence sampling. Output function signature does not require any arguments, as it uses the attributes of sampling_inputs.
optimizer (optax.GradientTransformation | None) – Optional optimizer for straight-through estimator. If provided, the sampling function will optimize the logits using this optimizer. If not provided, the sampling function will not perform optimization.
provided (If both model_inputs and sampling_inputs are)
used. (sampling_inputs will be)
- Return type:
Callable
[[Union
[Key[Array, '']
,UInt32[Array, '2']
],Int[Array, 'num_residues']
,Float[Array, 'num_residues num_atoms 3']
,Int[Array, 'num_residues num_atoms']
,Int[Array, 'num_residues']
,Int[Array, 'num_residues']
,Float[Array, 'num_residues num_classes']
|None
,int
,float
,tuple
[float
|int
|Array
|GradientTransformation
,...
],int
],tuple
[Int[Array, 'num_residues']
,Float[Array, 'num_residues num_classes']
,Int[Array, 'num_residues']
]] |Callable
[[Union
[Key[Array, '']
,UInt32[Array, '2']
],tuple
[float
|int
|Array
|GradientTransformation
,...
],int
],tuple
[Int[Array, 'num_residues']
,Float[Array, 'num_residues num_classes']
,Int[Array, 'num_residues']
]] |Callable
[[],tuple
[Int[Array, 'num_residues']
,Float[Array, 'num_residues num_classes']
,Int[Array, 'num_residues']
]]- Returns:
A function that samples sequences given structural inputs.