Source code for prxteinmpnn.run.sampling

"""Core user interface for the PrxteinMPNN package."""

from __future__ import annotations

import logging
import sys
from typing import Any

import h5py
import jax
import jax.numpy as jnp

from prxteinmpnn.sampling.sample import make_sample_sequences
from prxteinmpnn.utils.decoding_order import random_decoding_order
from prxteinmpnn.utils.residue_constants import atom_order

from .prep import prep_protein_stream_and_model
from .specs import SamplingSpecification

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, stream=sys.stdout, force=True)


[docs] def sample( spec: SamplingSpecification | None = None, **kwargs: Any, ) -> dict[str, Any]: """Sample new sequences for the given input structures. This function uses a high-performance Grain pipeline to load and process structures, then samples new sequences for each structure. Args: spec: An optional SamplingSpecification object. If None, a default will be created using kwargs, options are provided as keyword arguments. The following options can be set: inputs: A single or sequence of inputs (files, PDB IDs, etc.). chain_id: Specific chain(s) to parse from the structure. model: The model number to load. If None, all models are loaded. altloc: The alternate location identifier to use. model_version: The model version to use. model_weights: The model weights to use. foldcomp_database: The FoldComp database to use for FoldComp IDs. random_seed: The random number generator key. backbone_noise: The amount of noise to add to the backbone. num_samples: The number of sequences to sample per structure/noise level. sampling_strategy: The sampling strategy to use. temperature: The sampling temperature. bias: An optional array to bias the logits. fixed_positions: An optional array of residue indices to keep fixed. iterations: Number of optimization iterations for "straight_through" sampling. learning_rate: Learning rate for "straight_through" sampling. batch_size: The number of structures to process in a single batch. num_workers: Number of parallel workers for data loading. **kwargs: Additional keyword arguments for structure loading. Returns: A dictionary containing sampled sequences, logits, and metadata. """ if spec is None: spec = SamplingSpecification(**kwargs) if spec.output_h5_path: return _sample_streaming(spec) protein_iterator, model_parameters = prep_protein_stream_and_model(spec) sampler_fn = make_sample_sequences( model_parameters=model_parameters, decoding_order_fn=random_decoding_order, sampling_strategy=spec.sampling_strategy, ) all_sequences, all_logits = [], [] for batched_ensemble in protein_iterator: residue_mask = batched_ensemble.atom_mask[:, :, atom_order["CA"]] keys = jax.random.split(jax.random.key(spec.random_seed), spec.num_samples) vmap_samples = jax.vmap( sampler_fn, in_axes=(0, None, None, None, None, None, None, None, None, None, None, None), out_axes=0, ) vmap_noises = jax.vmap( vmap_samples, in_axes=(None, None, None, None, None, None, None, None, 0, None, None, None), out_axes=0, ) vmap_structures = jax.vmap( vmap_noises, in_axes=( None, # keys 0, # coordinates 0, # residue_mask 0, # residue_index 0, # chain_index None, # k_neighbors None, # bias None, # fixed_positions None, # backbone_noise None, # iterations None, # learning_rate None, # temperature ), out_axes=0, ) sampled_sequences, logits, _ = vmap_structures( keys, batched_ensemble.coordinates, residue_mask, batched_ensemble.residue_index, batched_ensemble.chain_index, 48, jnp.asarray(spec.bias, dtype=jnp.float32) if spec.bias is not None else None, jnp.asarray(spec.fixed_positions, dtype=jnp.int32) if spec.fixed_positions is not None else None, jnp.asarray(spec.backbone_noise, dtype=jnp.float32) if spec.backbone_noise is not None else None, spec.iterations, spec.learning_rate, spec.temperature, ) all_sequences.append(sampled_sequences) all_logits.append(logits) if not all_sequences: return {"sampled_sequences": None, "logits": None, "metadata": None} return { "sequences": jnp.concatenate(all_sequences, axis=0), "logits": jnp.concatenate(all_logits, axis=0), "metadata": { "specification": spec, }, }
def _sample_streaming( spec: SamplingSpecification, ) -> dict[str, str | dict[str, SamplingSpecification]]: """Sample new sequences and stream results to an HDF5 file.""" if not spec.output_h5_path: msg = "output_h5_path must be provided for streaming." raise ValueError(msg) protein_iterator, model_parameters = prep_protein_stream_and_model(spec) sampler_fn = make_sample_sequences( model_parameters=model_parameters, decoding_order_fn=random_decoding_order, sampling_strategy=spec.sampling_strategy, ) with h5py.File(spec.output_h5_path, "w") as f: seq_ds = f.create_dataset( "sequences", (0, 0, 0), maxshape=(None, None, None), chunks=True, dtype="i4", ) logits_ds = f.create_dataset( "logits", (0, 0, 0, 0), maxshape=(None, None, None, None), chunks=True, dtype="f4", ) for batched_ensemble in protein_iterator: residue_mask = batched_ensemble.atom_mask[:, :, atom_order["CA"]] keys = jax.random.split(jax.random.key(spec.random_seed), spec.num_samples) vmap_samples = jax.vmap( sampler_fn, in_axes=(0, None, None, None, None, None, None, None, None, None, None, None), out_axes=0, ) vmap_noises = jax.vmap( vmap_samples, in_axes=(None, None, None, None, None, None, None, None, 0, None, None, None), out_axes=0, ) vmap_structures = jax.vmap( vmap_noises, in_axes=( None, # keys 0, # coordinates 0, # residue_mask 0, # residue_index 0, # chain_index None, # k_neighbors None, # bias None, # fixed_positions None, # backbone_noise None, # iterations None, # learning_rate None, # temperature ), out_axes=0, ) sampled_sequences, logits, _ = vmap_structures( keys, batched_ensemble.coordinates, residue_mask, batched_ensemble.residue_index, batched_ensemble.chain_index, 48, jnp.asarray(spec.bias, dtype=jnp.float32) if spec.bias is not None else None, jnp.asarray(spec.fixed_positions, dtype=jnp.int32) if spec.fixed_positions is not None else None, jnp.asarray(spec.backbone_noise, dtype=jnp.float32) if spec.backbone_noise is not None else None, spec.iterations, spec.learning_rate, spec.temperature, ) seq_ds.resize(seq_ds.shape[0] + sampled_sequences.shape[0], axis=0) seq_ds[-sampled_sequences.shape[0] :, :, :] = sampled_sequences logits_ds.resize(logits_ds.shape[0] + logits.shape[0], axis=0) logits_ds[-logits.shape[0] :, :, :, :] = logits f.flush() return { "output_h5_path": str(spec.output_h5_path), "metadata": { "specification": spec, }, }