Source code for prxteinmpnn.run.sampling

"""Core user interface for the PrxteinMPNN package."""

from __future__ import annotations

import logging
import sys
from functools import partial
from typing import TYPE_CHECKING, Any

import h5py
import jax
import jax.numpy as jnp

from prxteinmpnn.sampling.sample import make_sample_sequences
from prxteinmpnn.utils.autoregression import resolve_tie_groups
from prxteinmpnn.utils.decoding_order import random_decoding_order

from .prep import prep_protein_stream_and_model
from .specs import SamplingSpecification

if TYPE_CHECKING:
  from jaxtyping import PRNGKeyArray

  from prxteinmpnn.utils.types import (
    AlphaCarbonMask,
    BackboneCoordinates,
    BackboneNoise,
    ChainIndex,
    DecodingOrder,
    Logits,
    ProteinSequence,
    ResidueIndex,
  )

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, stream=sys.stdout, force=True)


[docs] def sample( spec: SamplingSpecification | None = None, **kwargs: Any, # noqa: ANN401 ) -> dict[str, Any]: """Sample new sequences for the given input structures. This function uses a high-performance Grain pipeline to load and process structures, then samples new sequences for each structure. Args: spec: An optional SamplingSpecification object. If None, a default will be created using kwargs, options are provided as keyword arguments. The following options can be set: inputs: A single or sequence of inputs (files, PDB IDs, etc.). chain_id: Specific chain(s) to parse from the structure. model: The model number to load. If None, all models are loaded. altloc: The alternate location identifier to use. model_version: The model version to use. model_weights: The model weights to use. foldcomp_database: The FoldComp database to use for FoldComp IDs. random_seed: The random number generator key. backbone_noise: The amount of noise to add to the backbone. num_samples: The number of sequences to sample per structure/noise level. sampling_strategy: The sampling strategy to use. temperature: The sampling temperature. bias: An optional array to bias the logits. fixed_positions: An optional array of residue indices to keep fixed. iterations: Number of optimization iterations for "straight_through" sampling. learning_rate: Learning rate for "straight_through" sampling. batch_size: The number of structures to process in a single batch. **kwargs: Additional keyword arguments for structure loading. Returns: A dictionary containing sampled sequences, logits, and metadata. """ if spec is None: spec = SamplingSpecification(**kwargs) if spec.output_h5_path: return _sample_streaming(spec) protein_iterator, model = prep_protein_stream_and_model(spec) sampler_fn = make_sample_sequences( model=model, decoding_order_fn=random_decoding_order, sampling_strategy=spec.sampling_strategy, ) all_sequences, all_logits = [], [] for batched_ensemble in protein_iterator: keys = jax.random.split(jax.random.key(spec.random_seed), spec.num_samples) # Compute tie groups if needed tie_group_map = None num_groups = None # Handle inter-chain mode with tied positions if spec.pass_mode == "inter" and spec.tied_positions is not None: # noqa: S105 # For inter mode, loader has already concatenated structures with sequential chain IDs # Just resolve the tie groups on the concatenated protein tie_group_map = resolve_tie_groups(spec, batched_ensemble) num_groups = int(jnp.max(tie_group_map)) + 1 vmap_samples = jax.vmap( sampler_fn, in_axes=( 0, # keys None, # coordinates None, # mask None, # residue_index None, # chain_index None, # k_neighbors None, # bias None, # fixed_positions None, # backbone_noise None, # iterations None, # learning_rate None, # temperature None, # tie_group_map None, # num_groups None, # multi_state_strategy None, # multi_state_alpha ), out_axes=0, ) vmap_noises = jax.vmap( vmap_samples, in_axes=( None, # keys None, # coordinates None, # mask None, # residue_index None, # chain_index None, # k_neighbors None, # bias None, # fixed_positions 0, # backbone_noise None, # iterations None, # learning_rate None, # temperature None, # tie_group_map None, # num_groups None, # multi_state_strategy None, # multi_state_alpha ), out_axes=0, ) vmap_structures = jax.vmap( vmap_noises, in_axes=( None, # keys 0, # coordinates 0, # residue_mask 0, # residue_index 0, # chain_index None, # k_neighbors None, # bias None, # fixed_positions None, # backbone_noise None, # iterations None, # learning_rate None, # temperature None, # tie_group_map None, # num_groups None, # multi_state_strategy None, # multi_state_alpha ), out_axes=0, ) sampled_sequences, logits, _ = vmap_structures( keys, batched_ensemble.coordinates, batched_ensemble.mask, batched_ensemble.residue_index, batched_ensemble.chain_index, 48, jnp.asarray(spec.bias, dtype=jnp.float32) if spec.bias is not None else None, jnp.asarray(spec.fixed_positions, dtype=jnp.int32) if spec.fixed_positions is not None else None, jnp.asarray(spec.backbone_noise, dtype=jnp.float32) if spec.backbone_noise is not None else None, spec.iterations, spec.learning_rate, spec.temperature, tie_group_map, num_groups, spec.multi_state_strategy, spec.multi_state_alpha, ) all_sequences.append(sampled_sequences) all_logits.append(logits) if not all_sequences: return {"sampled_sequences": None, "logits": None, "metadata": None} return { "sequences": jnp.concatenate(all_sequences, axis=0), "logits": jnp.concatenate(all_logits, axis=0), "metadata": { "specification": spec, }, }
def _sample_streaming( spec: SamplingSpecification, ) -> dict[str, str | dict[str, SamplingSpecification]]: """Sample new sequences and stream results to an HDF5 file.""" if not spec.output_h5_path: msg = "output_h5_path must be provided for streaming." raise ValueError(msg) protein_iterator, model = prep_protein_stream_and_model(spec) if spec.average_encodings: msg = "average_encodings feature is temporarily disabled during Equinox migration" raise NotImplementedError(msg) sampler_fn = make_sample_sequences( model=model, decoding_order_fn=random_decoding_order, sampling_strategy="temperature", ) with h5py.File(spec.output_h5_path, "w") as f: structure_idx = 0 for batched_ensemble in protein_iterator: keys = jax.random.split(jax.random.key(spec.random_seed), spec.num_samples) # Compute tie groups if needed tie_group_map = None num_groups = None # Handle inter-chain mode with tied positions if spec.pass_mode == "inter" and spec.tied_positions is not None: # noqa: S105 # For inter mode, loader has already concatenated structures with sequential chain IDs # Just resolve the tie groups on the concatenated protein tie_group_map = resolve_tie_groups(spec, batched_ensemble) num_groups = int(jnp.max(tie_group_map)) + 1 noise_array = ( jnp.asarray(spec.backbone_noise, dtype=jnp.float32) if spec.backbone_noise is not None else jnp.zeros(1) ) # Create the sample function with bound parameters to avoid closure issues sample_fn_with_params = partial( sampler_fn, _k_neighbors=48, bias=jnp.asarray(spec.bias, dtype=jnp.float32) if spec.bias is not None else None, fixed_positions=( jnp.asarray(spec.fixed_positions, dtype=jnp.int32) if spec.fixed_positions is not None else None ), _iterations=spec.iterations, _learning_rate=spec.learning_rate, temperature=spec.temperature, tie_group_map=tie_group_map, num_groups=num_groups, multi_state_strategy=spec.multi_state_strategy, multi_state_alpha=spec.multi_state_alpha, ) def sample_single_noise( key: PRNGKeyArray, coords: BackboneCoordinates, mask: AlphaCarbonMask, residue_ix: ResidueIndex, chain_ix: ChainIndex, noise: BackboneNoise, _sampler: partial = sample_fn_with_params, # Bind to avoid closure issues ) -> tuple[ProteinSequence, Logits, DecodingOrder]: """Sample one sequence for one structure at one noise level.""" return _sampler( key, coords, mask, residue_ix, chain_ix, backbone_noise=noise, ) def mapped_fn_noise( key: PRNGKeyArray, coords: BackboneCoordinates, mask: AlphaCarbonMask, residue_ix: ResidueIndex, chain_ix: ChainIndex, noise_arr: BackboneNoise, ) -> tuple[ProteinSequence, Logits, DecodingOrder]: """Compute samples across all noise levels for a single structure/sample.""" return jax.lax.map( partial( sample_single_noise, key, coords, mask, residue_ix, chain_ix, ), noise_arr, batch_size=spec.noise_batch_size, ) def internal_sample( coords: BackboneCoordinates, mask: AlphaCarbonMask, residue_ix: ResidueIndex, chain_ix: ChainIndex, keys_arr: PRNGKeyArray, noise_arr: BackboneNoise, ) -> tuple[ProteinSequence, Logits, DecodingOrder]: """Sample mapping over keys and noise.""" noise_map_fn = partial( mapped_fn_noise, coords=coords, mask=mask, residue_ix=residue_ix, chain_ix=chain_ix, noise_arr=noise_arr, ) return jax.lax.map( noise_map_fn, keys_arr, batch_size=spec.samples_batch_size, ) vmap_structures = jax.vmap(internal_sample, in_axes=(0, 0, 0, 0, None, None)) sampled_sequences, sampled_logits, _ = vmap_structures( batched_ensemble.coordinates, batched_ensemble.mask, batched_ensemble.residue_index, batched_ensemble.chain_index, keys, noise_array, ) # Store each structure in its own group to handle variable lengths for i in range(sampled_sequences.shape[0]): grp = f.create_group(f"structure_{structure_idx}") grp.create_dataset("sequences", data=sampled_sequences[i], dtype="i4") grp.create_dataset("logits", data=sampled_logits[i], dtype="f4") # Store metadata about the structure grp.attrs["structure_index"] = structure_idx grp.attrs["num_samples"] = sampled_sequences.shape[1] grp.attrs["num_noise_levels"] = sampled_sequences.shape[2] grp.attrs["sequence_length"] = sampled_sequences.shape[3] structure_idx += 1 f.flush() return { "output_h5_path": str(spec.output_h5_path), "metadata": { "specification": spec, }, } def _sample_streaming_averaged( _spec: SamplingSpecification, _protein_iterator: Any, # noqa: ANN401 _model: Any, # noqa: ANN401 ) -> dict[str, str | dict[str, SamplingSpecification]]: """Sample with averaged encodings across noise levels. TODO: Re-implement this feature after Equinox migration is complete. This requires a way to separate encoding and sampling steps in the new architecture. See issue: TBD Args: _spec: Sampling specification (unused). _protein_iterator: Protein data iterator (unused). _model: PrxteinMPNN model (unused). Returns: Nothing, always raises NotImplementedError. Raises: NotImplementedError: This feature is temporarily disabled. """ msg = "average_encodings feature is temporarily disabled during Equinox migration" raise NotImplementedError(msg)