Understanding EDM for World Models

Diffusion models are widely used for visual generation and have recently become important in pixel-space world models. This note introduces EDM (Elucidating the Design Space of Diffusion-Based Generative Models), a continuous-noise diffusion formulation that differs from classic DDPM. I study EDM because it is used by DIAMOND, one of the world models in my research.

Sources:

  1. Elucidating the Design Space of Diffusion-Based Generative Models
  2. DIAMOND world model agent paper

Understanding EDM for World Models

Note: This is a technical post and I’won’t talk about science here (i.e., the motivation, theoratical landscape, and empirical findings), I just list equations.


1. Why EDM

EDM is a diffusion formulation with:

  1. continuous noise level \(\sigma\);
  2. preconditioned denoiser \(D_\theta\);
  3. EDM-specific loss weighting and sampling;
  4. clear separation between raw network output \(F_\theta\) and denoised estimate \(D_\theta\).

The most important point is:

EDM is not just DDPM with a different schedule. It uses a different noise parameterization, a different network target, and a different sampling view.


2. World Model Setting

A visual world model wants to predict the next observation from past observations and actions:

\[ p(o_{t+1}\mid o_{\le t}, a_{\le t}) \tag{1} \label{eq:wm-objective} \]

In DIAMOND-style diffusion world models, the target clean sample is the next frame:

\[ x_0 = o_{t+1} \tag{2} \label{eq:diamond-clean-target} \]

The condition is the previous trajectory context:

\[ \mathrm{cond}_t = (o_{\le t}, a_{\le t}) \tag{3} \label{eq:diamond-condition} \]

So the conditional diffusion model learns:

\[ p_\theta(x_0 \mid \mathrm{cond}_t) \]

or concretely:

\[ p_\theta(o_{t+1}\mid o_{\le t}, a_{\le t}) \]

This is exactly the world-model prediction problem in Equation \(\ref{eq:wm-objective}\), but implemented through conditional denoising.

EDM provides the denoising machinery used to model this conditional distribution.


3. Symbols

Symbol Meaning
\(x_0\) Clean data sample; in DIAMOND, usually the target next frame
\(\epsilon\) Gaussian noise, \(\epsilon\sim\mathcal{N}(0,I)\)
\(\sigma\) Continuous noise level
\(x_\sigma\) Noisy sample
\(\sigma_{\mathrm{data}}\) Data scale hyperparameter
\(F_\theta\) Raw neural network output
\(D_\theta\) EDM denoiser output, interpreted as clean-sample estimate
\(c_{\mathrm{in}}\) Input scaling coefficient
\(c_{\mathrm{out}}\) Output scaling coefficient
\(c_{\mathrm{skip}}\) Skip connection coefficient
\(c_{\mathrm{noise}}\) Noise-level embedding
\(\mathrm{cond}\) Conditioning information, e.g. previous frames and actions
\(\lambda(\sigma)\) EDM denoiser-space loss weight

4. EDM Forward Noising

EDM uses a simple continuous-noise formulation:

\[ x_\sigma = x_0 + \sigma \epsilon, \quad \epsilon\sim\mathcal{N}(0,I) \tag{4} \label{eq:edm-noising} \]

Here \(\sigma\) directly controls the noise strength.

When \(\sigma\) is small, \(x_\sigma\) is close to \(x_0\).
When \(\sigma\) is large, \(x_\sigma\) is mostly noise.

This differs from DDPM, which usually writes the noised sample as:

\[ x_t = \sqrt{\bar{\alpha}_t}x_0 + \sqrt{1-\bar{\alpha}_t}\epsilon \tag{5} \label{eq:ddpm-noising} \]

So the first major difference is: EDM uses the direct continuous-noise form in Equation \(\ref{eq:edm-noising}\), while DDPM usually uses the discrete-time form in Equation \(\ref{eq:ddpm-noising}\).

DDPM EDM
discrete timestep \(t\) continuous noise level \(\sigma\)
noise controlled by \(\bar{\alpha}_t\) noise controlled directly by \(\sigma\)
\(x_t=\sqrt{\bar\alpha_t}x_0+\sqrt{1-\bar\alpha_t}\epsilon\) \(x_\sigma=x_0+\sigma\epsilon\)

