Source code for prxteinmpnn.run.scoring

"""Core user interface for the PrxteinMPNN package."""

from __future__ import annotations

import logging
import sys
from typing import Any

import h5py
import jax
import jax.numpy as jnp

from prxteinmpnn.scoring.score import make_score_sequence
from prxteinmpnn.utils.aa_convert import string_to_protein_sequence

from .prep import prep_protein_stream_and_model
from .specs import ScoringSpecification

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, stream=sys.stdout, force=True)


[docs] def score( spec: ScoringSpecification | None = None, **kwargs: Any, ) -> dict[str, Any]: """Score all provided sequences against all input structures. This function uses a high-performance Grain pipeline to load and process structures, then scores all provided sequences against each structure. Args: spec: An optional ScoringSpecification object. If None, a default will be created using kwargs, options are provided as keyword arguments. The following options can be set: inputs: A single or sequence of inputs (files, PDB IDs, etc.). sequences_to_score: A list of protein sequences (strings) to score. chain_id: Specific chain(s) to parse from the structure. model: The model number to load. If None, all models are loaded. altloc: The alternate location identifier to use. model_version: The model version to use. model_weights: The model weights to use. foldcomp_database: The FoldComp database to use for FoldComp IDs. random_seed: The random number generator key. backbone_noise: The amount of noise to add to the backbone. ar_mask: An optional array of shape (L, L) to mask out certain residue pairs. batch_size: The number of structures to process in a single batch. num_workers: Number of parallel workers for data loading. **kwargs: Additional keyword arguments for structure loading. Returns: A dictionary containing scores, logits, and metadata. """ if spec is None: spec = ScoringSpecification(**kwargs) if spec.output_h5_path: return _score_streaming(spec) if not spec.sequences_to_score: msg = ( "No sequences provided for scoring. `sequences_to_score` must be a non-empty list of strings." ) raise ValueError(msg) integer_sequences = [string_to_protein_sequence(s) for s in spec.sequences_to_score] batched_sequences = jnp.concatenate(integer_sequences) protein_iterator, model_parameters = prep_protein_stream_and_model(spec) score_single_pair = make_score_sequence(model_parameters=model_parameters) all_scores, all_logits = [], [] for batched_ensemble in protein_iterator: max_len = batched_ensemble.coordinates.shape[1] current_ar_mask = ( 1 - jnp.eye(max_len, dtype=jnp.bool_) if spec.ar_mask is None else jnp.asarray(spec.ar_mask) ) vmap_sequences = jax.vmap( score_single_pair, in_axes=(None, 0, None, None, None, None, None, None, None), out_axes=0, ) vmap_noises = jax.vmap( vmap_sequences, in_axes=(None, None, None, None, None, None, None, 0, None), out_axes=0, ) vmap_structures = jax.vmap( vmap_noises, in_axes=(None, None, 0, 0, 0, 0, None, None, None), out_axes=0, ) scores, logits, decoding_orders = vmap_structures( jax.random.key(spec.random_seed), batched_sequences, batched_ensemble.coordinates, batched_ensemble.atom_mask, batched_ensemble.residue_index, batched_ensemble.chain_index, 48, jnp.asarray(spec.backbone_noise, dtype=jnp.float32), current_ar_mask, ) all_scores.append(scores) all_logits.append(logits) if not all_scores: return {"scores": None, "logits": None, "metadata": None} return { "scores": jnp.concatenate(all_scores, axis=0), "logits": jnp.concatenate(all_logits, axis=0), "metadata": { "specification": spec, }, }
def _score_streaming( spec: ScoringSpecification, ) -> dict[str, str | dict[str, ScoringSpecification]]: """Score sequences and stream results to an HDF5 file.""" if not spec.output_h5_path: msg = "output_h5_path must be provided for streaming." raise ValueError(msg) integer_sequences = [string_to_protein_sequence(s) for s in spec.sequences_to_score] batched_sequences = jnp.concatenate(integer_sequences) if batched_sequences.ndim == 1: batched_sequences = jnp.expand_dims(batched_sequences, 0) protein_iterator, model_parameters = prep_protein_stream_and_model(spec) score_single_pair = make_score_sequence(model_parameters=model_parameters) with h5py.File(spec.output_h5_path, "w") as f: scores_ds = f.create_dataset( "scores", (0,), maxshape=(None,), chunks=True, dtype="f4", ) logits_ds = f.create_dataset( "logits", (0, 0, 0), maxshape=(None, None, None), chunks=True, dtype="f4", ) for batched_ensemble in protein_iterator: max_len = batched_ensemble.coordinates.shape[1] current_ar_mask = ( 1 - jnp.eye(max_len, dtype=jnp.bool_) if spec.ar_mask is None else jnp.asarray(spec.ar_mask) ) vmap_sequences = jax.vmap( score_single_pair, in_axes=(None, 0, None, None, None, None, None, None, None), out_axes=0, ) vmap_noises = jax.vmap( vmap_sequences, in_axes=(None, None, None, None, None, None, None, 0, None), out_axes=0, ) vmap_structures = jax.vmap( vmap_noises, in_axes=(None, None, 0, 0, 0, 0, None, None, None), out_axes=0, ) scores, logits, _ = vmap_structures( jax.random.key(spec.random_seed), batched_sequences, batched_ensemble.coordinates, batched_ensemble.atom_mask, batched_ensemble.residue_index, batched_ensemble.chain_index, 48, jnp.asarray(spec.backbone_noise, dtype=jnp.float32), current_ar_mask, ) scores_ds.resize(scores_ds.shape[0] + scores.size, axis=0) scores_ds[-scores.size :] = scores.flatten() logits_ds.resize(logits_ds.shape[0] + logits.shape[0], axis=0) logits_ds[-logits.shape[0] :, :, :] = logits f.flush() return { "output_h5_path": str(spec.output_h5_path), "metadata": { "specification": spec, }, }