Straight-Through Estimator

Straight-through estimator, or STE, is a trick for training models that contain non-differentiable operations, notably the quantization operation commonly used in VQ-VAE.

The core idea is:

Use the non-differentiable operation in the forward pass, but replace its gradient with a simple surrogate in the backward pass.

The problem

Some operations are useful in neural networks but are not differentiable.

For example, quantization:

\[ y = \operatorname{round}(x). \]

In the forward pass, this maps a continuous value to a discrete value.

1
2
1.2 -> 1
1.8 -> 2

But \(\operatorname{round}(\cdot)\) has zero gradient almost everywhere, so normal backpropagation cannot train the earlier layers well.

Straight-through idea

STE uses the real operation in the forward pass:

\[ y = \operatorname{round}(x). \]

But in the backward pass, it pretends that the operation was identity:

\[ \frac{\partial y}{\partial x} \approx 1. \]

So the model behaves discretely in the forward computation, but gradients can still flow backward.

1
2
3
4
5
forward:
use round(x)

backward:
pretend round(x) = x

Stop-gradient form

A common implementation is:

\[ y = x + \operatorname{sg}(\operatorname{round}(x) - x), \]

where \(\operatorname{sg}(\cdot)\) means stop-gradient.

In the forward pass:

\[ y = x + (\operatorname{round}(x) - x) = \operatorname{round}(x). \]

In the backward pass, the stopped part has no gradient, so:

\[ \frac{\partial y}{\partial x} = 1. \]

Thus, the forward value is quantized, but the backward gradient is copied through.

Why it is useful

STE is commonly used when we want discrete behavior but still want gradient-based training.

Examples include:

1
2
3
4
5
6
binary neural networks
quantized neural networks
vector quantization
finite scalar quantization
discrete latent variables
straight-through Gumbel-Softmax

For example, in a discrete latent world model, an encoder may produce a continuous vector and then quantize it into a code:

\[ c = f(e_\theta(o)). \]

The quantization step is not differentiable. STE allows gradients from later losses to flow back into the encoder.

Relation to ST Gumbel-Softmax

Straight-through Gumbel-Softmax is one application of STE.

It uses a hard discrete sample in the forward pass, but uses the soft Gumbel-Softmax relaxation for gradients in the backward pass.

1
2
3
4
5
STE:
general trick for non-differentiable operations

ST Gumbel-Softmax:
STE applied to categorical sampling

Summary

STE is a biased gradient estimator, but it is simple and often effective.

Its key idea is:

1
2
3
4
5
forward:
use the real discrete or non-differentiable operation

backward:
replace its gradient with a simple differentiable surrogate

In short:

STE lets us train models with discrete operations using ordinary backpropagation.