5. EDM Denoiser

The ideal denoiser should recover the clean sample:

\[ D_\theta(x_\sigma,\sigma,\mathrm{cond}) \approx x_0 \]

For DIAMOND, this means:

\[ D_\theta( \text{noisy next frame}, \sigma, \text{past frames/actions} ) \approx \text{clean next frame} \]

But EDM does not let the neural network directly output \(D_\theta\). Instead, it defines a preconditioned denoiser:

\[ D_\theta(x,\sigma,\mathrm{cond}) = c_{\mathrm{skip}}(\sigma)x + c_{\mathrm{out}}(\sigma) F_\theta( c_{\mathrm{in}}(\sigma)x, c_{\mathrm{noise}}(\sigma), \mathrm{cond} ) \tag{6} \label{eq:edm-denoiser} \]

Equation \(\ref{eq:edm-denoiser}\) is the main reason EDM code has two different objects: the raw network output \(F_\theta\), and the final denoised estimate \(D_\theta\).

Key distinction:

Object Meaning
\(F_\theta\) raw network output
\(D_\theta\) final denoised estimate
\(c_{\mathrm{skip}}x+c_{\mathrm{out}}F_\theta\) clean-sample estimate

In code terms:

1
2
model_output = F_theta(c_in * x_sigma, c_noise, cond)
denoised = c_skip * x_sigma + c_out * model_output

So model_output is not the final clean image. The clean estimate is denoised.


6. EDM Preconditioning Coefficients

A common EDM parameterization is:

\[ c_{\mathrm{skip}}(\sigma) = \frac{\sigma_{\mathrm{data}}^2} {\sigma^2+\sigma_{\mathrm{data}}^2} \tag{7} \label{eq:c-skip} \]

\[ c_{\mathrm{out}}(\sigma) = \frac{\sigma\sigma_{\mathrm{data}}} {\sqrt{\sigma^2+\sigma_{\mathrm{data}}^2}} \tag{8} \label{eq:c-out} \]

\[ c_{\mathrm{in}}(\sigma) = \frac{1} {\sqrt{\sigma^2+\sigma_{\mathrm{data}}^2}} \tag{9} \label{eq:c-in} \]

\[ c_{\mathrm{noise}}(\sigma) = \frac{1}{4}\log\sigma \tag{10} \label{eq:c-noise} \]

Intuition:

Coefficient Role
\(c_{\mathrm{in}}\) normalizes noisy input
\(c_{\mathrm{skip}}\) passes part of noisy input directly to output
\(c_{\mathrm{out}}\) scales raw network output
\(c_{\mathrm{noise}}\) tells the network the noise level

Together, Equations \(\ref{eq:c-skip}\)-\(\ref{eq:c-noise}\) define the standard EDM-style preconditioning used in Equation \(\ref{eq:edm-denoiser}\).


7. Raw Network Target

EDM wants:

\[ D_\theta(x_\sigma,\sigma,\mathrm{cond}) \approx x_0 \]

Substitute the preconditioned denoiser from Equation \(\ref{eq:edm-denoiser}\):

\[ x_0 \approx c_{\mathrm{skip}}x_\sigma + c_{\mathrm{out}} F_\theta(c_{\mathrm{in}}x_\sigma,c_{\mathrm{noise}},\mathrm{cond}) \]

Solving for the raw network target gives:

\[ F_{\mathrm{target}} = \frac{ x_0-c_{\mathrm{skip}}x_\sigma }{ c_{\mathrm{out}} } \tag{11} \label{eq:edm-raw-target} \]

Therefore EDM trains \(F_\theta\) to predict a preconditioned residual target:

\[ F_\theta(c_{\mathrm{in}}x_\sigma,c_{\mathrm{noise}},\mathrm{cond}) \approx \frac{ x_0-c_{\mathrm{skip}}x_\sigma }{ c_{\mathrm{out}} } \]

This is different from DDPM epsilon prediction:

