"""Core user interface for the PrxteinMPNN package."""
from __future__ import annotations
import logging
import sys
from functools import partial
from typing import TYPE_CHECKING, Any
import h5py
import jax
import jax.numpy as jnp
if TYPE_CHECKING:
from collections.abc import Generator
from grain.python import IterDataset
from prxteinmpnn.utils.types import (
AtomMask,
ChainIndex,
OneHotProteinSequence,
ResidueIndex,
StructureAtomicCoordinates,
)
from prxteinmpnn.sampling.conditional_logits import ConditionalLogitsFn, make_conditional_logits_fn
from prxteinmpnn.utils.apc import apc_corrected_frobenius_norm
from prxteinmpnn.utils.catjac import (
combine_jacobians_h5_stream,
make_combine_jac,
)
from .prep import prep_protein_stream_and_model
from .specs import JacobianSpecification
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, stream=sys.stdout, force=True)
[docs]
def categorical_jacobian(
spec: JacobianSpecification | None = None,
**kwargs: Any, # noqa: ANN401
) -> dict[
str,
jax.Array
| dict[
str,
JacobianSpecification,
]
| None,
]:
"""Compute the Jacobian of the model's logits with respect to the input sequence.
Args:
spec: An optional JacobianConfig object. If None, a default will be created using
kwargs, options are provided as keyword arguments. The following options can be set:
inputs: A single or sequence of inputs (files, PDB IDs, etc.).
chain_id: Specific chain(s) to parse from the structure.
model: The model number to load. If None, all models are loaded.
altloc: The alternate location identifier to use.
model_version: The model version to use.
model_weights: The model weights to use.
foldcomp_database: The FoldComp database to use for FoldComp IDs.
random_seed: The random number generator key.
backbone_noise: The amount of noise to add to the backbone.
batch_size: The number of structures to process in a single batch.
noise_batch_size: Batch size for noise levels in Jacobian computation.
jacobian_batch_size: Inner batch size for Jacobian computation.
combine_batch_size: Batch size for combining Jacobians.
num_workers: Number of parallel workers for data loading.
combine_fn: Function or string specifying how to combine Jacobian pairs (e.g., "add",
"subtract").
combine_fn_kwargs: Optional dictionary of keyword arguments for the combine function.
combine_weights: Optional weights to use when combining Jacobians.
combine: Whether to combine Jacobians across samples.
output_h5_path: Optional path to an HDF5 file for streaming output.
compute_apc: Whether to compute APC-corrected Frobenius norm.
**kwargs: Additional keyword arguments for structure loading.
Returns:
A dictionary containing the Jacobian tensor and metadata.
"""
if spec is None:
spec = JacobianSpecification(**kwargs)
protein_iterator, model_parameters = prep_protein_stream_and_model(spec)
conditional_logits_fn = make_conditional_logits_fn(model_parameters=model_parameters)
if spec.output_h5_path:
result = _categorical_jacobian_streaming(spec, protein_iterator, conditional_logits_fn)
if spec.combine:
if not spec.output_h5_path:
msg = "output_h5_path must be provided for streaming."
raise ValueError(msg)
if not spec.combine_weights is not None:
msg = "combine_weights must be provided for streaming."
raise ValueError(msg)
combine_fn = spec.combine_fn
combine_jacobians_h5_stream(
h5_path=spec.output_h5_path,
combine_fn=combine_fn, # pyright: ignore[reportArgumentType]
fn_kwargs=spec.combine_fn_kwargs or {},
batch_size=spec.combine_batch_size,
weights=jnp.asarray(spec.combine_weights),
)
return result
return _categorical_jacobian_in_memory(spec, protein_iterator, conditional_logits_fn)
def _compute_jacobian_batches(
spec: JacobianSpecification,
protein_iterator: IterDataset,
conditional_logits_fn: ConditionalLogitsFn,
) -> Generator[tuple[jax.Array, jax.Array], None, None]:
"""Generate and yield Jacobian batches."""
for batched_ensemble in protein_iterator:
def compute_jacobian_for_structure(
coords: jax.Array,
atom_mask: jax.Array,
residue_ix: jax.Array,
chain_ix: jax.Array,
one_hot_sequence: jax.Array,
noise: jax.Array,
) -> jax.Array:
length = one_hot_sequence.shape[0]
residue_mask = atom_mask[:, 0]
one_hot_flat = one_hot_sequence.flatten()
input_dim = one_hot_flat.shape[0]
def logit_fn(one_hot_flat: jax.Array) -> jax.Array:
one_hot_2d = one_hot_flat.reshape(length, 21)
logits, _, _ = conditional_logits_fn(
jax.random.key(spec.random_seed),
coords,
one_hot_2d,
residue_mask,
residue_ix,
chain_ix,
None,
48,
noise,
)
return logits.flatten()
def jvp_fn(tangent: jax.Array) -> jax.Array:
return jax.jvp(logit_fn, (one_hot_flat,), (tangent,))[1]
def chunked_jacobian(idx: jax.Array) -> jax.Array:
tangent = jax.nn.one_hot(idx, num_classes=input_dim, dtype=one_hot_flat.dtype)
return jvp_fn(tangent)
jacobian_flat = jax.lax.map(
chunked_jacobian,
jnp.arange(input_dim),
batch_size=spec.jacobian_batch_size,
)
return jacobian_flat.reshape(length, 21, length, 21)
def mapped_fn(
coords: StructureAtomicCoordinates,
atom_mask: AtomMask,
residue_ix: ResidueIndex,
chain_ix: ChainIndex,
one_hot_sequence: OneHotProteinSequence,
) -> jax.Array:
"""Compute Jacobians for a single structure across multiple noise levels."""
return jax.lax.map(
partial(
compute_jacobian_for_structure,
coords,
atom_mask,
residue_ix,
chain_ix,
one_hot_sequence,
),
jnp.asarray(spec.backbone_noise, dtype=jnp.float32),
batch_size=spec.noise_batch_size,
)
jacobians_batch = jax.vmap(mapped_fn)(
batched_ensemble.coordinates,
batched_ensemble.atom_mask,
batched_ensemble.residue_index,
batched_ensemble.chain_index,
batched_ensemble.one_hot_sequence,
)
yield jacobians_batch, batched_ensemble.one_hot_sequence
def _categorical_jacobian_in_memory(
spec: JacobianSpecification,
protein_iterator: IterDataset,
conditional_logits_fn: Any,
) -> dict[str, jax.Array | dict[str, JacobianSpecification] | None]:
"""Compute Jacobians and store them in memory."""
all_jacobians, all_sequences = [], []
for jacobians_batch, one_hot_sequence_batch in _compute_jacobian_batches(
spec,
protein_iterator,
conditional_logits_fn,
):
all_jacobians.append(jacobians_batch)
all_sequences.append(one_hot_sequence_batch)
if not all_jacobians:
return {"categorical_jacobians": None, "metadata": None}
jacobians = jnp.concatenate(all_jacobians, axis=0)
apc_jacobians = (
jax.vmap(jax.vmap(apc_corrected_frobenius_norm))(jacobians) if spec.compute_apc else None
)
combine_jacs_fn = make_combine_jac(
combine_fn=spec.combine_fn,
fn_kwargs=spec.combine_fn_kwargs,
batch_size=spec.combine_batch_size,
)
combined_jacs, mapping = (
combine_jacs_fn(
jacobians,
jnp.concatenate(all_sequences, axis=0),
jnp.asarray(spec.combine_weights, dtype=jnp.float32),
)
if spec.combine
else (None, None)
)
return {
"categorical_jacobians": jacobians,
"apc_corrected_jacobians": apc_jacobians,
"combined": combined_jacs,
"mapping": mapping,
"metadata": {
"spec": spec,
},
}
def _compute_and_write_jacobians_streaming(
f: h5py.File,
spec: JacobianSpecification,
protein_iterator: IterDataset,
conditional_logits_fn: ConditionalLogitsFn,
) -> None:
"""Compute Jacobians and stream them to an HDF5 file."""
jac_ds = f.create_dataset(
"categorical_jacobians",
(0, 0, 0, 0, 0, 0),
maxshape=(None, None, None, None, None, None),
chunks=True,
)
seq_ds = f.create_dataset(
"one_hot_sequences",
(0, 0, 0),
maxshape=(None, None, None),
chunks=True,
)
apc_ds = (
f.create_dataset(
"apc_corrected_jacobians",
(0, 0, 0, 0),
maxshape=(None, None, None, None),
chunks=True,
)
if spec.compute_apc
else None
)
for jacobians_batch, one_hot_sequence_batch in _compute_jacobian_batches(
spec,
protein_iterator,
conditional_logits_fn,
):
current_size = jac_ds.shape[0]
new_size = current_size + jacobians_batch.shape[0]
jac_ds.resize((new_size, *jacobians_batch.shape[1:]))
seq_ds.resize((new_size, *one_hot_sequence_batch.shape[1:]))
jac_ds[current_size:new_size] = jacobians_batch
seq_ds[current_size:new_size] = one_hot_sequence_batch
if apc_ds:
apc_func = partial(
apc_corrected_frobenius_norm,
residue_batch_size=spec.apc_residue_batch_size,
)
apc_jacobians = jax.vmap(
lambda noise_jac, apc_func=apc_func: jax.lax.map(
apc_func,
noise_jac,
batch_size=spec.apc_batch_size,
),
)(jacobians_batch)
apc_ds.resize((new_size, *apc_jacobians.shape[1:]))
apc_ds[current_size:new_size] = apc_jacobians
f.flush()
def _categorical_jacobian_streaming(
spec: JacobianSpecification,
protein_iterator: IterDataset,
conditional_logits_fn: ConditionalLogitsFn,
) -> dict[str, Any]:
"""Compute Jacobians and stream them to an HDF5 file."""
if not spec.output_h5_path:
msg = "output_h5_path must be provided for streaming."
raise ValueError(msg)
with h5py.File(spec.output_h5_path, "w") as f:
_compute_and_write_jacobians_streaming(f, spec, protein_iterator, conditional_logits_fn)
return {
"output_h5_path": str(spec.output_h5_path),
"metadata": {
"spec": spec,
},
}