Flow Matching Training¶
This document covers the practical aspects of training flow matching models: loss functions, network architectures, implementation details, training strategies, and best practices.
Training Overview¶
The Training Loop¶
Flow matching training is remarkably simple compared to diffusion models:
Algorithm:
for batch in dataloader:
# 1. Sample data and noise
x0 = batch # data
x1 = sample_noise() # noise
# 2. Sample time uniformly
t = uniform(0, 1)
# 3. Interpolate
xt = (1 - t) * x0 + t * x1
# 4. Compute target velocity
target = x1 - x0
# 5. Predict and compute loss
pred = model(xt, t)
loss = mse_loss(pred, target)
# 6. Update
loss.backward()
optimizer.step()
Key simplicity: Direct regression with MSE loss, no complex score matching objectives.
Loss Functions¶
Conditional Flow Matching Loss¶
The standard loss for flow matching:
Components:
- \(t \sim \text{Uniform}[0, 1]\): Random time
- \(x_0 \sim p_{\text{data}}\): Data sample
- \(x_1 \sim p_{\text{noise}}\): Noise sample
- \(x_t = \psi_t(x_0, x_1)\): Interpolated point
- \(u_t(x_0, x_1) = \frac{d}{dt}\psi_t(x_0, x_1)\): Target velocity
Rectified Flow Loss¶
For linear interpolation \(x_t = (1-t)x_0 + tx_1\):
Simplification: Target velocity is constant: \(u_t = x_1 - x_0\)
PyTorch implementation:
def rectified_flow_loss(model, x0, x1, t):
"""
Compute rectified flow loss.
Args:
model: Neural network v_theta(x, t)
x0: Data samples [batch_size, ...]
x1: Noise samples [batch_size, ...]
t: Time values [batch_size]
Returns:
loss: Scalar loss value
"""
# Interpolate
t_expanded = t.view(-1, *([1] * (x0.ndim - 1)))
xt = (1 - t_expanded) * x0 + t_expanded * x1
# Target velocity
target = x1 - x0
# Predict velocity
pred = model(xt, t)
# MSE loss
loss = F.mse_loss(pred, target)
return loss
Variance-Preserving Loss¶
For VP interpolation \(x_t = \sqrt{1-\sigma_t^2} \, x_0 + \sigma_t \, x_1\):
Target velocity:
where \(\sigma_t' = \frac{d\sigma_t}{dt}\).
Common choice: \(\sigma_t = t\), so \(\sigma_t' = 1\):
Weighted Loss¶
Add time-dependent weighting:
Common weights:
1. Uniform: \(w(t) = 1\) (standard)
2. SNR-based: \(w(t) = \frac{1}{\text{SNR}(t)}\) (from diffusion)
3. Endpoint emphasis: \(w(t) = t^2\) or \(w(t) = (1-t)^2\)
4. Min-SNR: \(w(t) = \min(\text{SNR}(t), \gamma)\) (clip large weights)
When to use:
- Uniform works well for most cases
- Endpoint emphasis if sampling quality at \(t=0\) is critical
- SNR-based for VP flows to match diffusion performance
Network Architectures¶
Architecture Requirements¶
Flow matching networks \(v_\theta(x, t)\) must:
- Input: Accept data \(x\) and time \(t\)
- Output: Velocity vector same shape as \(x\)
- Time conditioning: Incorporate \(t\) throughout the network
- Expressiveness: Capture complex velocity fields
U-Net Architecture¶
Standard choice for images:
class FlowMatchingUNet(nn.Module):
def __init__(self, channels=3, dim=64, dim_mults=(1, 2, 4, 8)):
super().__init__()
# Time embedding
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(dim),
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim * 4)
)
# Encoder
self.downs = nn.ModuleList([])
for mult in dim_mults:
self.downs.append(
ResnetBlock(channels, dim * mult, time_emb_dim=dim * 4)
)
# Bottleneck
self.mid = ResnetBlock(dim * dim_mults[-1], dim * dim_mults[-1])
# Decoder
self.ups = nn.ModuleList([])
for mult in reversed(dim_mults):
self.ups.append(
ResnetBlock(dim * mult * 2, dim * mult, time_emb_dim=dim * 4)
)
# Output
self.final = nn.Conv2d(dim, channels, 1)
def forward(self, x, t):
# Time embedding
t_emb = self.time_mlp(t)
# Encoder
h = []
for down in self.downs:
x = down(x, t_emb)
h.append(x)
# Bottleneck
x = self.mid(x, t_emb)
# Decoder
for up in self.ups:
x = torch.cat([x, h.pop()], dim=1)
x = up(x, t_emb)
# Output velocity
return self.final(x)
Key components:
- Sinusoidal time embedding: Encodes \(t \in [0, 1]\)
- ResNet blocks: With time conditioning via FiLM
- Skip connections: Preserve spatial information
- U-Net structure: Encoder-decoder with bottleneck
Diffusion Transformer (DiT)¶
Modern architecture for images:
class DiTBlock(nn.Module):
def __init__(self, dim, num_heads):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(dim, num_heads)
self.norm2 = nn.LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim)
)
# AdaLN modulation
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(dim, 6 * dim)
)
def forward(self, x, c):
# c is time + condition embedding
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = \
self.adaLN_modulation(c).chunk(6, dim=-1)
# Self-attention with AdaLN
x = x + gate_msa * self.attn(
modulate(self.norm1(x), shift_msa, scale_msa)
)
# MLP with AdaLN
x = x + gate_mlp * self.mlp(
modulate(self.norm2(x), shift_mlp, scale_mlp)
)
return x
class FlowMatchingDiT(nn.Module):
def __init__(self, img_size=32, patch_size=2, dim=512, depth=12, num_heads=8):
super().__init__()
# Patchify
self.patch_embed = PatchEmbed(img_size, patch_size, 3, dim)
# Time + condition embedding
self.time_embed = TimestepEmbedder(dim)
# Transformer blocks
self.blocks = nn.ModuleList([
DiTBlock(dim, num_heads) for _ in range(depth)
])
# Output
self.final_layer = FinalLayer(dim, patch_size, 3)
def forward(self, x, t):
# Patchify
x = self.patch_embed(x)
# Time embedding
c = self.time_embed(t)
# Transformer
for block in self.blocks:
x = block(x, c)
# Unpatchify and output velocity
return self.final_layer(x)
Advantages:
- Scalability: Scales better to large models
- Flexibility: Handles variable resolutions
- Attention: Captures long-range dependencies
- Modern: State-of-the-art for image generation
Time Conditioning¶
Sinusoidal embedding (standard):
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, t):
device = t.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = t[:, None] * emb[None, :]
emb = torch.cat([emb.sin(), emb.cos()], dim=-1)
return emb
FiLM conditioning (Feature-wise Linear Modulation):
class FiLM(nn.Module):
def __init__(self, dim, time_emb_dim):
super().__init__()
self.scale_shift = nn.Linear(time_emb_dim, dim * 2)
def forward(self, x, time_emb):
scale, shift = self.scale_shift(time_emb).chunk(2, dim=-1)
return x * (1 + scale) + shift
AdaLN conditioning (Adaptive Layer Normalization):
Training Strategies¶
Data Preprocessing¶
Images:
# Normalize to [-1, 1]
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
Gene expression:
# Log-normalize and standardize
def preprocess_gene_expression(X):
# Log1p transform
X = np.log1p(X)
# Standardize per gene
X = (X - X.mean(axis=0)) / (X.std(axis=0) + 1e-8)
return X
Noise Distribution¶
Standard Gaussian (most common):
Matched variance:
Domain-specific (for gene expression):
# Sparse noise matching dropout structure
def sparse_noise(x0, dropout_rate=0.1):
noise = torch.randn_like(x0)
mask = torch.rand_like(x0) > dropout_rate
return noise * mask
Time Sampling¶
Uniform (standard):
Stratified (better coverage):
# Divide [0,1] into bins
n_bins = batch_size
bins = torch.linspace(0, 1, n_bins + 1, device=device)
t = bins[:-1] + torch.rand(n_bins, device=device) / n_bins
Importance sampling (emphasize difficult times):
Batch Size and Learning Rate¶
Batch size:
- Images: 128-512 (larger is better for diverse pairs)
- Gene expression: 64-256 (depends on dataset size)
- Rule of thumb: As large as GPU memory allows
Learning rate:
- AdamW: 1e-4 to 5e-4 (standard)
- Warmup: 1000-5000 steps
- Schedule: Cosine decay or constant
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
EMA (Exponential Moving Average)¶
Use EMA for better sampling quality:
class EMA:
def __init__(self, model, decay=0.9999):
self.model = model
self.decay = decay
self.shadow = {}
self.backup = {}
for name, param in model.named_parameters():
if param.requires_grad:
self.shadow[name] = param.data.clone()
def update(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
self.shadow[name] -= (1 - self.decay) * (self.shadow[name] - param.data)
def apply_shadow(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
self.backup[name] = param.data
param.data = self.shadow[name]
def restore(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
param.data = self.backup[name]
self.backup = {}
# Usage
ema = EMA(model, decay=0.9999)
for batch in dataloader:
loss = train_step(model, batch)
optimizer.step()
ema.update() # Update EMA after each step
# For sampling, use EMA weights
ema.apply_shadow()
samples = sample(model, ...)
ema.restore()
Complete Training Script¶
Full Implementation¶
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
def train_flow_matching(
model,
train_loader,
num_epochs=100,
lr=1e-4,
device='cuda',
use_ema=True,
save_every=10
):
"""
Complete training loop for flow matching.
"""
# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=num_epochs
)
# EMA
if use_ema:
ema = EMA(model, decay=0.9999)
# Training loop
model.train()
for epoch in range(num_epochs):
epoch_loss = 0
pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')
for batch_idx, x0 in enumerate(pbar):
x0 = x0.to(device)
batch_size = x0.shape[0]
# Sample noise
x1 = torch.randn_like(x0)
# Sample time
t = torch.rand(batch_size, device=device)
# Interpolate
t_expanded = t.view(-1, *([1] * (x0.ndim - 1)))
xt = (1 - t_expanded) * x0 + t_expanded * x1
# Target velocity
target = x1 - x0
# Predict velocity
pred = model(xt, t)
# Compute loss
loss = F.mse_loss(pred, target)
# Backward
optimizer.zero_grad()
loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
# Update
optimizer.step()
# Update EMA
if use_ema:
ema.update()
# Logging
epoch_loss += loss.item()
pbar.set_postfix({'loss': loss.item()})
# Scheduler step
scheduler.step()
# Log epoch
avg_loss = epoch_loss / len(train_loader)
print(f'Epoch {epoch+1}: Average Loss = {avg_loss:.6f}')
# Save checkpoint
if (epoch + 1) % save_every == 0:
checkpoint = {
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': avg_loss,
}
if use_ema:
checkpoint['ema_state_dict'] = ema.shadow
torch.save(checkpoint, f'checkpoint_epoch_{epoch+1}.pt')
return model
Training Example¶
# Load data
from torchvision import datasets, transforms
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
train_dataset = datasets.CIFAR10(
root='./data', train=True, download=True, transform=transform
)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
# Create model
model = FlowMatchingUNet(channels=3, dim=64).cuda()
# Train
trained_model = train_flow_matching(
model,
train_loader,
num_epochs=100,
lr=1e-4,
device='cuda',
use_ema=True
)
Reflow: Iterative Refinement¶
The Reflow Algorithm¶
Reflow iteratively straightens flow paths for faster sampling.
Algorithm:
def reflow(model, train_loader, num_iterations=3):
"""
Iterative reflow to straighten paths.
"""
models = [model]
for iteration in range(1, num_iterations):
print(f'Reflow iteration {iteration}')
# Generate synthetic data using current model
synthetic_data = []
for x1 in tqdm(train_loader, desc='Generating synthetic data'):
x1 = x1.to(device)
# Sample from current model
x0_synthetic = sample_ode(models[-1], x1)
synthetic_data.append((x0_synthetic, x1))
# Train new model on synthetic data
new_model = FlowMatchingUNet(channels=3, dim=64).cuda()
for epoch in range(num_epochs):
for x0_syn, x1 in synthetic_data:
# Standard flow matching training
t = torch.rand(x0_syn.shape[0], device=device)
t_exp = t.view(-1, *([1] * (x0_syn.ndim - 1)))
xt = (1 - t_exp) * x0_syn + t_exp * x1
target = x1 - x0_syn
pred = new_model(xt, t)
loss = F.mse_loss(pred, target)
loss.backward()
optimizer.step()
optimizer.zero_grad()
models.append(new_model)
return models
Effect:
- Iteration 1: 50-100 steps needed
- Iteration 2: 20-30 steps needed
- Iteration 3: 10-15 steps needed
Trade-off: More training time for faster sampling.
Monitoring and Debugging¶
Training Metrics¶
Track during training:
- Loss: Should decrease steadily
- Gradient norm: Should be stable (clip if exploding)
- Learning rate: Follow schedule
- Sample quality: Generate samples periodically
# Log metrics
wandb.log({
'loss': loss.item(),
'grad_norm': grad_norm,
'lr': scheduler.get_last_lr()[0],
'epoch': epoch
})
# Generate samples every N epochs
if epoch % 10 == 0:
with torch.no_grad():
samples = sample_ode(model, num_samples=64)
wandb.log({'samples': wandb.Image(samples)})
Common Issues¶
1. Loss not decreasing:
- Check learning rate (try 1e-4)
- Check data normalization
- Verify target velocity computation
- Increase batch size
2. NaN loss:
- Gradient clipping (clip_grad_norm)
- Lower learning rate
- Check for inf/nan in data
- Use mixed precision carefully
3. Poor sample quality:
- Train longer
- Use EMA
- Increase model capacity
- Try more sampling steps
- Check noise distribution
4. Mode collapse:
- Increase batch size
- Use diverse noise samples
- Check data augmentation
- Verify loss computation
Best Practices¶
Do's¶
✅ Use EMA for better sampling quality ✅ Clip gradients to prevent instability ✅ Normalize data to [-1, 1] or standardize ✅ Use large batch sizes for diverse pairs ✅ Monitor samples during training ✅ Save checkpoints regularly ✅ Use mixed precision for faster training (with caution)
Don'ts¶
❌ Don't skip EMA (significant quality improvement) ❌ Don't use tiny batch sizes (<32) ❌ Don't ignore gradient norms (clip if >1.0) ❌ Don't overtrain (diminishing returns after convergence) ❌ Don't forget data normalization
Summary¶
Key Training Steps¶
- Sample data \(x_0\) and noise \(x_1\)
- Sample time \(t \sim U[0, 1]\)
- Interpolate \(x_t = (1-t)x_0 + tx_1\)
- Compute target \(u_t = x_1 - x_0\)
- Predict \(v_\theta(x_t, t)\)
- Optimize MSE loss
Key Hyperparameters¶
- Batch size: 128-512 (larger is better)
- Learning rate: 1e-4 to 5e-4
- EMA decay: 0.9999
- Gradient clip: 1.0
- Epochs: 100-500 (depends on dataset)
Architecture Choices¶
- Images: U-Net or DiT
- Sequences: Transformer
- Time conditioning: Sinusoidal + FiLM/AdaLN
Related Documents¶
- Flow Matching Foundations — Theory and mathematics
- Flow Matching Sampling — ODE solvers and sampling
- DDPM Training — Comparison with diffusion training
- Diffusion Transformers — DiT architecture
References¶
- Lipman, Y., et al. (2023). Flow Matching for Generative Modeling. ICLR.
- Liu, X., et al. (2023). Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow. ICLR.
- Peebles, W., & Xie, S. (2023). Scalable Diffusion Models with Transformers. ICCV.
- Karras, T., et al. (2022). Elucidating the Design Space of Diffusion-Based Generative Models. NeurIPS.