Source code for prxteinmpnn.run.specs
"""Core user interface for the PrxteinMPNN package."""
from __future__ import annotations
import logging
import multiprocessing as mp
import sys
from collections.abc import Sequence
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal
mp.set_start_method("spawn", force=True)
if TYPE_CHECKING:
  from io import StringIO
  from jaxtyping import ArrayLike
  from prxteinmpnn.ensemble.dbscan import ConformationalStates
  from prxteinmpnn.mpnn import ModelVersion, ModelWeights
  from prxteinmpnn.utils.catjac import CombineCatJacPairFn
  from prxteinmpnn.utils.decoding_order import DecodingOrderFn
  from prxteinmpnn.utils.foldcomp_utils import FoldCompDatabase
AlignmentStrategy = Literal["sequence", "structure"]
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, stream=sys.stdout, force=True)
def _loader_inputs(inputs: Sequence[str | StringIO] | str | StringIO) -> Sequence[str | StringIO]:
  return (inputs,) if not isinstance(inputs, Sequence) else inputs
[docs]
@dataclass
class RunSpecification:
  """Configuration for running the model.
  Attributes:
      inputs: A sequence of input file paths or StringIO objects, or a single input.
      model_weights: The model weights to use (default is "original").
      model_version: The model version to use (default is "v_48_020.pkl").
      batch_size: The batch size to use (default is 32).
      backbone_noise: The backbone noise levels to use (default is (0.0,)).
                      Can be a single float or a sequence of floats.
      foldcomp_database: An optional path to a FoldComp database (default is None).
      num_workers: The number of worker processes for data loading (default is 0).
      ar_mask: An optional array-like mask for autoregressive positions (default is None).
      random_seed: The random seed to use (default is 42).
      chain_id: An optional chain ID to use (default is None).
      model: An optional model ID to use (default is None).
      altloc: The alternate location to use (default is "first").
      decoding_order_fn: An optional function to generate the decoding order (default is None).
      conformational_states: ConformationalStates to use for coarse graining the inference.
  """
  inputs: Sequence[str | StringIO] | str | StringIO
  topology: str | Path | None = None
  model_weights: ModelWeights = "original"
  model_version: ModelVersion = "v_48_020.pkl"
  batch_size: int = 32
  backbone_noise: Sequence[float] | float = (0.0,)
  foldcomp_database: FoldCompDatabase | None = None
  num_workers: int = 0
  ar_mask: None | ArrayLike = None
  random_seed: int = 42
  chain_id: Sequence[str] | str | None = None
  model: int | None = None
  altloc: Literal["first", "all"] = "first"
  decoding_order_fn: DecodingOrderFn | None = None
  conformational_states: ConformationalStates | None = None
  def __post_init__(self) -> None:
    """Post-initialization processing."""
    if isinstance(self.backbone_noise, float):
      object.__setattr__(self, "backbone_noise", (self.backbone_noise,)) 
[docs]
@dataclass
class ScoringSpecification(RunSpecification):
  """Configuration for scoring sequences.
  Attributes:
      sequences_to_score: A sequence of amino acid sequences to score.
      temperature: The temperature for scoring (default is 1.0).
      return_logits: Whether to return the raw logits (default is False).
      return_decoding_orders: Whether to return decoding orders (default is False).
      return_all_scores: Whether to return scores for all sequences (default is False).
      score_batch_size: The batch size for scoring sequences (default is 16).
      output_h5_path: Optional path to an HDF5 file for streaming output.
  """
  sequences_to_score: Sequence[str] = ()
  temperature: float = 1.0
  return_logits: bool = False
  return_decoding_orders: bool = False
  return_all_scores: bool = False
  score_batch_size: int = 16
  output_h5_path: str | Path | None = None
  def __post_init__(self) -> None:
    """Post-initialization processing."""
    super().__post_init__()
    if not self.sequences_to_score:
      msg = (
        "No sequences provided for scoring."
        "`sequences_to_score` must be a non-empty list of strings."
      )
      raise ValueError(msg)
    if self.output_h5_path and isinstance(self.output_h5_path, str):
      object.__setattr__(self, "output_h5_path", Path(self.output_h5_path)) 
[docs]
@dataclass
class SamplingSpecification(RunSpecification):
  """Configuration for sampling sequences."""
  num_samples: int = 1
  sampling_strategy: Literal["temperature", "straight_through"] = "temperature"
  temperature: float = 0.1
  bias: ArrayLike | None = None
  fixed_positions: ArrayLike | None = None
  iterations: int | None = None
  learning_rate: float | None = None
  output_h5_path: str | Path | None = None
  def __post_init__(self) -> None:
    """Post-initialization processing."""
    super().__post_init__()
    if self.sampling_strategy == "straight_through" and (
      self.iterations is None or self.learning_rate is None
    ):
      msg = "For 'straight_through' sampling, 'iterations' and 'learning_rate' must be provided."
      raise ValueError(msg)
    if self.output_h5_path and isinstance(self.output_h5_path, str):
      object.__setattr__(self, "output_h5_path", Path(self.output_h5_path)) 
[docs]
@dataclass
class JacobianSpecification(RunSpecification):
  """Configuration for computing categorical Jacobians."""
  noise_batch_size: int = 1
  jacobian_batch_size: int = 16
  combine: bool = False
  combine_batch_size: int = 8
  combine_noise_batch_size: int = 1
  combine_weights: ArrayLike | None = None
  combine_fn: CombineCatJacPairFn | None = None
  combine_fn_kwargs: dict[str, Any] | None = None
  output_h5_path: str | Path | None = None
  compute_apc: bool = True
  apc_batch_size: int = 8
  apc_residue_batch_size: int = 1000
  def __post_init__(self) -> None:
    """Post-initialization processing."""
    super().__post_init__()
    if self.output_h5_path and isinstance(self.output_h5_path, str):
      object.__setattr__(self, "output_h5_path", Path(self.output_h5_path)) 
@dataclass
class ConformationalInferenceSpecification(RunSpecification):
  """Configuration for deriving states from a protein ensemble.
  Attributes:
      output_h5_path: Optional path to an HDF5 file for streaming output.
      batch_size: The batch size for processing proteins (default is 8).
  """
  output_h5_path: str | Path | None = None
  batch_size: int = 8
  inference_strategy: Literal["unconditional", "conditional", "vmm", "coordinates"] = (
    "unconditional"
  )
  inference_features: Sequence[
    Literal["logits", "node_features", "edge_features", "backbone_coordinates", "full_coordinates"]
  ] = ("logits",)
  mode: Literal["global", "per"] = "global"
  gmm_n_components: int = 100
  eps_std_scale: float = 1.0
  min_cluster_weight: float = 0.01
  gmm_init: Literal["kmeans", "random"] = "kmeans"
  reference_sequence: str | None = None
  def __post_init__(self) -> None:
    """Post-initialization processing."""
    super().__post_init__()
    if self.output_h5_path and isinstance(self.output_h5_path, str):
      object.__setattr__(self, "output_h5_path", Path(self.output_h5_path))