sampling_step#
Defines the single-pass autoregressive sampling step.
- prxteinmpnn.sampling.sampling_step.temperature_sample(decoder, sample_model_pass_fn, temperature, bias, prng_key)[source]#
- Single autoregressive sampling step with temperature scaling. - Return type:
- tuple[- Union[- Key[Array, ''],- UInt32[Array, '2']],- Int[Array, 'num_residues'],- Float[Array, 'num_residues num_classes']]
- Parameters:
- decoder (Callable[[Unpack[tuple[Key[Array, ''] | UInt32[Array, '2'], Int[Array, 'num_atoms num_features'], Float[Array, 'num_atoms num_neighbors num_features'], Int[Array, 'num_atoms num_neighbors'], Int[Array, 'num_residues num_atoms'], Bool[Array, 'num_residues num_residues'], Float | None, Float[Array, 'num_residues num_classes'] | None]]], tuple[Float[Array, 'num_residues num_classes'], Float[Array, 'num_residues num_classes']]]) 
- sample_model_pass_fn (partial[tuple[Int[Array, 'num_atoms num_features'], Float[Array, 'num_atoms num_neighbors num_features'], Int[Array, 'num_atoms num_neighbors'], Int[Array, 'num_residues num_atoms'], Bool[Array, 'num_residues num_residues'], Key[Array, ''] | UInt32[Array, '2']]]) 
- temperature (Float | None) 
- bias (Float[Array, 'num_residues num_classes'] | None) 
- prng_key (Key[Array, ''] | UInt32[Array, '2']) 
 
 
- prxteinmpnn.sampling.sampling_step.preload_sampling_step_decoder(decoder, sample_model_pass_fn, sampling_strategy, temperature=None)[source]#
- Preloads the sampling step decoder for the specified strategy. - Return type:
- partial[- tuple[- Union[- Key[Array, ''],- UInt32[Array, '2']],- Int[Array, 'num_residues'],- Float[Array, 'num_residues num_classes']]]
- Parameters:
- decoder (Callable[[Unpack[tuple[Key[Array, ''] | UInt32[Array, '2'], Int[Array, 'num_atoms num_features'], Float[Array, 'num_atoms num_neighbors num_features'], Int[Array, 'num_atoms num_neighbors'], Int[Array, 'num_residues num_atoms'], Bool[Array, 'num_residues num_residues'], Float | None, Float[Array, 'num_residues num_classes'] | None]]], tuple[Float[Array, 'num_residues num_classes'], Float[Array, 'num_residues num_classes']]]) 
- sample_model_pass_fn (partial[tuple[Int[Array, 'num_atoms num_features'], Float[Array, 'num_atoms num_neighbors num_features'], Int[Array, 'num_atoms num_neighbors'], Int[Array, 'num_residues num_atoms'], Bool[Array, 'num_residues num_residues'], Key[Array, ''] | UInt32[Array, '2']]]) 
- sampling_strategy (Literal['temperature', 'straight_through']) 
- temperature (Float | None)