sampling_step#

Defines the single-pass autoregressive sampling step.

prxteinmpnn.sampling.sampling_step.temperature_sample(decoder, sample_model_pass_fn, temperature, bias, prng_key)[source]#

Single autoregressive sampling step with temperature scaling.

Return type:

tuple[Union[Key[Array, ''], UInt32[Array, '2']], Int[Array, 'num_residues'], Float[Array, 'num_residues num_classes']]

Parameters:
  • decoder (Callable[[Unpack[tuple[Key[Array, ''] | UInt32[Array, '2'], Int[Array, 'num_atoms num_features'], Float[Array, 'num_atoms num_neighbors num_features'], Int[Array, 'num_atoms num_neighbors'], Int[Array, 'num_residues num_atoms'], Bool[Array, 'num_residues num_residues'], Float | None, Float[Array, 'num_residues num_classes'] | None]]], tuple[Float[Array, 'num_residues num_classes'], Float[Array, 'num_residues num_classes']]])

  • sample_model_pass_fn (partial[tuple[Int[Array, 'num_atoms num_features'], Float[Array, 'num_atoms num_neighbors num_features'], Int[Array, 'num_atoms num_neighbors'], Int[Array, 'num_residues num_atoms'], Bool[Array, 'num_residues num_residues'], Key[Array, ''] | UInt32[Array, '2']]])

  • temperature (Float | None)

  • bias (Float[Array, 'num_residues num_classes'] | None)

  • prng_key (Key[Array, ''] | UInt32[Array, '2'])

prxteinmpnn.sampling.sampling_step.preload_sampling_step_decoder(decoder, sample_model_pass_fn, sampling_strategy, temperature=None)[source]#

Preloads the sampling step decoder for the specified strategy.

Return type:

partial[tuple[Union[Key[Array, ''], UInt32[Array, '2']], Int[Array, 'num_residues'], Float[Array, 'num_residues num_classes']]]

Parameters:
  • decoder (Callable[[Unpack[tuple[Key[Array, ''] | UInt32[Array, '2'], Int[Array, 'num_atoms num_features'], Float[Array, 'num_atoms num_neighbors num_features'], Int[Array, 'num_atoms num_neighbors'], Int[Array, 'num_residues num_atoms'], Bool[Array, 'num_residues num_residues'], Float | None, Float[Array, 'num_residues num_classes'] | None]]], tuple[Float[Array, 'num_residues num_classes'], Float[Array, 'num_residues num_classes']]])

  • sample_model_pass_fn (partial[tuple[Int[Array, 'num_atoms num_features'], Float[Array, 'num_atoms num_neighbors num_features'], Int[Array, 'num_atoms num_neighbors'], Int[Array, 'num_residues num_atoms'], Bool[Array, 'num_residues num_residues'], Key[Array, ''] | UInt32[Array, '2']]])

  • sampling_strategy (Literal['temperature', 'straight_through'])

  • temperature (Float | None)