ste

ste#

Straight-Through Estimator (STE) for JAX.

prxteinmpnn.sampling.ste

Note: Only use this for discrete optimization problems where you want to allow gradients to pass through the argmax operation. Useful for tasks like protein sequence optimization when a model outputs logits for amino acid sequences.

Unclear if the optimized sequences will be valid proteins, so this is a heuristic approach to allow gradient-based optimization on discrete outputs to assess how well other samplers are navigating model landscapes.