ste#

Straight-Through Estimator (STE) for JAX.

prxteinmpnn.sampling.ste

Note: Only use this for discrete optimization problems where you want to allow gradients to pass through the argmax operation. Useful for tasks like protein sequence optimization when a model outputs logits for amino acid sequences.

Unclear if the optimized sequences will be valid proteins, so this is a heuristic approach to allow gradient-based optimization on discrete outputs to assess how well other samplers are navigating model landscapes.

prxteinmpnn.sampling.ste.straight_through_estimator(logits)[source]#

Implement the straight-through estimator (STE).

Allow gradients to pass through the discrete argmax operation.

Return type:

Float[Array, 'num_residues num_classes']

Parameters:

logits (Logits)

prxteinmpnn.sampling.ste.ste_loss(logits_to_optimize, target_logits, mask, eps=1e-08)[source]#

Calculate cross-entropy between one-hot sequence (from STE) and target distribution.

Parameters:
  • logits_to_optimize (Float[Array, 'num_residues num_classes']) – Logits to optimize, shape (sequence_length, num_classes).

  • target_logits (Float[Array, 'num_residues num_classes']) – Target logits for the sequence, shape (sequence_length, num_classes). These are the logits from the model that we want to match, such as MPNN model’s unconditional logits.

  • mask (Int[Array, 'num_residues num_atoms']) – Boolean mask indicating valid positions in the sequence. Used to ignore padding or invalid positions.

  • eps (float) – Small value to avoid division by zero.

Return type:

Float[Array, '']

Returns:

Loss value as a scalar.

Example

>>> logits_to_optimize = jnp.array([[0.1, 0.9], [0.8, 0.2]])
>>> target_logits = jnp.array([[0.2, 0.8], [0.7, 0.3]])
>>> mask = jnp.array([True, True])
>>> loss = ste_loss(logits_to_optimize, target_logits, mask)