Source code for prxteinmpnn.utils.aa_convert

"""Utility functions for converting between AlphaFold and ProteinMPNN amino acid orders."""

from collections.abc import Mapping

import jax
import jax.numpy as jnp
import numpy as np

from prxteinmpnn.utils.residue_constants import (
  restype_order,
  restype_order_with_x,
  unk_restype_index,
)
from prxteinmpnn.utils.types import ProteinSequence

MPNN_ALPHABET = "ACDEFGHIKLMNPQRSTVWYX"
AF_ALPHABET = "ARNDCQEGHILKMFPSTWYVX"

_AF_TO_MPNN_PERM = jnp.array(
  [MPNN_ALPHABET.index(k) for k in AF_ALPHABET],
)

_MPNN_TO_AF_PERM = jnp.array(
  [AF_ALPHABET.index(k) for k in MPNN_ALPHABET],
)


[docs] def af_to_mpnn(sequence: ProteinSequence) -> ProteinSequence: """Convert a sequence of integer indices from AlphaFold's to ProteinMPNN's alphabet order.""" return _AF_TO_MPNN_PERM[sequence].astype(jnp.int8)
[docs] def mpnn_to_af(sequence: ProteinSequence) -> ProteinSequence: """Convert a sequence of integer indices from ProteinMPNN's to AlphaFold's alphabet order.""" return _MPNN_TO_AF_PERM[sequence].astype(jnp.int8)
[docs] def string_key_to_index( string_keys: np.ndarray, key_map: Mapping[str, int], unk_index: int | None = None, ) -> jax.Array: """Convert string keys to integer indices based on a mapping. Efficient vectorized implementation to convert a 1D array of string keys to a 1D array of integer indices using a provided mapping. If a key is not found in the mapping, it is replaced with a specified unknown index. Args: string_keys: A 1D array of string keys. key_map: A dictionary mapping string keys to integer indices. unk_index: The index to use for unknown keys not found in the mapping. If None, uses the length of the key_map as the unknown index. Returns: A 1D array of integer indices corresponding to the string keys. """ if unk_index is None: unk_index = len(key_map) sorted_keys = np.array(sorted(key_map.keys())) sorted_values = np.array([key_map[k] for k in sorted_keys]) indices = np.searchsorted(sorted_keys, string_keys) indices = np.clip(indices, 0, len(sorted_keys) - 1) found_keys = sorted_keys[indices] is_known = found_keys == string_keys return jnp.where(is_known, sorted_values[indices], unk_index)
[docs] def string_to_protein_sequence( sequence: str, aa_map: dict | None = None, unk_index: int | None = None, ) -> ProteinSequence: """Convert a string sequence to a ProteinSequence. Args: sequence: A string containing the protein sequence. aa_map: A dictionary mapping amino acid names to integer indices. If None, uses the default `restype_order` mapping. unk_index: The index to use for unknown amino acids not found in the mapping. If None, uses `unk_restype_index`. Returns: A ProteinSequence containing the amino acid type indices corresponding to the input string. """ if unk_index is None: unk_index = unk_restype_index if aa_map is None: aa_map = restype_order # Corrected line: Split the string into a list of characters for string_key_to_index return af_to_mpnn( string_key_to_index(np.array(list(sequence)), aa_map, unk_index).astype( jnp.int8, ), ) # This part was already correct for when aa_map is explicitly provided, # as it correctly uses list(sequence). return string_key_to_index(np.array(list(sequence)), aa_map, unk_index).astype( jnp.int8, )
[docs] def protein_sequence_to_string( sequence: ProteinSequence, aa_map: dict | None = None, ) -> str: """Convert a ProteinSequence to a string. Args: sequence: A ProteinSequence containing amino acid type indices. aa_map: A dictionary mapping amino acid type indices to their corresponding names. If None, uses the default `restype_order` mapping. Returns: A string representation of the protein sequence. """ if aa_map is None: aa_map = {i: aa for aa, i in restype_order_with_x.items()} af_seq = mpnn_to_af(jnp.asarray(sequence)).astype(np.int32) return "".join([aa_map.get(int(aa), "X") for aa in af_seq])