Source code for prxteinmpnn.sampling.sampling_step

"""Defines the single-pass autoregressive sampling step."""

from functools import partial
from typing import Literal

import jax.numpy as jnp
from jaxtyping import Float, PRNGKeyArray

from prxteinmpnn.model.decoding_signatures import RunAutoregressiveDecoderFn
from prxteinmpnn.utils.types import Logits, ProteinSequence

from .initialize import SamplingModelPassOutput

SampleModelPassFn = partial[SamplingModelPassOutput]
SamplingStepState = tuple[PRNGKeyArray, ProteinSequence, Logits]
SamplingStepFn = partial[SamplingStepState]


[docs] def temperature_sample( decoder: RunAutoregressiveDecoderFn, sample_model_pass_fn: SampleModelPassFn, temperature: Float | None, bias: Logits | None, prng_key: PRNGKeyArray, ) -> SamplingStepState: """Single autoregressive sampling step with temperature scaling.""" ( node_features, edge_features, neighbor_indices, mask, autoregressive_mask, decoding_key, ) = sample_model_pass_fn(prng_key=prng_key) output_sequence_one_hot, logits = decoder( decoding_key, node_features, edge_features, neighbor_indices, mask, autoregressive_mask, temperature, bias, ) output_sequence = output_sequence_one_hot.argmax(axis=-1).astype(jnp.int8) return prng_key, output_sequence, logits
[docs] def preload_sampling_step_decoder( decoder: RunAutoregressiveDecoderFn, sample_model_pass_fn: SampleModelPassFn, sampling_strategy: Literal["temperature", "straight_through"], temperature: Float | None = None, ) -> SamplingStepFn: """Preloads the sampling step decoder for the specified strategy.""" if sampling_strategy == "temperature": return partial( temperature_sample, decoder=decoder, sample_model_pass_fn=sample_model_pass_fn, temperature=temperature, ) # No other sampling strategies are supported in this simplified version msg = f"Unsupported sampling strategy: {sampling_strategy}" raise NotImplementedError(msg)