\[ \epsilon_\theta(x_t,t)\approx \epsilon \tag{12} \label{eq:ddpm-eps-target} \]

So the second major difference is: DDPM commonly trains the network against the noise target in Equation \(\ref{eq:ddpm-eps-target}\), while EDM trains the raw network output against the preconditioned target in Equation \(\ref{eq:edm-raw-target}\).

DDPM epsilon prediction EDM
network predicts noise \(\epsilon\) network predicts preconditioned residual
clean estimate is derived from predicted noise clean estimate is \(c_{\mathrm{skip}}x+c_{\mathrm{out}}F_\theta\)
target is \(\epsilon\) target is \((x_0-c_{\mathrm{skip}}x_\sigma)/c_{\mathrm{out}}\)

8. EDM Loss

The standard EDM loss is often written on the denoised estimate:

\[ \mathcal{L}_{\mathrm{EDM}} = \mathbb{E}_{x_0,\sigma,\epsilon} \left[ \lambda(\sigma) \left\| D_\theta(x_\sigma,\sigma,\mathrm{cond}) - x_0 \right\|^2 \right] \tag{13} \label{eq:edm-denoiser-loss} \]

where \(x_\sigma\) is defined by Equation \(\ref{eq:edm-noising}\).

A common EDM weight is:

\[ \lambda(\sigma) = \frac{ \sigma^2+\sigma_{\mathrm{data}}^2 }{ (\sigma\sigma_{\mathrm{data}})^2 } \tag{14} \label{eq:edm-lambda} \]

Since \(c_{\mathrm{out}}(\sigma)\) is defined by Equation \(\ref{eq:c-out}\), we have:

\[ \lambda(\sigma) = \frac{1}{c_{\mathrm{out}}(\sigma)^2} \tag{15} \label{eq:lambda-cout} \]

This means the denoiser-space loss in Equation \(\ref{eq:edm-denoiser-loss}\) is equivalent to an unweighted MSE on the raw network output:

\[ \mathcal{L}_{\mathrm{EDM}} = \mathbb{E} \left[ \left\| F_\theta(c_{\mathrm{in}}x_\sigma,c_{\mathrm{noise}},\mathrm{cond}) - \frac{x_0-c_{\mathrm{skip}}x_\sigma}{c_{\mathrm{out}}} \right\|^2 \right] \tag{16} \label{eq:edm-raw-output-loss} \]

So implementation often looks like:

1
2
3
4
5
6
7
x_sigma = x0 + sigma * noise

model_output = model(c_in * x_sigma, c_noise, cond)

target = (x0 - c_skip * x_sigma) / c_out

loss = mse(model_output, target)

Important implementation point:

If the code already trains \(F_\theta\) against the preconditioned target in Equation \(\ref{eq:edm-raw-target}\), do not blindly multiply the same \(\lambda(\sigma)\) again unless the implementation intentionally changes the weighting.


9. Conditional EDM for DIAMOND

For a DIAMOND-style world model, the clean target and condition are defined by Equations \(\ref{eq:diamond-clean-target}\) and \(\ref{eq:diamond-condition}\):

\[ x_0 = o_{t+1} \]

\[ \mathrm{cond}_t=(o_{\le t},a_{\le t}) \]

The noised target is:

\[ x_\sigma=o_{t+1}+\sigma\epsilon \tag{17} \label{eq:diamond-noised-target} \]

The denoiser is:

\[ D_\theta(x_\sigma,\sigma,\mathrm{cond}_t) = c_{\mathrm{skip}}x_\sigma + c_{\mathrm{out}} F_\theta(c_{\mathrm{in}}x_\sigma,c_{\mathrm{noise}},\mathrm{cond}_t) \tag{18} \label{eq:diamond-denoiser} \]

The raw target is:

\[ F_{\mathrm{target}} = \frac{ o_{t+1}-c_{\mathrm{skip}}x_\sigma }{ c_{\mathrm{out}} } \tag{19} \label{eq:diamond-raw-target} \]

So DIAMOND uses EDM as a conditional future-frame generator:

\[ p_\theta(o_{t+1}\mid o_{\le t},a_{\le t}) \]

In words:

