Source code for prxteinmpnn.sampling
"""Sampling utilities for PrXteinMPNN."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from prxteinmpnn.sampling.conditional_logits import (
make_conditional_logits_fn,
make_encoding_conditional_logits_split_fn,
)
from prxteinmpnn.sampling.sample import make_encoding_sampling_split_fn, make_sample_sequences
from prxteinmpnn.sampling.unconditional_logits import make_unconditional_logits_fn
from prxteinmpnn.utils import ste
if TYPE_CHECKING:
import jax
from prxteinmpnn.model import PrxteinMPNN
__all__ = [
"make_conditional_logits_fn",
"make_encoding_conditional_logits_split_fn",
"make_encoding_sampling_split_fn",
"make_sample_sequences",
"make_unconditional_logits_fn",
"sample",
"ste",
]
[docs]
def sample(
prng_key: jax.Array,
model: PrxteinMPNN,
structure_coordinates: jax.Array,
mask: jax.Array,
residue_index: jax.Array,
chain_index: jax.Array,
**kwargs: Any, # noqa: ANN401
) -> tuple[jax.Array, jax.Array, jax.Array]:
"""Sample sequences from a structure using the default temperature sampler.
This is a convenience wrapper around `make_sample_sequences`.
Args:
prng_key: JAX random key.
model: A PrxteinMPNN Equinox model instance.
structure_coordinates: Atomic coordinates (N, 4, 3).
mask: Alpha carbon mask indicating valid residues.
residue_index: Residue indices.
chain_index: Chain indices.
**kwargs: Additional keyword arguments for the sampler.
Returns:
Tuple of (sampled sequence, logits, decoding order).
"""
sampler = make_sample_sequences(model, sampling_strategy="temperature")
return sampler(
prng_key,
structure_coordinates,
mask,
residue_index,
chain_index,
**kwargs,
)