ste#
Straight-Through Estimator (STE) for JAX.
prxteinmpnn.sampling.ste
Note: Only use this for discrete optimization problems where you want to allow gradients to pass through the argmax operation. Useful for tasks like protein sequence optimization when a model outputs logits for amino acid sequences.
Unclear if the optimized sequences will be valid proteins, so this is a heuristic approach to allow gradient-based optimization on discrete outputs to assess how well other samplers are navigating model landscapes.
- prxteinmpnn.sampling.ste.straight_through_estimator(logits)[source]#
Implement the straight-through estimator (STE).
Allow gradients to pass through the discrete argmax operation.
- Return type:
Float[Array, 'num_residues num_classes']
- Parameters:
logits (Logits)
- prxteinmpnn.sampling.ste.ste_loss(logits_to_optimize, target_logits, mask, eps=1e-08)[source]#
Calculate cross-entropy between one-hot sequence (from STE) and target distribution.
- Parameters:
logits_to_optimize (
Float[Array, 'num_residues num_classes']
) – Logits to optimize, shape (sequence_length, num_classes).target_logits (
Float[Array, 'num_residues num_classes']
) – Target logits for the sequence, shape (sequence_length, num_classes). These are the logits from the model that we want to match, such as MPNN model’s unconditional logits.mask (
Int[Array, 'num_residues num_atoms']
) – Boolean mask indicating valid positions in the sequence. Used to ignore padding or invalid positions.eps (
float
) – Small value to avoid division by zero.
- Return type:
Float[Array, '']
- Returns:
Loss value as a scalar.
Example
>>> logits_to_optimize = jnp.array([[0.1, 0.9], [0.8, 0.2]]) >>> target_logits = jnp.array([[0.2, 0.8], [0.7, 0.3]]) >>> mask = jnp.array([True, True]) >>> loss = ste_loss(logits_to_optimize, target_logits, mask)