Source code for prxteinmpnn.sampling.ste
"""Straight-Through Estimator (STE) for JAX.
prxteinmpnn.sampling.ste
Note: Only use this for discrete optimization problems where you want to allow gradients
to pass through the argmax operation. Useful for tasks like protein sequence
optimization when a model outputs logits for amino acid sequences.
Unclear if the optimized sequences will be valid proteins, so this is a heuristic
approach to allow gradient-based optimization on discrete outputs to assess how well
other samplers are navigating model landscapes.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import jax
import jax.numpy as jnp
if TYPE_CHECKING:
from prxteinmpnn.utils.types import (
AtomMask,
CEELoss,
Logits,
)
DEFAULT_EPS = 1e-8
[docs]
def straight_through_estimator(logits: Logits) -> Logits:
"""Implement the straight-through estimator (STE).
Allow gradients to pass through the discrete argmax operation.
"""
probs = jax.nn.softmax(logits, axis=-1)
one_hot = jax.nn.one_hot(jnp.argmax(probs, axis=-1), num_classes=probs.shape[-1])
return jax.lax.stop_gradient(one_hot - probs) + probs
[docs]
def ste_loss(
logits_to_optimize: Logits,
target_logits: Logits,
mask: AtomMask,
eps: float = DEFAULT_EPS,
) -> CEELoss:
"""Calculate cross-entropy between one-hot sequence (from STE) and target distribution.
Args:
logits_to_optimize: Logits to optimize, shape (sequence_length, num_classes).
target_logits: Target logits for the sequence, shape (sequence_length, num_classes).
These are the logits from the model that we want to match, such as MPNN model's
unconditional logits.
mask: Boolean mask indicating valid positions in the sequence.
Used to ignore padding or invalid positions.
eps: Small value to avoid division by zero.
Returns:
Loss value as a scalar.
Example:
>>> logits_to_optimize = jnp.array([[0.1, 0.9], [0.8, 0.2]])
>>> target_logits = jnp.array([[0.2, 0.8], [0.7, 0.3]])
>>> mask = jnp.array([True, True])
>>> loss = ste_loss(logits_to_optimize, target_logits, mask)
"""
seq_one_hot = straight_through_estimator(logits_to_optimize)
target_log_probs = jax.nn.log_softmax(target_logits)
loss_per_position = -(seq_one_hot * target_log_probs).sum(axis=-1)
return (loss_per_position * mask).sum() / (mask.sum() + eps)