"""Core user interface for the PrxteinMPNN package."""
from __future__ import annotations
import logging
import multiprocessing as mp
import sys
from collections.abc import Sequence
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal
from prxteinmpnn.io.weights import MODEL_VERSION, MODEL_WEIGHTS
mp.set_start_method("spawn", force=True)
if TYPE_CHECKING:
from io import StringIO
from jaxtyping import ArrayLike
from prxteinmpnn.ensemble.dbscan import ConformationalStates
from prxteinmpnn.utils.catjac import CombineCatJacPairFn
from prxteinmpnn.utils.decoding_order import DecodingOrderFn
from prxteinmpnn.utils.foldcomp_utils import FoldCompDatabase
# Type aliases for convenience
ModelWeights = MODEL_WEIGHTS
ModelVersion = MODEL_VERSION
AlignmentStrategy = Literal["sequence", "structure"]
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, stream=sys.stdout, force=True)
def _loader_inputs(inputs: Sequence[str | StringIO] | str | StringIO) -> Sequence[str | StringIO]:
return (inputs,) if not isinstance(inputs, Sequence) else inputs
[docs]
@dataclass
class RunSpecification:
"""Configuration for running the model.
Attributes:
inputs: A sequence of input file paths or StringIO objects, or a single input.
model_weights: The model weights to use (default is "original").
model_version: The model version to use (default is "v_48_020.pkl").
batch_size: The batch size to use (default is 32).
backbone_noise: The backbone noise levels to use (default is (0.0,)).
Can be a single float or a sequence of floats.
foldcomp_database: An optional path to a FoldComp database (default is None).
ar_mask: An optional array-like mask for autoregressive positions (default is None).
random_seed: The random seed to use (default is 42).
chain_id: An optional chain ID to use (default is None).
model: An optional model ID to use (default is None).
altloc: The alternate location to use (default is "first").
decoding_order_fn: An optional function to generate the decoding order (default is None).
conformational_states: ConformationalStates to use for coarse graining the inference.
"""
inputs: Sequence[str | StringIO] | str | StringIO
topology: str | Path | None = None
model_weights: ModelWeights = "original"
model_version: ModelVersion = "v_48_020"
batch_size: int = 32
backbone_noise: Sequence[float] | float = (0.0,)
foldcomp_database: FoldCompDatabase | None = None
ar_mask: None | ArrayLike = None
random_seed: int = 42
chain_id: Sequence[str] | str | None = None
model: int | None = None
altloc: Literal["first", "all"] = "first"
decoding_order_fn: DecodingOrderFn | None = None
conformational_states: ConformationalStates | None = None
cache_path: str | Path | None = None
overwrite_cache: bool = False
output_path: str | Path | None = None
# Tied-position logit averaging fields
tied_positions: Sequence[tuple[int, int]] | Literal["auto", "direct"] | None = None
pass_mode: Literal["inter", "intra"] = "intra" # noqa: S105 (not a password)
def __post_init__(self) -> None:
"""Post-initialization processing and validation for tied-position logit averaging."""
if isinstance(self.backbone_noise, float):
object.__setattr__(self, "backbone_noise", (self.backbone_noise,))
if self.cache_path and isinstance(self.cache_path, str):
object.__setattr__(self, "cache_path", Path(self.cache_path))
# Validation for tied_positions and pass_mode
if self.tied_positions in ("auto", "direct") and self.pass_mode != "inter": # noqa: S105
msg = (
f"If tied_positions is '{self.tied_positions}', pass_mode must be 'inter'. "
f"Got pass_mode='{self.pass_mode}'."
)
raise ValueError(msg)
[docs]
@dataclass
class ScoringSpecification(RunSpecification):
"""Configuration for scoring sequences.
Attributes:
sequences_to_score: A sequence of amino acid sequences to score.
temperature: The temperature for scoring (default is 1.0).
return_logits: Whether to return the raw logits (default is False).
return_decoding_orders: Whether to return decoding orders (default is False).
return_all_scores: Whether to return scores for all sequences (default is False).
score_batch_size: The batch size for scoring sequences (default is 16).
output_h5_path: Optional path to an HDF5 file for streaming output.
"""
sequences_to_score: Sequence[str] = ()
temperature: float = 1.0
return_logits: bool = False
return_decoding_orders: bool = False
return_all_scores: bool = False
score_batch_size: int = 16
output_h5_path: str | Path | None = None
def __post_init__(self) -> None:
"""Post-initialization processing."""
super().__post_init__()
if not self.sequences_to_score:
msg = (
"No sequences provided for scoring."
"`sequences_to_score` must be a non-empty list of strings."
)
raise ValueError(msg)
if self.output_h5_path and isinstance(self.output_h5_path, str):
object.__setattr__(self, "output_h5_path", Path(self.output_h5_path))
[docs]
@dataclass
class SamplingSpecification(RunSpecification):
"""Configuration for sampling sequences."""
num_samples: int = 1
sampling_strategy: Literal["temperature", "straight_through"] = "temperature"
temperature: float = 0.1
bias: ArrayLike | None = None
fixed_positions: ArrayLike | None = None
iterations: int | None = None
learning_rate: float | None = None
output_h5_path: str | Path | None = None
samples_batch_size: int = 16
noise_batch_size: int = 4
average_encodings: bool = False
average_logits: None | Literal["structures", "noise", "both"] = None
multi_state_strategy: Literal["mean", "min", "product", "max_min"] = "mean"
multi_state_alpha: float = 0.5
def __post_init__(self) -> None:
"""Post-initialization processing."""
super().__post_init__()
if self.sampling_strategy == "straight_through" and (
self.iterations is None or self.learning_rate is None
):
msg = "For 'straight_through' sampling, 'iterations' and 'learning_rate' must be provided."
raise ValueError(msg)
if self.output_h5_path and isinstance(self.output_h5_path, str):
object.__setattr__(self, "output_h5_path", Path(self.output_h5_path))
[docs]
@dataclass
class JacobianSpecification(RunSpecification):
"""Configuration for computing categorical Jacobians."""
noise_batch_size: int = 1
jacobian_batch_size: int = 16
average_encodings: bool = True
average_logits: None | Literal["structures", "noise", "both"] = None
combine: bool = False
combine_batch_size: int = 8
combine_noise_batch_size: int = 1
combine_weights: ArrayLike | None = None
combine_fn: CombineCatJacPairFn | None = None
combine_fn_kwargs: dict[str, Any] | None = None
output_h5_path: str | Path | None = None
compute_apc: bool = True
apc_batch_size: int = 8
apc_residue_batch_size: int = 1000
def __post_init__(self) -> None:
"""Post-initialization processing."""
super().__post_init__()
if self.output_h5_path and isinstance(self.output_h5_path, str):
object.__setattr__(self, "output_h5_path", Path(self.output_h5_path))
@dataclass
class ConformationalInferenceSpecification(RunSpecification):
"""Configuration for deriving states from a protein ensemble.
Attributes:
output_h5_path: Optional path to an HDF5 file for streaming output.
batch_size: The batch size for processing proteins (default is 8).
"""
output_h5_path: str | Path | None = None
batch_size: int = 8
inference_strategy: Literal["unconditional", "conditional", "vmm"] = "unconditional"
inference_features: Sequence[Literal["logits", "node_features", "edge_features"]] = ("logits",)
mode: Literal["global", "per"] = "global"
covariance_type: Literal["full", "diag"] = "diag"
gmm_n_components: int = 100
eps_std_scale: float = 1.0
min_cluster_weight: float = 0.01
preprocessing_mode: Literal["pca"] | None = None
gmm_init: Literal["kmeans", "random"] = "kmeans"
gmm_max_iters: int = 100
kmeans_max_iters: int = 200
pca_n_components: int = 20
pca_solver: Literal["full", "randomized"] = "full"
pca_rng_seed: int = 0
gmm_min_iters: int = 10
covariance_regularization: float = 1e-3
reference_sequence: str | None = None
def __post_init__(self) -> None:
"""Post-initialization processing."""
super().__post_init__()
if self.output_h5_path and isinstance(self.output_h5_path, str):
object.__setattr__(self, "output_h5_path", Path(self.output_h5_path))
MIN_PAIR = 2
@dataclass
class InspectionSpecification(RunSpecification):
"""Configuration for inspecting model encodings and features."""
output_h5_path: str | Path | None = None
inspection_features: Sequence[
Literal[
"unconditional_logits",
"encoded_node_features",
"edge_features",
"decoded_node_features",
"conditional_logits",
]
] = ("unconditional_logits",)
distance_matrix: bool = False
distance_matrix_method: Literal["ca", "cb", "backbone_average", "closest_atom"] = "ca"
cross_input_similarity: bool = False
similarity_metric: Literal[
"rmsd",
"tm-score",
"gdt_ts",
"gdt_ha",
"cosine",
] = "rmsd"
def __post_init__(self) -> None:
"""Post-initialization processing."""
super().__post_init__()
if self.output_h5_path and isinstance(self.output_h5_path, str):
object.__setattr__(self, "output_h5_path", Path(self.output_h5_path))
if self.cross_input_similarity and len(_loader_inputs(self.inputs)) < MIN_PAIR:
msg = f"Cross-input similarity requires at least {MIN_PAIR} input structures."
raise ValueError(msg)
Specs = (
RunSpecification
| ScoringSpecification
| SamplingSpecification
| JacobianSpecification
| InspectionSpecification
)