decoding_order#

Utilities for decoding order generation.

prxteinmpnn.utils.decoding_order

prxteinmpnn.utils.decoding_order.random_decoding_order(prng_key, num_residues, tie_group_map=None, num_groups=None)[source]#

Return a random decoding order, optionally respecting tied positions.

Parameters:
  • prng_key (Union[Key[Array, ''], UInt32[Array, '2']]) – PRNG key for randomness.

  • num_residues (int) – Total number of residues.

  • tie_group_map (Array | None) – Optional (N,) array mapping each position to a group ID. Positions with the same group ID are tied and will be decoded together in the same step.

  • num_groups (int | None) – Number of unique groups in tie_group_map. Required if tie_group_map is provided. Should equal tie_group_map.max() + 1 when groups are normalized to [0, 1, …, num_groups-1].

Return type:

tuple[Int[Array, 'num_residues'], Union[Key[Array, ''], UInt32[Array, '2']]]

Returns:

Tuple of (decoding_order, next_key) where decoding_order respects ties.

Example

>>> key = jax.random.PRNGKey(0)
>>> # Without ties: standard random order
>>> order, key = random_decoding_order(key, 5)
>>>
>>> # With ties: positions in same group stay together
>>> tie_map = jnp.array([0, 1, 0, 2, 1])  # Groups: {0: [0,2], 1: [1,4], 2: [3]}
>>> order, key = random_decoding_order(key, 5, tie_map, num_groups=3)
prxteinmpnn.utils.decoding_order.single_decoding_order(key, num_residues, tie_group_map=None, num_groups=None)[source]#

Generate a single decoding order (identity permutation).

Parameters:
  • key (Union[Key[Array, ''], UInt32[Array, '2']]) – Random key (unused, for API compatibility).

  • num_residues (int) – Number of residues.

  • tie_group_map (Array | None) – Optional (N,) array mapping each position to a group ID. Currently ignored for single decoding order.

  • num_groups (int | None) – Number of unique groups (unused, for API compatibility).

Returns:

(N,) array, [0, 1, …, N-1]. key: Same random key (unchanged).

Return type:

decoding_order