autoregression#

Autoregression utilities.

prxteinmpnn.utils.autoregression.get_decoding_step_map(tie_group_map, group_decoding_order, num_groups=None)[source]#

Map each residue to its decoding step index based on group order.

Parameters:
  • tie_group_map (Array) – (N,) array of group ids in range [0, num_groups-1].

  • group_decoding_order (Array) – (M,) array, permutation of group ids [0, …, M-1].

  • num_groups (int | None) – Number of unique groups. If not provided, inferred from group_decoding_order length.

Returns:

(N,) array, decoding step for each residue.

Return type:

decoding_step_map

prxteinmpnn.utils.autoregression.make_autoregressive_mask(decoding_step_map)[source]#

Create an (N, N) AR mask for group-based decoding.

Positions at step i can attend to all positions at steps <= i. This allows positions in the same group (same step) to attend to each other.

Parameters:

decoding_step_map (Array) – (N,) array, decoding step for each residue.

Returns:

(N, N) boolean array where mask[i,j]=True means position i

can attend to position j.

Return type:

mask

prxteinmpnn.utils.autoregression.resolve_tie_groups(spec, combined_protein, structure_mappings=None, num_structures=None)[source]#

Resolve tie groups for tied_positions modes.

Parameters:
  • spec (RunSpecification) – RunSpecification with tied_positions.

  • combined_protein (Protein) – Protein dataclass with batch_dim=1 (1, seq_len, …).

  • structure_mappings (Sequence[dict] | None) – Optional, for ‘auto’ mode.

  • num_structures (int | None) – Optional, number of structures concatenated (for ‘direct’ mode).

Returns:

jnp.ndarray of shape (n,) with group ids.

Return type:

tie_group_map

prxteinmpnn.utils.autoregression.generate_ar_mask(decoding_order, chain_idx=None, tie_group_map=None, num_groups=None)[source]#

Get the autoregressive mask for the given decoding order.

When tie_group_map is provided, positions in the same group can attend to each other (within the same decoding step), enabling tied sampling.

Parameters:
  • decoding_order (Int[Array, 'num_residues']) – The order in which atoms are decoded.

  • chain_idx (Array | None) – Optional chain indices. If provided, atoms can only attend to atoms in the same chain that come before them in the decoding order.

  • tie_group_map (Array | None) – Optional (N,) array mapping each position to a group ID. When provided, positions in the same group are decoded simultaneously and can attend to each other.

  • num_groups (int | None) – Number of unique groups in tie_group_map. Required if tie_group_map is provided. Should equal tie_group_map.max() + 1 when groups are normalized to [0, 1, …, num_groups-1].

Return type:

Bool[Array, 'num_residues num_residues']

Returns:

An autoregressive mask (N, N) where mask[i,j]=1 means position i can attend to position j during decoding.

Example

>>> # Standard AR mask
>>> order = jnp.array([0, 1, 2])
>>> mask = generate_ar_mask(order)
>>> # mask = [[1, 0, 0], [1, 1, 0], [1, 1, 1]]
>>>
>>> # With tied positions
>>> tie_map = jnp.array([0, 1, 0])  # Positions 0 and 2 are tied
>>> mask = generate_ar_mask(order, tie_group_map=tie_map, num_groups=2)
>>> # Positions 0 and 2 can attend to each other