Source code for prxteinmpnn.sampling.config

"""Configuration for sequence sampling in the PrxteinMPNN project."""

import enum

from flax.struct import dataclass

from prxteinmpnn.utils.types import Logits


[docs] class SamplingEnum(enum.Enum): """Enum for different sampling strategies.""" GREEDY = "greedy" TOP_K = "top_k" TOP_P = "top_p" TEMPERATURE = "temperature" BEAM_SEARCH = "beam_search" STRAIGHT_THROUGH = "straight_through"
[docs] @dataclass(frozen=True) class SamplingConfig: """Configuration for sequence sampling.""" # Static parameters that control the computation graph sampling_strategy: SamplingEnum iterations: int = 1 temperature: float = 1.0 target_logits: Logits | None = None learning_rate: float = 0.1