sampling_step#
Sample sequences from a structure using the ProteinMPNN model.
prxteinmpnn.sampling.sampling
- prxteinmpnn.sampling.sampling_step.temperature_sample(prng_key, decoder, sample_model_pass_fn_only_prng, temperature)[source]#
Single autoregressive sampling step with temperature scaling.
- Parameters:
prng_key (
Union
[Key[Array, '']
,UInt32[Array, '2']
]) – Random key for JAX operations.decoder (
Callable
[[Unpack
[tuple
[Union
[Key[Array, '']
,UInt32[Array, '2']
],Int[Array, 'num_atoms num_features']
,Float[Array, 'num_atoms num_neighbors num_features']
,Int[Array, 'num_atoms num_neighbors']
,Int[Array, 'num_residues num_atoms']
,Bool[Array, 'num_residues num_residues']
,float
]]],tuple
[Float[Array, 'num_residues num_classes']
,Float[Array, 'num_residues num_classes']
]]) – Decoder function to update node features. Preloaded autoregressive decoder.sample_model_pass_fn_only_prng (
Callable
[[Union
[Key[Array, '']
,UInt32[Array, '2']
]],tuple
[Int[Array, 'num_atoms num_features']
,Float[Array, 'num_atoms num_neighbors num_features']
,Int[Array, 'num_atoms num_neighbors']
,Int[Array, 'num_residues']
,Bool[Array, 'num_residues num_residues']
,Union
[Key[Array, '']
,UInt32[Array, '2']
]]]) – Function to run a single pass through the model.temperature (
float
) – Temperature for scaling logits.
- Return type:
tuple
[Union
[Key[Array, '']
,UInt32[Array, '2']
],Float[Array, 'num_residues num_neighbors num_features']
,Int[Array, 'num_atoms num_features']
,Int[Array, 'num_residues']
,Float[Array, 'num_residues num_classes']
]- Returns:
Updated carry state and None for scan output.
Example
carry = (rng_key, edge_features, node_features, sequence, logits) sample_step = partial(
sample_temperature_step, decoder=decoder, neighbor_indices=neighbor_indices, mask=mask, autoregressive_mask=autoregressive_mask, model_parameters=model_parameters, temperature=temperature,
) final_carry, _ = jax.lax.fori_loop(
0, iterations, sample_step, carry,
)
- prxteinmpnn.sampling.sampling_step.ste_sample(initial_carry, decoder, sample_model_pass_fn_only_prng, model_parameters, learning_rate, target_logits, iterations)[source]#
Single autoregressive sampling step with straight-through estimator.
- Parameters:
initial_carry (
tuple
[Union
[Key[Array, '']
,UInt32[Array, '2']
],Float[Array, 'num_residues num_neighbors num_features']
,Int[Array, 'num_atoms num_features']
,Int[Array, 'num_residues']
,Float[Array, 'num_residues num_classes']
]) – Tuple containing initial state (rng_key, edge_features, node_features, sequence, logits).decoder (
Callable
[[Unpack
[tuple
[Int[Array, 'num_atoms num_features']
,Float[Array, 'num_atoms num_neighbors num_features']
,Int[Array, 'num_atoms num_neighbors']
,Int[Array, 'num_residues num_atoms']
,Bool[Array, 'num_residues num_residues']
,Int[Array, 'num_residues']
]]],Int[Array, 'num_atoms num_features']
]) – Decoder function to update node features.model_parameters (
PyTree[str, 'P']
) – Model parameters for the model.sample_model_pass_fn_only_prng (
Callable
[[Union
[Key[Array, '']
,UInt32[Array, '2']
]],tuple
[Int[Array, 'num_atoms num_features']
,Float[Array, 'num_atoms num_neighbors num_features']
,Int[Array, 'num_atoms num_neighbors']
,Int[Array, 'num_residues']
,Bool[Array, 'num_residues num_residues']
,Union
[Key[Array, '']
,UInt32[Array, '2']
]]]) – Function to run a single pass through the model.learning_rate (
float
) – Learning rate for updating logits.target_logits (
Float[Array, 'num_residues num_classes']
) – Target logits for the straight-through estimator.iterations (
int
) – Number of iterations for the straight-through estimator.
- Return type:
tuple
[Union
[Key[Array, '']
,UInt32[Array, '2']
],Float[Array, 'num_residues num_neighbors num_features']
,Int[Array, 'num_atoms num_features']
,Int[Array, 'num_residues']
,Float[Array, 'num_residues num_classes']
]- Returns:
Updated carry state and None for scan output.
Example
carry = (rng_key, edge_features, node_features, sequence, logits) sample_step = partial(
sample_straight_through_estimator_step, decoder=decoder, neighbor_indices=neighbor_indices, mask=mask, autoregressive_mask=autoregressive_mask, model_parameters=model_parameters, learning_rate=learning_rate,
) final_carry, _ = jax.lax.fori_loop(
0, iterations, sample_step, carry,
- prxteinmpnn.sampling.sampling_step.preload_sampling_step_decoder(decoder, sample_model_pass_fn_only_prng, model_parameters, sampling_config)[source]#
Preload the sampling step decoder.
- Return type:
Callable
[[Unpack
[tuple
[Any
,...
]]],tuple
[Union
[Key[Array, '']
,UInt32[Array, '2']
],Float[Array, 'num_residues num_neighbors num_features']
,Int[Array, 'num_atoms num_features']
,Int[Array, 'num_residues']
,Float[Array, 'num_residues num_classes']
]]- Parameters:
decoder (Callable[[Unpack[tuple[Int[Array, 'num_atoms num_features'], Float[Array, 'num_atoms num_neighbors num_features'], Int[Array, 'num_atoms num_neighbors'], Int[Array, 'num_residues num_atoms'], Bool[Array, 'num_residues num_residues'], Int[Array, 'num_residues']]]], Int[Array, 'num_atoms num_features']] | Callable[[Unpack[tuple[Key[Array, ''] | UInt32[Array, '2'], Int[Array, 'num_atoms num_features'], Float[Array, 'num_atoms num_neighbors num_features'], Int[Array, 'num_atoms num_neighbors'], Int[Array, 'num_residues num_atoms'], Bool[Array, 'num_residues num_residues'], float]]], tuple[Float[Array, 'num_residues num_classes'], Float[Array, 'num_residues num_classes']]])
sample_model_pass_fn_only_prng (Callable[[Key[Array, ''] | UInt32[Array, '2']], tuple[Int[Array, 'num_atoms num_features'], Float[Array, 'num_atoms num_neighbors num_features'], Int[Array, 'num_atoms num_neighbors'], Int[Array, 'num_residues'], Bool[Array, 'num_residues num_residues'], Key[Array, ''] | UInt32[Array, '2']]])
model_parameters (PyTree[str, 'P'])
sampling_config (SamplingConfig)