Training Diffusion Models: The SDE Perspective¶
This document explains how diffusion models are trained from the SDE perspective, clarifying a common source of confusion: SDE solvers are NOT used during training.
Training uses closed-form solutions of the forward SDE, making it computationally simple and efficient.
Overview¶
The Key Insight¶
Common misconception: Since diffusion models are based on SDEs, training must involve solving SDEs.
Reality: Training uses closed-form marginals from the forward SDE solution. No numerical SDE solver is needed.
Training vs Sampling¶
| Phase | SDE Solver? | What's Used |
|---|---|---|
| Training | ❌ No | Closed-form \(q(x_t \mid x_0)\) |
| Sampling | ✅ Yes | Numerical discretization of reverse SDE/ODE |
This document focuses on training — see 03_sde_sampling.md for sampling.
The Forward SDE Solution¶
VP-SDE (Variance-Preserving)¶
The forward SDE is:
Closed-form solution (see 03_solving_vpsde.md for derivation):
where:
- \(\bar{\alpha}_t = \exp\left(-\int_0^t \beta(s)\,ds\right)\)
- \(\epsilon \sim \mathcal{N}(0, I)\)
Key property: We can sample \(x_t\) directly from \(x_0\) without simulating the SDE!
Marginal Distribution¶
The marginal distribution at time \(t\) is:
This is all we need for training — no SDE solver required.
Training Objective¶
Score Matching¶
The goal is to learn the score function \(\nabla_x \log p_t(x)\).
Theoretical objective (intractable):
Practical objective (tractable via denoising score matching):
where \(x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1-\bar{\alpha}_t} \epsilon\).
Connection to Noise Prediction¶
The conditional score has a simple form:
In practice, we predict noise instead of score:
Relationship:
DDPM Training Loss¶
The standard DDPM loss is:
Why this works: 1. Equivalent to score matching with specific weighting \(\lambda(t) = 1\) 2. Simpler to implement (predict noise, not score) 3. Better empirical performance
Training Algorithm¶
Pseudocode¶
# Training loop for VP-SDE diffusion model
for epoch in range(num_epochs):
for batch in dataloader:
x_0 = batch # Clean data samples
# 1. Sample random timesteps
t = torch.rand(batch_size) * T # t ∈ [0, T]
# 2. Sample noise
epsilon = torch.randn_like(x_0)
# 3. Compute α̅_t (cumulative product)
alpha_bar_t = compute_alpha_bar(t) # exp(-∫₀ᵗ β(s)ds)
# 4. Create noisy samples (closed-form!)
x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * epsilon
# 5. Predict noise
epsilon_pred = model(x_t, t)
# 6. Compute loss
loss = F.mse_loss(epsilon_pred, epsilon)
# 7. Update model
loss.backward()
optimizer.step()
optimizer.zero_grad()
Key observation: Step 4 uses the closed-form solution — no SDE solver!
PyTorch Implementation¶
import torch
import torch.nn as nn
import torch.nn.functional as F
class DiffusionTrainer:
def __init__(self, model, beta_schedule, T=1.0):
self.model = model
self.T = T
# Precompute α̅_t for efficiency
# In continuous time: α̅(t) = exp(-∫₀ᵗ β(s)ds)
# For linear schedule: β(t) = β_min + t(β_max - β_min)
self.beta_min = beta_schedule['beta_min']
self.beta_max = beta_schedule['beta_max']
def compute_alpha_bar(self, t):
"""Compute α̅_t = exp(-∫₀ᵗ β(s)ds) for linear schedule."""
# For β(s) = β_min + s(β_max - β_min):
# ∫₀ᵗ β(s)ds = β_min*t + (β_max - β_min)*t²/2
integral = self.beta_min * t + 0.5 * (self.beta_max - self.beta_min) * t**2
return torch.exp(-integral)
def training_step(self, x_0):
"""Single training step."""
batch_size = x_0.shape[0]
# Sample timesteps uniformly
t = torch.rand(batch_size, device=x_0.device) * self.T
# Sample noise
epsilon = torch.randn_like(x_0)
# Compute α̅_t
alpha_bar_t = self.compute_alpha_bar(t)
# Reshape for broadcasting
alpha_bar_t = alpha_bar_t.view(-1, 1, 1, 1) # For images
# Create noisy samples (closed-form marginal)
sqrt_alpha_bar_t = torch.sqrt(alpha_bar_t)
sqrt_one_minus_alpha_bar_t = torch.sqrt(1 - alpha_bar_t)
x_t = sqrt_alpha_bar_t * x_0 + sqrt_one_minus_alpha_bar_t * epsilon
# Predict noise
epsilon_pred = self.model(x_t, t)
# Compute loss
loss = F.mse_loss(epsilon_pred, epsilon)
return loss
# Usage
model = UNet(...) # Your noise prediction network
trainer = DiffusionTrainer(
model=model,
beta_schedule={'beta_min': 0.1, 'beta_max': 20.0},
T=1.0
)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
for epoch in range(num_epochs):
for batch in dataloader:
loss = trainer.training_step(batch)
loss.backward()
optimizer.step()
optimizer.zero_grad()
Why No SDE Solver in Training?¶
Mathematical Reason¶
The forward SDE has a closed-form solution:
This means:
- We can sample \(x_t\) directly from \(x_0\)
- No need to simulate the SDE step-by-step
- Training is fast and exact
Computational Advantage¶
With SDE solver (hypothetical, not used): - Start with \(x_0\) - Simulate many small steps: \(x_0 \to x_{\Delta t} \to x_{2\Delta t} \to \cdots \to x_t\) - Slow and introduces discretization error
With closed-form (actual approach): - Directly compute \(x_t\) from \(x_0\) in one step - Fast and exact - This is what makes training practical!
Connection to DDPM¶
Discrete DDPM Training¶
DDPM uses discrete timesteps \(k = 1, 2, \ldots, N\):
# DDPM training
t = torch.randint(0, N, (batch_size,))
alpha_bar_t = alpha_bar[t] # Precomputed cumulative product
x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * epsilon
Continuous SDE Training¶
SDE uses continuous time \(t \in [0, T]\):
# SDE training
t = torch.rand(batch_size) * T
alpha_bar_t = compute_alpha_bar(t) # exp(-∫₀ᵗ β(s)ds)
x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * epsilon
Key difference: How \(\bar{\alpha}_t\) is computed - DDPM: Discrete products \(\prod_{i=1}^t \alpha_i\) - SDE: Continuous integral \(\exp(-\int_0^t \beta(s)ds)\)
Training loop: Identical structure!
Noise Schedules¶
Linear Schedule (Most Common)¶
Cumulative:
Typical values: \(\beta_{\min} = 0.1\), \(\beta_{\max} = 20\)
Cosine Schedule¶
where \(s = 0.008\) is a small offset.
Advantage: More uniform SNR across timesteps.
Implementation¶
def linear_beta_schedule(t, beta_min=0.1, beta_max=20.0):
"""Linear noise schedule."""
return beta_min + t * (beta_max - beta_min)
def compute_alpha_bar_linear(t, beta_min=0.1, beta_max=20.0):
"""Compute α̅_t for linear schedule."""
integral = beta_min * t + 0.5 * (beta_max - beta_min) * t**2
return torch.exp(-integral)
def compute_alpha_bar_cosine(t, s=0.008):
"""Compute α̅_t for cosine schedule."""
def f(t):
return torch.cos((t + s) / (1 + s) * torch.pi / 2) ** 2
return f(t) / f(torch.zeros_like(t))
Training Strategies¶
Time Sampling¶
Uniform sampling (standard):
Importance sampling (advanced): - Sample more frequently at difficult timesteps - Weight loss by inverse sampling probability
Weighting¶
Simple weighting (DDPM):
SNR weighting (improved):
$$
\mathcal{L} = \mathbb{E}{t, x_0, \epsilon} \left[ \frac{1}{\text{SNR}(t)} \left| \epsilon\theta(x_t, t) - \epsilon \right|^2 \right] $$
where \(\text{SNR}(t) = \bar{\alpha}_t / (1 - \bar{\alpha}_t)\).
Exponential Moving Average (EMA)¶
Maintain EMA of model parameters for better sample quality:
ema_model = copy.deepcopy(model)
ema_decay = 0.9999
# After each training step
for param_ema, param in zip(ema_model.parameters(), model.parameters()):
param_ema.data.mul_(ema_decay).add_(param.data, alpha=1 - ema_decay)
# Use ema_model for sampling
Network Architecture¶
U-Net (Standard for Images)¶
class UNet(nn.Module):
def __init__(self, in_channels=3, out_channels=3,
base_channels=128, time_emb_dim=256):
super().__init__()
# Time embedding
self.time_mlp = nn.Sequential(
SinusoidalPositionEmbeddings(time_emb_dim),
nn.Linear(time_emb_dim, time_emb_dim),
nn.GELU(),
nn.Linear(time_emb_dim, time_emb_dim),
)
# Encoder
self.encoder = nn.ModuleList([
ResBlock(in_channels, base_channels, time_emb_dim),
ResBlock(base_channels, base_channels * 2, time_emb_dim),
ResBlock(base_channels * 2, base_channels * 4, time_emb_dim),
])
# Bottleneck
self.bottleneck = ResBlock(base_channels * 4, base_channels * 4, time_emb_dim)
# Decoder
self.decoder = nn.ModuleList([
ResBlock(base_channels * 8, base_channels * 2, time_emb_dim),
ResBlock(base_channels * 4, base_channels, time_emb_dim),
ResBlock(base_channels * 2, base_channels, time_emb_dim),
])
# Output
self.out = nn.Conv2d(base_channels, out_channels, 1)
def forward(self, x, t):
# Embed time
t_emb = self.time_mlp(t)
# Encoder with skip connections
skips = []
for block in self.encoder:
x = block(x, t_emb)
skips.append(x)
x = F.avg_pool2d(x, 2)
# Bottleneck
x = self.bottleneck(x, t_emb)
# Decoder with skip connections
for block in self.decoder:
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = torch.cat([x, skips.pop()], dim=1)
x = block(x, t_emb)
return self.out(x)
Time Embedding¶
Sinusoidal embeddings (like Transformer positional encoding):
class SinusoidalPositionEmbeddings(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, time):
device = time.device
half_dim = self.dim // 2
embeddings = math.log(10000) / (half_dim - 1)
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
embeddings = time[:, None] * embeddings[None, :]
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
return embeddings
Training Tips¶
Hyperparameters¶
Learning rate:
- Start: \(1 \times 10^{-4}\)
- Use cosine annealing or constant
Batch size:
- Larger is better (256-2048)
- Use gradient accumulation if memory limited
Training steps:
- Images: 500K - 1M steps
- High resolution: 1M+ steps
Monitoring¶
Track during training: 1. Loss: Should decrease steadily 2. Sample quality: Generate samples every N steps 3. EMA decay: Use EMA model for evaluation
Early stopping:
- Monitor FID or other metrics on validation set
- Diffusion models benefit from long training
Debugging¶
Common issues:
- Loss not decreasing:
- Check data normalization (typically [-1, 1])
- Verify \(\bar{\alpha}_t\) computation
-
Check time embedding
-
Poor sample quality:
- Train longer
- Use EMA
-
Try different noise schedule
-
Mode collapse:
- Rare in diffusion models
- Check data diversity
Comparison: Training vs Sampling¶
Training (This Document)¶
What: Learn to predict noise \(\epsilon_\theta(x_t, t)\)
How: 1. Sample \(x_0\), \(t\), \(\epsilon\) 2. Compute \(x_t\) using closed-form 3. Predict \(\epsilon_\theta(x_t, t)\) 4. Minimize \(\|\epsilon_\theta - \epsilon\|^2\)
SDE solver: ❌ Not used
Speed: Fast (single forward pass per sample)
Sampling (See 03_sde_sampling.md)¶
What: Generate samples from learned model
How: 1. Start with \(x_T \sim \mathcal{N}(0, I)\) 2. Iteratively denoise using reverse SDE/ODE 3. End with \(x_0 \approx\) data sample
SDE solver: ✅ Used (Euler, RK4, adaptive)
Speed: Slow (many steps required)
Key Takeaways¶
Conceptual¶
- Training uses closed-form marginals — no SDE solver needed
- The forward SDE solution is exact — we can sample \(x_t\) directly from \(x_0\)
- Score matching = denoising — predicting noise is equivalent to learning scores
- Training is the same as DDPM — just continuous time instead of discrete
Practical¶
- Implementation is simple: Sample \(t\), add noise, predict, compute loss
- No numerical integration during training
- Fast and stable — closed-form solution avoids discretization errors
- SDE solvers only used for sampling (generation)
Mathematical¶
- Forward SDE: \(dx = -\frac{1}{2}\beta(t) x\,dt + \sqrt{\beta(t)}\,dw\)
- Closed-form: \(x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1-\bar{\alpha}_t} \epsilon\)
- Training loss: \(\mathbb{E}[\|\epsilon_\theta(x_t, t) - \epsilon\|^2]\)
- No solver needed: Direct sampling from \(q(x_t | x_0)\)
Related Documents¶
SDE Documentation¶
- 00_sde_overview.md — High-level SDE introduction
- 01_diffusion_sde_view.md — Detailed SDE formulation
- 03_solving_vpsde.md — Derivation of closed-form solution
- 03_sde_sampling.md — How to sample (where SDE solvers are used)
Related Topics¶
- DDPM Training — Discrete version
- Flow Matching Training — Alternative approach
Summary¶
Training diffusion models from the SDE perspective is simple:
- Use closed-form marginals from the forward SDE solution
- No SDE solver required — direct sampling of \(x_t\) from \(x_0\)
- Same as DDPM training — just continuous time
- Fast and exact — no discretization errors
The confusion arises because:
- The SDE formulation seems complex
- But training only uses the closed-form solution
- SDE solvers are only needed for sampling (generation)
Bottom line: Training is straightforward regression on noise prediction, using the exact solution of the forward SDE.