given past frames and actions, denoise a noisy version of the next frame.

Equations \(\ref{eq:diamond-noised-target}\)-\(\ref{eq:diamond-raw-target}\) are just the conditional world-model version of the generic EDM equations.


10. EDM Sampling

At sampling time, \(x_0\) is unknown.

The sampler starts from high noise:

\[ x_{\sigma_{\max}} \sim \mathcal{N}(0,\sigma_{\max}^2I) \tag{20} \label{eq:edm-sampling-init} \]

Then it follows a decreasing noise schedule:

\[ \sigma_{\max}>\cdots>\sigma_i>\sigma_{i+1}>\cdots>\sigma_{\min} \]

At each step, compute:

\[ D_\theta(x_i,\sigma_i,\mathrm{cond}) \]

A common EDM ODE direction is:

\[ d(x_i,\sigma_i) = \frac{ x_i-D_\theta(x_i,\sigma_i,\mathrm{cond}) }{ \sigma_i } \tag{21} \label{eq:edm-ode-direction} \]

A simple Euler step is:

\[ x_{i+1} = x_i + (\sigma_{i+1}-\sigma_i) d(x_i,\sigma_i) \tag{22} \label{eq:edm-euler-step} \]

For a world model rollout:

1
2
3
4
given past frames and actions
-> sample next frame with conditional EDM
-> append generated frame to context
-> repeat

So DIAMOND rollout is expensive because each predicted frame requires a diffusion sampling process, not just one forward pass. Equations \(\ref{eq:edm-ode-direction}\) and \(\ref{eq:edm-euler-step}\) are the core reason sampling requires multiple denoising steps.


11. EDM vs DDPM Summary

Aspect DDPM EDM
Noise variable discrete timestep \(t\) continuous \(\sigma\)
Forward noising \(\sqrt{\bar\alpha_t}x_0+\sqrt{1-\bar\alpha_t}\epsilon\) \(x_0+\sigma\epsilon\)
Raw network target often \(\epsilon\) preconditioned residual
Denoiser form depends on parameterization \(c_{\mathrm{skip}}x+c_{\mathrm{out}}F_\theta\)
Loss often epsilon MSE denoiser MSE or equivalent raw-output MSE
Sampling reverse diffusion chain numerical integration over \(\sigma\)
Key practical object \(\epsilon_\theta\) \(D_\theta\) and \(F_\theta\) must be distinguished

The key differences are:

  1. EDM uses continuous noise level \(\sigma\), not discrete timestep \(t\).
  2. EDM uses the preconditioned denoiser in Equation \(\ref{eq:edm-denoiser}\), not a plain network output.
  3. EDM’s raw network target is not \(\epsilon\), but the scaled residual in Equation \(\ref{eq:edm-raw-target}\).
  4. EDM sampling is naturally described as numerical integration over decreasing \(\sigma\), as in Equations \(\ref{eq:edm-ode-direction}\) and \(\ref{eq:edm-euler-step}\).
  5. In EDM code, model_output and denoised are different objects.

12. Summary

EDM noising:

\[ x_\sigma=x_0+\sigma\epsilon \]

EDM denoiser:

\[ D_\theta(x,\sigma,\mathrm{cond}) = c_{\mathrm{skip}}x + c_{\mathrm{out}} F_\theta(c_{\mathrm{in}}x,c_{\mathrm{noise}},\mathrm{cond}) \]

Raw network target:

\[ F_{\mathrm{target}} = \frac{ x_0-c_{\mathrm{skip}}x_\sigma }{ c_{\mathrm{out}} } \]

Denoised estimate:

\[ \hat x_0 = D_\theta(x_\sigma,\sigma,\mathrm{cond}) = c_{\mathrm{skip}}x_\sigma + c_{\mathrm{out}}F_\theta \]

For DIAMOND:

\[ x_0=o_{t+1}, \quad \mathrm{cond}_t=(o_{\le t},a_{\le t}) \]

In one sentence:

EDM trains a preconditioned conditional denoising model over continuous noise levels, and DIAMOND uses this machinery to generate future visual observations as a diffusion-based world model.