Straight-Through Estimator
Straight-through estimator, or STE, is a trick for training models that contain non-differentiable operations, notably the quantization operation commonly used in VQ-VAE.
The core idea is:
Use the non-differentiable operation in the forward pass, but replace its gradient with a simple surrogate in the backward pass.
The problem
Some operations are useful in neural networks but are not differentiable.
For example, quantization:
\[ y = \operatorname{round}(x). \]
In the forward pass, this maps a continuous value to a discrete value.
1 | 1.2 -> 1 |
But \(\operatorname{round}(\cdot)\) has zero gradient almost everywhere, so normal backpropagation cannot train the earlier layers well.
Straight-through idea
STE uses the real operation in the forward pass:
\[ y = \operatorname{round}(x). \]
But in the backward pass, it pretends that the operation was identity:
\[ \frac{\partial y}{\partial x} \approx 1. \]
So the model behaves discretely in the forward computation, but gradients can still flow backward.
1 | forward: |
Stop-gradient form
A common implementation is:
\[ y = x + \operatorname{sg}(\operatorname{round}(x) - x), \]
where \(\operatorname{sg}(\cdot)\) means stop-gradient.
In the forward pass:
\[ y = x + (\operatorname{round}(x) - x) = \operatorname{round}(x). \]
In the backward pass, the stopped part has no gradient, so:
\[ \frac{\partial y}{\partial x} = 1. \]
Thus, the forward value is quantized, but the backward gradient is copied through.
Why it is useful
STE is commonly used when we want discrete behavior but still want gradient-based training.
Examples include:
1 | binary neural networks |
For example, in a discrete latent world model, an encoder may produce a continuous vector and then quantize it into a code:
\[ c = f(e_\theta(o)). \]
The quantization step is not differentiable. STE allows gradients from later losses to flow back into the encoder.
Relation to ST Gumbel-Softmax
Straight-through Gumbel-Softmax is one application of STE.
It uses a hard discrete sample in the forward pass, but uses the soft Gumbel-Softmax relaxation for gradients in the backward pass.
1 | STE: |
Summary
STE is a biased gradient estimator, but it is simple and often effective.
Its key idea is:
1 | forward: |
In short:
STE lets us train models with discrete operations using ordinary backpropagation.