decoding_order#
Utilities for decoding order generation.
prxteinmpnn.utils.decoding_order
- prxteinmpnn.utils.decoding_order.random_decoding_order(prng_key, num_residues, tie_group_map=None, num_groups=None)[source]#
Return a random decoding order, optionally respecting tied positions.
- Parameters:
prng_key (
Union[Key[Array, ''],UInt32[Array, '2']]) – PRNG key for randomness.num_residues (
int) – Total number of residues.tie_group_map (
Array|None) – Optional (N,) array mapping each position to a group ID. Positions with the same group ID are tied and will be decoded together in the same step.num_groups (
int|None) – Number of unique groups in tie_group_map. Required if tie_group_map is provided. Should equal tie_group_map.max() + 1 when groups are normalized to [0, 1, …, num_groups-1].
- Return type:
tuple[Int[Array, 'num_residues'],Union[Key[Array, ''],UInt32[Array, '2']]]- Returns:
Tuple of (decoding_order, next_key) where decoding_order respects ties.
Example
>>> key = jax.random.PRNGKey(0) >>> # Without ties: standard random order >>> order, key = random_decoding_order(key, 5) >>> >>> # With ties: positions in same group stay together >>> tie_map = jnp.array([0, 1, 0, 2, 1]) # Groups: {0: [0,2], 1: [1,4], 2: [3]} >>> order, key = random_decoding_order(key, 5, tie_map, num_groups=3)
- prxteinmpnn.utils.decoding_order.single_decoding_order(key, num_residues, tie_group_map=None, num_groups=None)[source]#
Generate a single decoding order (identity permutation).
- Parameters:
key (
Union[Key[Array, ''],UInt32[Array, '2']]) – Random key (unused, for API compatibility).num_residues (
int) – Number of residues.tie_group_map (
Array|None) – Optional (N,) array mapping each position to a group ID. Currently ignored for single decoding order.num_groups (
int|None) – Number of unique groups (unused, for API compatibility).
- Returns:
(N,) array, [0, 1, …, N-1]. key: Same random key (unchanged).
- Return type:
decoding_order