Source code for prxteinmpnn.utils.autoregression

"""Autoregression utilities."""

from __future__ import annotations

from collections import defaultdict
from functools import partial
from typing import TYPE_CHECKING

import jax
import jax.numpy as jnp

if TYPE_CHECKING:
  from collections.abc import Sequence

  from prxteinmpnn.run.specs import RunSpecification
  from prxteinmpnn.utils.data_structures import Protein

  from .types import AutoRegressiveMask, DecodingOrder


[docs] def get_decoding_step_map( tie_group_map: jnp.ndarray, group_decoding_order: jnp.ndarray, num_groups: int | None = None, ) -> jnp.ndarray: """Map each residue to its decoding step index based on group order. Args: tie_group_map: (N,) array of group ids in range [0, num_groups-1]. group_decoding_order: (M,) array, permutation of group ids [0, ..., M-1]. num_groups: Number of unique groups. If not provided, inferred from group_decoding_order length. Returns: decoding_step_map: (N,) array, decoding step for each residue. """ if num_groups is None: num_groups = len(group_decoding_order) group_to_step = ( jnp.zeros(num_groups, dtype=jnp.int32) .at[group_decoding_order] .set(jnp.arange(len(group_decoding_order))) ) return group_to_step[tie_group_map]
[docs] def make_autoregressive_mask(decoding_step_map: jnp.ndarray) -> jnp.ndarray: """Create an (N, N) AR mask for group-based decoding. Positions at step i can attend to all positions at steps <= i. This allows positions in the same group (same step) to attend to each other. Args: decoding_step_map: (N,) array, decoding step for each residue. Returns: mask: (N, N) boolean array where mask[i,j]=True means position i can attend to position j. """ steps_i = decoding_step_map[:, None] steps_j = decoding_step_map[None, :] return steps_i >= steps_j
[docs] def resolve_tie_groups( # noqa: C901, PLR0912 spec: RunSpecification, combined_protein: Protein, structure_mappings: Sequence[dict] | None = None, num_structures: int | None = None, ) -> jnp.ndarray: """Resolve tie groups for tied_positions modes. Args: spec: RunSpecification with tied_positions. combined_protein: Protein dataclass with batch_dim=1 (1, seq_len, ...). structure_mappings: Optional, for 'auto' mode. num_structures: Optional, number of structures concatenated (for 'direct' mode). Returns: tie_group_map: jnp.ndarray of shape (n,) with group ids. """ chain_ids = combined_protein.chain_index[0] residue_indices = combined_protein.residue_index[0] n = chain_ids.shape[0] tie_group_map = jnp.arange(n, dtype=jnp.int32) tied_positions = spec.tied_positions if tied_positions is None: return tie_group_map if tied_positions == "direct": if num_structures is None: if combined_protein.mapping is not None: structure_indices = combined_protein.mapping[0] # Remove batch dim num_inputs = int(jnp.max(structure_indices)) + 1 else: msg = ( "Cannot determine number of structures for 'direct' mode. " "The concatenated protein should have a 'mapping' field with structure indices, " "or pass num_structures explicitly." ) raise ValueError(msg) else: num_inputs = num_structures ll = n // num_inputs if n % ll != 0: msg = ( f"Inputs must be same length for 'direct' mode. " f"Total length {n} is not divisible by {num_inputs} structures." ) raise ValueError(msg) k = n // ll return jnp.tile(jnp.arange(ll, dtype=jnp.int32), k) if tied_positions == "auto": if structure_mappings is None: msg = "structure_mappings required for 'auto' mode." raise ValueError(msg) for seq_pos, struct_pos_list in enumerate(structure_mappings): if len(struct_pos_list) > 1: group_id = n + seq_pos tie_group_map = tie_group_map.at[jnp.array(struct_pos_list)].set(group_id) _, tie_group_map = jnp.unique(tie_group_map, return_inverse=True) return tie_group_map def _collect_group_indices( groups: Sequence[tuple[int, int]], chain_id_arr: jnp.ndarray, residue_idx_arr: jnp.ndarray, ) -> list[tuple[int, list[int]]]: group_map = defaultdict(list) for group_idx, group in enumerate(groups): if isinstance(group[0], (list, tuple)): for tup in group: group_map[group_idx].append(tup) else: group_map[group_idx].append(group) group_indices = [] for group_idx, tuples in group_map.items(): indices = [] for chain_idx, res_idx in tuples: mask = (chain_id_arr == chain_idx) & (residue_idx_arr == res_idx) idx = jnp.where(mask)[0] if idx.size > 0: indices.append(idx[0]) group_indices.append((group_idx, indices)) return group_indices group_indices = _collect_group_indices( tied_positions, chain_ids, residue_indices, ) for group_idx, indices in group_indices: if indices: group_id = n + group_idx tie_group_map = tie_group_map.at[jnp.array(indices)].set(group_id) _, tie_group_map = jnp.unique(tie_group_map, return_inverse=True) return tie_group_map
[docs] @partial(jax.jit, static_argnames=("num_groups",)) def generate_ar_mask( decoding_order: DecodingOrder, chain_idx: jnp.ndarray | None = None, tie_group_map: jnp.ndarray | None = None, num_groups: int | None = None, ) -> AutoRegressiveMask: """Get the autoregressive mask for the given decoding order. When tie_group_map is provided, positions in the same group can attend to each other (within the same decoding step), enabling tied sampling. Args: decoding_order: The order in which atoms are decoded. chain_idx: Optional chain indices. If provided, atoms can only attend to atoms in the same chain that come before them in the decoding order. tie_group_map: Optional (N,) array mapping each position to a group ID. When provided, positions in the same group are decoded simultaneously and can attend to each other. num_groups: Number of unique groups in tie_group_map. Required if tie_group_map is provided. Should equal tie_group_map.max() + 1 when groups are normalized to [0, 1, ..., num_groups-1]. Returns: An autoregressive mask (N, N) where mask[i,j]=1 means position i can attend to position j during decoding. Example: >>> # Standard AR mask >>> order = jnp.array([0, 1, 2]) >>> mask = generate_ar_mask(order) >>> # mask = [[1, 0, 0], [1, 1, 0], [1, 1, 1]] >>> >>> # With tied positions >>> tie_map = jnp.array([0, 1, 0]) # Positions 0 and 2 are tied >>> mask = generate_ar_mask(order, tie_group_map=tie_map, num_groups=2) >>> # Positions 0 and 2 can attend to each other """ if tie_group_map is None: row_indices = decoding_order[:, None] col_indices = decoding_order[None, :] ar_mask = (row_indices >= col_indices).astype(int) else: if num_groups is None: msg = "num_groups must be provided when tie_group_map is not None" raise ValueError(msg) group_mask = tie_group_map[decoding_order][None, :] == jnp.arange(num_groups)[:, None] group_first_occurrence = jnp.argmax(group_mask, axis=1) group_decoding_order = jnp.argsort(group_first_occurrence) decoding_step_map = get_decoding_step_map(tie_group_map, group_decoding_order, num_groups) ar_mask = make_autoregressive_mask(decoding_step_map).astype(int) if chain_idx is not None: same_chain = (chain_idx[:, None] == chain_idx[None, :]).astype(int) ar_mask = ar_mask * same_chain return ar_mask