Source code for prxteinmpnn.utils.decoding_order
"""Utilities for decoding order generation.
prxteinmpnn.utils.decoding_order
"""
from collections.abc import Callable
from functools import partial
import jax
from jaxtyping import Int, PRNGKeyArray
from .types import (
DecodingOrder,
)
DecodingOrderInputs = tuple[PRNGKeyArray, Int]
DecodingOrderOutputs = tuple[DecodingOrder, PRNGKeyArray]
DecodingOrderFn = Callable[[*DecodingOrderInputs], DecodingOrderOutputs]
[docs]
@partial(jax.jit, static_argnames=("num_residues",))
def random_decoding_order(
prng_key: PRNGKeyArray,
num_residues: int,
) -> DecodingOrderOutputs:
"""Return a random decoding order."""
current_key, next_key = jax.random.split(prng_key)
decoding_order = jax.random.permutation(current_key, jax.numpy.arange(0, num_residues))
decoding_order = jax.numpy.asarray(decoding_order, dtype=jax.numpy.int32)
return decoding_order, next_key
[docs]
@partial(jax.jit, static_argnames=("num_residues",))
def single_decoding_order(
prng_key: PRNGKeyArray,
num_residues: int,
) -> DecodingOrderOutputs:
"""Return a single decoding order (identity)."""
decoding_order = jax.random.permutation(prng_key, jax.numpy.arange(0, num_residues))
decoding_order = jax.numpy.asarray(decoding_order, dtype=jax.numpy.int32)
return decoding_order, prng_key