Gumbel-Softmax Sampling
Straight-through Gumbel-softmax is a trick for training models with discrete latent variables, similar to the famoud straight through estimator (STE) trick used in VQ-VAE.
The problem is simple:
We want to sample a discrete category in the forward pass, but we still want gradients to flow through it in the backward pass.
A motivating example
Suppose an image observation \(o_t\) is encoded into a discrete latent representation.
First, an encoder maps the image into a continuous vector:
\[ x_t = e_\theta(o_t). \]
Then this vector is quantized into a discrete code:
\[ c_t \in C. \]
Assume the codebook has \(|C|\) possible codes:
\[ C = \{c^{(1)}, c^{(2)}, \dots, c^{(|C|)}\}. \]
Now consider a world model that predicts the next latent code given the current code and action:
\[ p_\phi(c_{t+1} \mid c_t, a_t). \]
Because \(c_{t+1}\) is discrete, the dynamics model can be viewed as a classifier over \(|C|\) possible next codes.
It outputs logits:
\[ l_1, l_2, \dots, l_{|C|}. \]
After softmax, we get probabilities:
\[ p_i = \frac{\exp(l_i)} {\sum_{j=1}^{|C|}\exp(l_j)}. \]
So the next code is sampled from a categorical distribution:
\[ c_{t+1} \sim \operatorname{Categorical}(p_1, p_2, \dots, p_{|C|}). \]
The problem with discrete sampling
Discrete sampling is not differentiable.
For example, suppose the model predicts:
1 | code 1: 0.1 |
Then we sample one code, such as:
1 | sampled code = code 2 |
The sampling operation makes a hard decision. A tiny change in the probabilities may not change the sampled code at all, or it may suddenly change it completely.
Therefore, the gradient cannot naturally pass through the sampling step.
This is a problem when we want to train a multi-step world model by backpropagation through time:
1 | c_t |
If the sampling operation blocks gradients, the model cannot be trained end-to-end in the usual way.
Gumbel-max trick
The Gumbel-max trick gives a way to sample from a categorical distribution using logits.
Given logits \(l_i\), we sample independent Gumbel noise \(g_i\):
\[ g_i = -\log(-\log u_i), \]
where
\[ u_i \sim \operatorname{Uniform}(0,1). \]
Then categorical sampling can be written as:
\[ k = \arg\max_i (l_i + g_i). \]
This produces a discrete sample from the categorical distribution defined by the logits.
However, \(\arg\max\) is still not differentiable.
Gumbel-softmax
Gumbel-softmax replaces the hard \(\arg\max\) with a softmax relaxation:
\[ y_i = \frac{ \exp((l_i + g_i)/\tau) }{ \sum_j \exp((l_j + g_j)/\tau) }. \]
Here \(\tau\) is the temperature.
When \(\tau\) is large, the output is soft:
1 | [0.25, 0.45, 0.30] |
When \(\tau\) is small, the output becomes close to one-hot:
1 | [0.01, 0.98, 0.01] |
So Gumbel-softmax gives a differentiable approximation to categorical sampling.
The forward result is no longer a hard category, but a soft vector:
\[ y \in \Delta^{|C|-1}. \]
This vector can be used as a soft mixture of codebook vectors:
\[ \tilde{c} = \sum_i y_i c^{(i)}. \]
Straight-through Gumbel-softmax
Straight-through Gumbel-softmax combines two ideas:
1 | forward pass: |
In the forward pass, we convert the soft sample \(y\) into a hard one-hot vector:
\[ y_{\text{hard}} = \operatorname{onehot}(\arg\max_i y_i). \]
But in the backward pass, we pretend that the output was the soft differentiable sample \(y\).
A common implementation is:
\[ y_{\text{ST}} = y_{\text{hard}} + \operatorname{sg}(y - y_{\text{hard}}), \]
or equivalently depending on convention:
\[ y_{\text{ST}} = y + \operatorname{sg}(y_{\text{hard}} - y). \]
Here \(\operatorname{sg}(\cdot)\) means stop-gradient.
The important idea is:
1 | forward: |
So the model gets discrete behavior in the forward computation, but still receives useful gradients during training.
Why is this useful for world models?
In a discrete latent world model, we may want the predicted latent state to be a real discrete code.
For example:
\[ \hat{c}_{t+1} \sim p_\phi(\hat{c}_{t+1} \mid \hat{c}_t, a_t). \]
If we train the world model for multiple steps, the prediction at one step becomes the input to the next step:
\[ \hat{c}_{t+1} \rightarrow \hat{c}_{t+2} \rightarrow \hat{c}_{t+3}. \]
Without a differentiable trick, sampling \(\hat{c}_{t+1}\) would block gradients from later losses.
ST Gumbel-softmax solves this by allowing the model to use hard discrete samples in the forward rollout while still allowing gradients to flow backward through the sampled codes.
This is especially useful for multi-step latent dynamics training.
Classification loss
Because the next latent state is discrete, the dynamics model can be trained with cross-entropy.
Suppose the target next code is \(c_{t+1}\), obtained by encoding and quantizing the real next observation \(o_{t+1}\):
\[ c_{t+1} = f(e_\theta(o_{t+1})). \]
The dynamics model predicts:
\[ p_\phi(\hat{c}_{t+1} \mid c_t, a_t). \]
Then the latent consistency loss can be:
\[ \mathcal{L}_{\text{dyn}} = \operatorname{CE} \left( p_\phi(\hat{c}_{t+1} \mid c_t, a_t), c_{t+1} \right). \]
So the model is trained like a classifier over codebook entries.
Training vs planning
During training, sampling can be useful because the model should learn to handle its own sampled predictions during multi-step rollouts.
So training may use:
\[ \hat{c}_{h+1} \sim p_\phi(\hat{c}_{h+1} \mid \hat{c}_h, a_h), \]
implemented with ST Gumbel-softmax.
During planning, however, sampling may be undesirable.
If a planner evaluates the same action sequence multiple times, stochastic sampling can give different predicted futures and different scores. This adds noise to planning.
Therefore, some methods use the expected code during planning:
\[ \hat{c}_{h+1} = \sum_i p_\phi(\hat{c}_{h+1}=c^{(i)} \mid \hat{c}_h, a_h) c^{(i)}. \]
This is deterministic and makes trajectory evaluation more stable.
The tradeoff is that the expected code may not be a valid discrete codebook entry. But if the codebook has a meaningful geometry, the expected code can act like an interpolation between codes.
Summary
ST Gumbel-softmax is useful when a model contains discrete latent variables but still needs gradient-based training.
The key idea is:
1 | forward pass: |
In discrete world models, this allows multi-step latent rollouts such as:
1 | current code |
while still training the encoder and dynamics model end-to-end with backpropagation.
In short:
1 | Gumbel-softmax: |