Source code for prxteinmpnn.utils.decoding_order

"""Utilities for decoding order generation.

prxteinmpnn.utils.decoding_order
"""

from collections.abc import Callable
from functools import partial

import jax
import jax.numpy as jnp
from jaxtyping import Int, PRNGKeyArray

from .autoregression import get_decoding_step_map
from .types import (
  DecodingOrder,
)

DecodingOrderInputs = tuple[PRNGKeyArray, Int, jnp.ndarray | None]  # Added tie_group_map
DecodingOrderOutputs = tuple[DecodingOrder, PRNGKeyArray]
DecodingOrderFn = Callable[
  [PRNGKeyArray, int, jnp.ndarray | None, int | None],
  DecodingOrderOutputs,
]


[docs] @partial(jax.jit, static_argnames=("num_residues", "num_groups")) def random_decoding_order( prng_key: PRNGKeyArray, num_residues: int, tie_group_map: jnp.ndarray | None = None, num_groups: int | None = None, ) -> DecodingOrderOutputs: """Return a random decoding order, optionally respecting tied positions. Args: prng_key: PRNG key for randomness. num_residues: Total number of residues. tie_group_map: Optional (N,) array mapping each position to a group ID. Positions with the same group ID are tied and will be decoded together in the same step. num_groups: Number of unique groups in tie_group_map. Required if tie_group_map is provided. Should equal tie_group_map.max() + 1 when groups are normalized to [0, 1, ..., num_groups-1]. Returns: Tuple of (decoding_order, next_key) where decoding_order respects ties. Example: >>> key = jax.random.PRNGKey(0) >>> # Without ties: standard random order >>> order, key = random_decoding_order(key, 5) >>> >>> # With ties: positions in same group stay together >>> tie_map = jnp.array([0, 1, 0, 2, 1]) # Groups: {0: [0,2], 1: [1,4], 2: [3]} >>> order, key = random_decoding_order(key, 5, tie_map, num_groups=3) """ current_key, next_key = jax.random.split(prng_key) if tie_group_map is None: # Standard random order without ties decoding_order = jax.random.permutation(current_key, jnp.arange(0, num_residues)) return jnp.asarray(decoding_order, dtype=jnp.int32), next_key if num_groups is None: msg = "num_groups must be provided when tie_group_map is not None" raise ValueError(msg) # With tied positions: generate random order over groups (vectorized) group_order = jax.random.permutation(current_key, jnp.arange(num_groups)) # Map groups to decoding steps decoding_step_map = get_decoding_step_map(tie_group_map, group_order, num_groups) # Create decoding order: sort by step, then by position within step # This ensures positions in the same group are adjacent decoding_order = jnp.argsort(decoding_step_map) return jnp.asarray(decoding_order, dtype=jnp.int32), next_key
[docs] def single_decoding_order( key: PRNGKeyArray, num_residues: int, tie_group_map: jnp.ndarray | None = None, num_groups: int | None = None, ) -> DecodingOrderOutputs: """Generate a single decoding order (identity permutation). Args: key: Random key (unused, for API compatibility). num_residues: Number of residues. tie_group_map: Optional (N,) array mapping each position to a group ID. Currently ignored for single decoding order. num_groups: Number of unique groups (unused, for API compatibility). Returns: decoding_order: (N,) array, [0, 1, ..., N-1]. key: Same random key (unchanged). """ del tie_group_map, num_groups # Unused for single order decoding_order = jnp.arange(0, num_residues, dtype=jnp.int32) return decoding_order, key