Source code for prxteinmpnn.mpnn
"""Core module for the PrxteinMPNN model.
prxteinmpnn.mpnn
"""
import enum
import pathlib
import jax
import jax.numpy as jnp
import joblib
from jaxtyping import PyTree
[docs]
class ProteinMPNNModelVersion(enum.Enum):
"""Enum for different ProteinMPNN model configurations."""
V_48_002 = "v_48_002.pkl"
V_48_010 = "v_48_010.pkl"
V_48_020 = "v_48_020.pkl"
V_48_030 = "v_48_030.pkl"
[docs]
class ModelWeights(enum.Enum):
"""Enum for different sets of model weights."""
DEFAULT = "original"
SOLUBLE = "soluble"
[docs]
def get_mpnn_model(
model_version: ProteinMPNNModelVersion = ProteinMPNNModelVersion.V_48_020,
model_weights: ModelWeights = ModelWeights.DEFAULT,
) -> PyTree:
"""Create a ProteinMPNN model with specified configuration and weights.
Args:
model_version: The model configuration to use.
model_weights: The weights to load for the model.
Returns:
A PyTree containing the model parameters.
Raises:
FileNotFoundError: If the model file does not exist.
Example:
>>> params = get_mpnn_model()
"""
base_dir = pathlib.Path(__file__).parent
model_path = base_dir / "model" / model_weights.value / model_version.value
if not model_path.exists():
msg = f"Model file not found: {model_path}"
raise FileNotFoundError(msg)
checkpoint_state = joblib.load(model_path)
checkpoint_state = checkpoint_state["model_state_dict"]
return jax.tree_util.tree_map(jnp.array, checkpoint_state)