sample

sample#

Factory for creating sequence sampling functions in ProteinMPNN.

prxteinmpnn.sampling.factory

prxteinmpnn.sampling.sample.make_sample_sequences(model_parameters, decoding_order_fn, config, num_encoder_layers=3, num_decoder_layers=3, model_inputs=None, sampling_inputs=None)[source]#

Create a function to sample sequences from a structure using ProteinMPNN.

Parameters:
  • model_parameters (ModelParameters) – Pre-trained ProteinMPNN model parameters.

  • decoding_order_fn (DecodingOrderFn) – Function to generate decoding order.

  • config (SamplingConfig) – Configuration for sampling, including strategy and hyperparameters.

  • num_encoder_layers (int) – Number of encoder layers. Default is 3.

  • num_decoder_layers (int) – Number of decoder layers. Default is 3.

  • model_inputs (ModelInputs | None) – Optional model inputs for sampling. Output function signature requires prng_key, bias, k_neighbors, augment_eps, hyperparameters, and iterations.

  • sampling_inputs (SamplingInputs | None) – Optional sampling inputs for sequence sampling. Output function signature does not require any arguments, as it uses the attributes of sampling_inputs.

  • optimizer (optax.GradientTransformation | None) – Optional optimizer for straight-through estimator. If provided, the sampling function will optimize the logits using this optimizer. If not provided, the sampling function will not perform optimization.

  • provided (If both model_inputs and sampling_inputs are)

  • used. (sampling_inputs will be)

Return type:

Callable[[Union[Key[Array, ''], UInt32[Array, '2']], Int[Array, 'num_residues'], Float[Array, 'num_residues num_atoms 3'], Int[Array, 'num_residues num_atoms'], Int[Array, 'num_residues'], Int[Array, 'num_residues'], Float[Array, 'num_residues num_classes'] | None, int, float, tuple[float | int | Array | GradientTransformation, ...], int], tuple[Int[Array, 'num_residues'], Float[Array, 'num_residues num_classes'], Int[Array, 'num_residues']]] | Callable[[Union[Key[Array, ''], UInt32[Array, '2']], tuple[float | int | Array | GradientTransformation, ...], int], tuple[Int[Array, 'num_residues'], Float[Array, 'num_residues num_classes'], Int[Array, 'num_residues']]] | Callable[[], tuple[Int[Array, 'num_residues'], Float[Array, 'num_residues num_classes'], Int[Array, 'num_residues']]]

Returns:

A function that samples sequences given structural inputs.