autoregression#
Autoregression utilities.
- prxteinmpnn.utils.autoregression.get_decoding_step_map(tie_group_map, group_decoding_order, num_groups=None)[source]#
Map each residue to its decoding step index based on group order.
- Parameters:
- Returns:
(N,) array, decoding step for each residue.
- Return type:
decoding_step_map
- prxteinmpnn.utils.autoregression.make_autoregressive_mask(decoding_step_map)[source]#
Create an (N, N) AR mask for group-based decoding.
Positions at step i can attend to all positions at steps <= i. This allows positions in the same group (same step) to attend to each other.
- Parameters:
decoding_step_map (
Array) – (N,) array, decoding step for each residue.- Returns:
- (N, N) boolean array where mask[i,j]=True means position i
can attend to position j.
- Return type:
mask
- prxteinmpnn.utils.autoregression.resolve_tie_groups(spec, combined_protein, structure_mappings=None, num_structures=None)[source]#
Resolve tie groups for tied_positions modes.
- Parameters:
spec (
RunSpecification) – RunSpecification with tied_positions.combined_protein (
Protein) – Protein dataclass with batch_dim=1 (1, seq_len, …).structure_mappings (
Sequence[dict] |None) – Optional, for ‘auto’ mode.num_structures (
int|None) – Optional, number of structures concatenated (for ‘direct’ mode).
- Returns:
jnp.ndarray of shape (n,) with group ids.
- Return type:
tie_group_map
- prxteinmpnn.utils.autoregression.generate_ar_mask(decoding_order, chain_idx=None, tie_group_map=None, num_groups=None)[source]#
Get the autoregressive mask for the given decoding order.
When tie_group_map is provided, positions in the same group can attend to each other (within the same decoding step), enabling tied sampling.
- Parameters:
decoding_order (
Int[Array, 'num_residues']) – The order in which atoms are decoded.chain_idx (
Array|None) – Optional chain indices. If provided, atoms can only attend to atoms in the same chain that come before them in the decoding order.tie_group_map (
Array|None) – Optional (N,) array mapping each position to a group ID. When provided, positions in the same group are decoded simultaneously and can attend to each other.num_groups (
int|None) – Number of unique groups in tie_group_map. Required if tie_group_map is provided. Should equal tie_group_map.max() + 1 when groups are normalized to [0, 1, …, num_groups-1].
- Return type:
Bool[Array, 'num_residues num_residues']- Returns:
An autoregressive mask (N, N) where mask[i,j]=1 means position i can attend to position j during decoding.
Example
>>> # Standard AR mask >>> order = jnp.array([0, 1, 2]) >>> mask = generate_ar_mask(order) >>> # mask = [[1, 0, 0], [1, 1, 0], [1, 1, 1]] >>> >>> # With tied positions >>> tie_map = jnp.array([0, 1, 0]) # Positions 0 and 2 are tied >>> mask = generate_ar_mask(order, tie_group_map=tie_map, num_groups=2) >>> # Positions 0 and 2 can attend to each other