Source code for prxteinmpnn.utils.autoregression

"""Utilities for autoregression.

prxteinmpnn.utils.autoregression
"""

import jax

from .types import AtomMask, DecodingOrder


[docs] @jax.jit def generate_ar_mask(decoding_order: DecodingOrder) -> AtomMask: """Get the autoregressive mask for the given decoding order. Args: decoding_order: The order in which atoms are decoded. Returns: An atom mask where each atom can only attend to itself and previous atoms. """ row_indices = decoding_order[:, None] col_indices = decoding_order[None, :] return (row_indices >= col_indices).astype(int)