Source code for prxteinmpnn.mpnn

"""Core module for the PrxteinMPNN model.

prxteinmpnn.mpnn
"""

import pathlib
from typing import Literal

import jax
import jax.numpy as jnp
import joblib
from jaxtyping import PyTree

ModelVersion = Literal[
  "v_48_002.pkl",
  "v_48_010.pkl",
  "v_48_020.pkl",
  "v_48_030.pkl",
]

ModelWeights = Literal["original", "soluble"]


[docs] def get_mpnn_model( model_version: ModelVersion = "v_48_020.pkl", model_weights: ModelWeights = "original", ) -> PyTree: """Create a ProteinMPNN model with specified configuration and weights. Args: model_version: The model configuration to use. model_weights: The weights to load for the model. Returns: A PyTree containing the model parameters. Raises: FileNotFoundError: If the model file does not exist. Example: >>> params = get_mpnn_model() """ base_dir = pathlib.Path(__file__).parent model_path = base_dir / "model" / model_weights / model_version if not model_path.exists(): msg = f"Model file not found: {model_path}" raise FileNotFoundError(msg) checkpoint_state = joblib.load(model_path) checkpoint_state = checkpoint_state["model_state_dict"] return jax.tree_util.tree_map(jnp.array, checkpoint_state)