ste#
Straight-Through Estimator (STE) for JAX.
prxteinmpnn.sampling.ste
Note: Only use this for discrete optimization problems where you want to allow gradients to pass through the argmax operation. Useful for tasks like protein sequence optimization when a model outputs logits for amino acid sequences.
Unclear if the optimized sequences will be valid proteins, so this is a heuristic approach to allow gradient-based optimization on discrete outputs to assess how well other samplers are navigating model landscapes.