sampling_step#

Sample sequences from a structure using the ProteinMPNN model.

prxteinmpnn.sampling.sampling

prxteinmpnn.sampling.sampling_step.temperature_sample(prng_key, decoder, sample_model_pass_fn_only_prng, temperature)[source]#

Single autoregressive sampling step with temperature scaling.

Parameters:
  • prng_key (Union[Key[Array, ''], UInt32[Array, '2']]) – Random key for JAX operations.

  • decoder (Callable[[Unpack[tuple[Union[Key[Array, ''], UInt32[Array, '2']], Int[Array, 'num_atoms num_features'], Float[Array, 'num_atoms num_neighbors num_features'], Int[Array, 'num_atoms num_neighbors'], Int[Array, 'num_residues num_atoms'], Bool[Array, 'num_residues num_residues'], float]]], tuple[Float[Array, 'num_residues num_classes'], Float[Array, 'num_residues num_classes']]]) – Decoder function to update node features. Preloaded autoregressive decoder.

  • sample_model_pass_fn_only_prng (Callable[[Union[Key[Array, ''], UInt32[Array, '2']]], tuple[Int[Array, 'num_atoms num_features'], Float[Array, 'num_atoms num_neighbors num_features'], Int[Array, 'num_atoms num_neighbors'], Int[Array, 'num_residues'], Bool[Array, 'num_residues num_residues'], Union[Key[Array, ''], UInt32[Array, '2']]]]) – Function to run a single pass through the model.

  • temperature (float) – Temperature for scaling logits.

Return type:

tuple[Union[Key[Array, ''], UInt32[Array, '2']], Float[Array, 'num_residues num_neighbors num_features'], Int[Array, 'num_atoms num_features'], Int[Array, 'num_residues'], Float[Array, 'num_residues num_classes']]

Returns:

Updated carry state and None for scan output.

Example

carry = (rng_key, edge_features, node_features, sequence, logits) sample_step = partial(

sample_temperature_step, decoder=decoder, neighbor_indices=neighbor_indices, mask=mask, autoregressive_mask=autoregressive_mask, model_parameters=model_parameters, temperature=temperature,

) final_carry, _ = jax.lax.fori_loop(

0, iterations, sample_step, carry,

)

prxteinmpnn.sampling.sampling_step.ste_sample(initial_carry, decoder, sample_model_pass_fn_only_prng, model_parameters, learning_rate, target_logits, iterations)[source]#

Single autoregressive sampling step with straight-through estimator.

Parameters:
  • initial_carry (tuple[Union[Key[Array, ''], UInt32[Array, '2']], Float[Array, 'num_residues num_neighbors num_features'], Int[Array, 'num_atoms num_features'], Int[Array, 'num_residues'], Float[Array, 'num_residues num_classes']]) – Tuple containing initial state (rng_key, edge_features, node_features, sequence, logits).

  • decoder (Callable[[Unpack[tuple[Int[Array, 'num_atoms num_features'], Float[Array, 'num_atoms num_neighbors num_features'], Int[Array, 'num_atoms num_neighbors'], Int[Array, 'num_residues num_atoms'], Bool[Array, 'num_residues num_residues'], Int[Array, 'num_residues']]]], Int[Array, 'num_atoms num_features']]) – Decoder function to update node features.

  • model_parameters (PyTree[str, 'P']) – Model parameters for the model.

  • sample_model_pass_fn_only_prng (Callable[[Union[Key[Array, ''], UInt32[Array, '2']]], tuple[Int[Array, 'num_atoms num_features'], Float[Array, 'num_atoms num_neighbors num_features'], Int[Array, 'num_atoms num_neighbors'], Int[Array, 'num_residues'], Bool[Array, 'num_residues num_residues'], Union[Key[Array, ''], UInt32[Array, '2']]]]) – Function to run a single pass through the model.

  • learning_rate (float) – Learning rate for updating logits.

  • target_logits (Float[Array, 'num_residues num_classes']) – Target logits for the straight-through estimator.

  • iterations (int) – Number of iterations for the straight-through estimator.

Return type:

tuple[Union[Key[Array, ''], UInt32[Array, '2']], Float[Array, 'num_residues num_neighbors num_features'], Int[Array, 'num_atoms num_features'], Int[Array, 'num_residues'], Float[Array, 'num_residues num_classes']]

Returns:

Updated carry state and None for scan output.

Example

carry = (rng_key, edge_features, node_features, sequence, logits) sample_step = partial(

sample_straight_through_estimator_step, decoder=decoder, neighbor_indices=neighbor_indices, mask=mask, autoregressive_mask=autoregressive_mask, model_parameters=model_parameters, learning_rate=learning_rate,

) final_carry, _ = jax.lax.fori_loop(

0, iterations, sample_step, carry,

prxteinmpnn.sampling.sampling_step.preload_sampling_step_decoder(decoder, sample_model_pass_fn_only_prng, model_parameters, sampling_config)[source]#

Preload the sampling step decoder.

Return type:

Callable[[Unpack[tuple[Any, ...]]], tuple[Union[Key[Array, ''], UInt32[Array, '2']], Float[Array, 'num_residues num_neighbors num_features'], Int[Array, 'num_atoms num_features'], Int[Array, 'num_residues'], Float[Array, 'num_residues num_classes']]]

Parameters:
  • decoder (Callable[[Unpack[tuple[Int[Array, 'num_atoms num_features'], Float[Array, 'num_atoms num_neighbors num_features'], Int[Array, 'num_atoms num_neighbors'], Int[Array, 'num_residues num_atoms'], Bool[Array, 'num_residues num_residues'], Int[Array, 'num_residues']]]], Int[Array, 'num_atoms num_features']] | Callable[[Unpack[tuple[Key[Array, ''] | UInt32[Array, '2'], Int[Array, 'num_atoms num_features'], Float[Array, 'num_atoms num_neighbors num_features'], Int[Array, 'num_atoms num_neighbors'], Int[Array, 'num_residues num_atoms'], Bool[Array, 'num_residues num_residues'], float]]], tuple[Float[Array, 'num_residues num_classes'], Float[Array, 'num_residues num_classes']]])

  • sample_model_pass_fn_only_prng (Callable[[Key[Array, ''] | UInt32[Array, '2']], tuple[Int[Array, 'num_atoms num_features'], Float[Array, 'num_atoms num_neighbors num_features'], Int[Array, 'num_atoms num_neighbors'], Int[Array, 'num_residues'], Bool[Array, 'num_residues num_residues'], Key[Array, ''] | UInt32[Array, '2']]])

  • model_parameters (PyTree[str, 'P'])

  • sampling_config (SamplingConfig)