Advanced Score Network Architectures for Realistic Data¶
Overview¶
This document explores advanced neural network architectures for score-based diffusion models when working with realistic, complex data such as: - Medical imaging (CT, MRI, pathology slides) - Gene expression data (bulk RNA-seq, microarray) - Single-cell RNA-seq (scRNA-seq) data - High-resolution natural images
The simple MLP used in the tutorial notebook works well for 2D toy data, but realistic datasets require more sophisticated architectures.
Referenced From¶
- Notebook:
notebooks/diffusion/02_sde_formulation/02_sde_formulation.ipynb - Related: Time Embedding and FiLM — Component details for time conditioning
Why Simple MLPs Are Not Enough¶
The Tutorial Architecture¶
The notebook uses a simple MLP:
class SimpleScoreNetwork(nn.Module):
def __init__(self, data_dim=2, hidden_dim=128, time_dim=32):
self.net = nn.Sequential(
nn.Linear(data_dim + time_dim, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, data_dim),
)
This works for 2D toy data because:
- Only 2 dimensions to model
- Simple distributions (e.g., mixture of Gaussians)
- No spatial structure to capture
Limitations for Realistic Data¶
| Data Type | Challenge | Why MLP Fails |
|---|---|---|
| Images | Spatial correlations, local patterns | MLPs treat pixels independently |
| Gene expression | High dimensionality (~20,000 genes) | Too many parameters, no structure |
| scRNA-seq | Sparse, high-dimensional, cell-type structure | Can't leverage biological priors |
| Medical imaging | Multi-scale features, fine details | No hierarchical feature extraction |
Architecture 1: U-Net for Images¶
Why U-Net?¶
U-Net is the dominant architecture for image diffusion models (DDPM, Stable Diffusion) because:
- Multi-scale processing: Encoder-decoder structure captures features at multiple resolutions
- Skip connections: Preserve fine details while processing global structure
- Fully convolutional: Handles arbitrary image sizes
- Proven effectiveness: State-of-the-art in image generation
Architecture Overview¶
Input Image (H×W×C) + Time Embedding
↓
┌───────────────┐
│ Encoder │ (Downsample: capture global structure)
│ Block 1 │ ────────────────┐
└───────────────┘ │ Skip
↓ │
┌───────────────┐ │
│ Encoder │ │
│ Block 2 │ ────────────┐ │
└───────────────┘ │ │
↓ │ │
┌───────────────┐ │ │
│ Bottleneck │ │ │
└───────────────┘ │ │
↓ │ │
┌───────────────┐ │ │
│ Decoder │ ←───────────┘ │
│ Block 1 │ │
└───────────────┘ │
↓ │
┌───────────────┐ │
│ Decoder │ ←───────────────┘
│ Block 2 │
└───────────────┘
↓
Output Score (H×W×C)
Key Components¶
1. Residual Blocks with Time Conditioning¶
class ResBlock(nn.Module):
"""Residual block with time conditioning via FiLM."""
def __init__(self, in_channels, out_channels, time_dim):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
# Time conditioning (FiLM)
self.time_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_dim, 2 * out_channels) # scale and shift
)
# Skip connection if dimensions change
self.skip = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
self.norm1 = nn.GroupNorm(8, in_channels)
self.norm2 = nn.GroupNorm(8, out_channels)
self.act = nn.SiLU()
def forward(self, x, time_emb):
h = self.norm1(x)
h = self.act(h)
h = self.conv1(h)
# Apply time conditioning
scale, shift = self.time_mlp(time_emb).chunk(2, dim=1)
scale = scale[:, :, None, None] # Broadcast to spatial dims
shift = shift[:, :, None, None]
h = h * (1 + scale) + shift
h = self.norm2(h)
h = self.act(h)
h = self.conv2(h)
return h + self.skip(x)
2. Attention Layers¶
For capturing long-range dependencies (especially important for large images):
class SelfAttention(nn.Module):
"""Self-attention for spatial features."""
def __init__(self, channels, num_heads=4):
super().__init__()
self.attention = nn.MultiheadAttention(channels, num_heads, batch_first=True)
self.norm = nn.GroupNorm(8, channels)
def forward(self, x):
b, c, h, w = x.shape
# Reshape to sequence: (batch, h*w, channels)
x_flat = x.flatten(2).transpose(1, 2)
# Apply attention
x_norm = self.norm(x).flatten(2).transpose(1, 2)
attn_out, _ = self.attention(x_norm, x_norm, x_norm)
# Reshape back
out = attn_out.transpose(1, 2).reshape(b, c, h, w)
return x + out
Complete U-Net Skeleton¶
class UNet(nn.Module):
"""U-Net for score estimation in diffusion models."""
def __init__(self, in_channels=3, base_channels=64, time_dim=256):
super().__init__()
# Time embedding
self.time_mlp = nn.Sequential(
SinusoidalTimeEmbedding(time_dim),
nn.Linear(time_dim, time_dim * 4),
nn.SiLU(),
nn.Linear(time_dim * 4, time_dim)
)
# Encoder
self.enc1 = ResBlock(in_channels, base_channels, time_dim)
self.enc2 = ResBlock(base_channels, base_channels * 2, time_dim)
self.enc3 = ResBlock(base_channels * 2, base_channels * 4, time_dim)
self.down1 = nn.Conv2d(base_channels, base_channels, 3, stride=2, padding=1)
self.down2 = nn.Conv2d(base_channels * 2, base_channels * 2, 3, stride=2, padding=1)
# Bottleneck with attention
self.bottleneck = nn.Sequential(
ResBlock(base_channels * 4, base_channels * 4, time_dim),
SelfAttention(base_channels * 4),
ResBlock(base_channels * 4, base_channels * 4, time_dim)
)
# Decoder (with skip connections)
self.dec3 = ResBlock(base_channels * 8, base_channels * 2, time_dim) # *8 from concat
self.dec2 = ResBlock(base_channels * 4, base_channels, time_dim)
self.dec1 = ResBlock(base_channels * 2, base_channels, time_dim)
self.up2 = nn.ConvTranspose2d(base_channels * 4, base_channels * 4, 4, stride=2, padding=1)
self.up1 = nn.ConvTranspose2d(base_channels * 2, base_channels * 2, 4, stride=2, padding=1)
# Output
self.out = nn.Sequential(
nn.GroupNorm(8, base_channels),
nn.SiLU(),
nn.Conv2d(base_channels, in_channels, 3, padding=1)
)
def forward(self, x, t):
# Time embedding
t_emb = self.time_mlp(t)
# Encoder
h1 = self.enc1(x, t_emb)
h2 = self.enc2(self.down1(h1), t_emb)
h3 = self.enc3(self.down2(h2), t_emb)
# Bottleneck
h = self.bottleneck(h3)
# Decoder with skip connections
h = self.up2(h)
h = self.dec3(torch.cat([h, h2], dim=1), t_emb)
h = self.up1(h)
h = self.dec2(torch.cat([h, h1], dim=1), t_emb)
h = self.dec1(h, t_emb)
return self.out(h)
Medical Imaging Considerations¶
For medical imaging (CT, MRI, pathology):
- 3D U-Net: Replace 2D convolutions with 3D for volumetric data
- Higher resolution: More downsampling stages for high-res pathology
- Domain-specific augmentation: Intensity normalization, organ-specific priors
- Conditional generation: Condition on patient metadata, modality, etc.
# 3D convolution for volumetric medical imaging
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1)
Architecture 2: Vision Transformer (DiT)¶
Why Vision Transformers?¶
The Diffusion Transformer (DiT) architecture shows that pure Transformers can match or exceed U-Net:
- Scalability: Performance scales predictably with model size
- Simplicity: Fewer inductive biases than CNNs
- Long-range attention: Global receptive field from the start
- Transfer learning: Leverage pretrained Vision Transformers
Architecture Overview¶
Input Image → Patchify → Linear Projection
↓
Positional Encoding
↓
┌──────────────────────────────┐
│ Transformer Block 1 │ ← Time/Class Conditioning
│ (Self-Attention + FFN) │ (via AdaLN)
└──────────────────────────────┘
↓
┌──────────────────────────────┐
│ Transformer Block 2 │ ← Time/Class Conditioning
└──────────────────────────────┘
↓
...
↓
┌──────────────────────────────┐
│ Transformer Block N │ ← Time/Class Conditioning
└──────────────────────────────┘
↓
Linear Projection
↓
Unpatchify → Score
Key Innovation: Adaptive Layer Normalization (AdaLN)¶
DiT conditions on time using Adaptive Layer Normalization:
class AdaLN(nn.Module):
"""Adaptive Layer Normalization for DiT."""
def __init__(self, hidden_dim, condition_dim):
super().__init__()
self.norm = nn.LayerNorm(hidden_dim, elementwise_affine=False)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(condition_dim, 6 * hidden_dim) # scale, shift, gate for 2 norms
)
def forward(self, x, condition):
# Get modulation parameters
params = self.adaLN_modulation(condition)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = params.chunk(6, dim=-1)
return shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp
DiT Block¶
class DiTBlock(nn.Module):
"""Diffusion Transformer block with AdaLN conditioning."""
def __init__(self, hidden_dim, num_heads, mlp_ratio=4.0, condition_dim=256):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_dim, elementwise_affine=False)
self.attn = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
self.norm2 = nn.LayerNorm(hidden_dim, elementwise_affine=False)
mlp_hidden = int(hidden_dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(hidden_dim, mlp_hidden),
nn.GELU(),
nn.Linear(mlp_hidden, hidden_dim)
)
# AdaLN modulation
self.adaLN = nn.Sequential(
nn.SiLU(),
nn.Linear(condition_dim, 6 * hidden_dim)
)
def forward(self, x, condition):
# Get modulation parameters
shift1, scale1, gate1, shift2, scale2, gate2 = self.adaLN(condition).chunk(6, dim=-1)
# Attention with AdaLN
h = self.norm1(x)
h = h * (1 + scale1.unsqueeze(1)) + shift1.unsqueeze(1)
h, _ = self.attn(h, h, h)
x = x + gate1.unsqueeze(1) * h
# MLP with AdaLN
h = self.norm2(x)
h = h * (1 + scale2.unsqueeze(1)) + shift2.unsqueeze(1)
h = self.mlp(h)
x = x + gate2.unsqueeze(1) * h
return x
When to Use DiT vs U-Net¶
| Aspect | U-Net | DiT |
|---|---|---|
| Inductive bias | Strong (locality, hierarchy) | Weak (learns structure) |
| Data efficiency | Better with less data | Needs more data |
| Scaling | Saturates earlier | Scales well |
| Compute | More efficient for small models | More efficient at scale |
| Best for | Small-medium datasets | Large-scale training |
Architecture 3: Networks for Tabular/Biological Data¶
Gene Expression Data¶
Gene expression data (bulk RNA-seq, microarray) is: - High-dimensional: ~20,000 genes - Tabular: No spatial structure - Structured: Gene pathways, co-expression patterns
Architecture Options¶
Option 1: Deep MLP with Residual Connections
class GeneExpressionScoreNet(nn.Module):
"""Score network for gene expression data."""
def __init__(self, n_genes=20000, hidden_dim=2048, n_layers=6, time_dim=128):
super().__init__()
self.time_embed = SinusoidalTimeEmbedding(time_dim)
# Initial projection
self.input_proj = nn.Linear(n_genes, hidden_dim)
# Deep residual blocks
self.blocks = nn.ModuleList([
ResidualMLPBlock(hidden_dim, time_dim)
for _ in range(n_layers)
])
# Output
self.output_proj = nn.Sequential(
nn.LayerNorm(hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, n_genes)
)
def forward(self, x, t):
t_emb = self.time_embed(t)
h = self.input_proj(x)
for block in self.blocks:
h = block(h, t_emb)
return self.output_proj(h)
class ResidualMLPBlock(nn.Module):
def __init__(self, hidden_dim, time_dim):
super().__init__()
self.norm = nn.LayerNorm(hidden_dim)
self.mlp = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim * 4),
nn.GELU(),
nn.Linear(hidden_dim * 4, hidden_dim)
)
self.time_proj = nn.Linear(time_dim, hidden_dim * 2) # FiLM
def forward(self, x, t_emb):
scale, shift = self.time_proj(t_emb).chunk(2, dim=-1)
h = self.norm(x)
h = h * (1 + scale) + shift
h = self.mlp(h)
return x + h
Option 2: Graph Neural Network (for pathway structure)
If gene-gene interaction networks are available:
class PathwayAwareScoreNet(nn.Module):
"""Score network using gene pathway structure."""
def __init__(self, n_genes, hidden_dim, adjacency_matrix, time_dim=128):
super().__init__()
# Register adjacency as buffer (not parameter)
self.register_buffer('adj', adjacency_matrix)
self.time_embed = SinusoidalTimeEmbedding(time_dim)
self.node_embed = nn.Linear(1, hidden_dim)
# Graph convolution layers
self.gcn_layers = nn.ModuleList([
GraphConvWithTime(hidden_dim, time_dim)
for _ in range(4)
])
self.output = nn.Linear(hidden_dim, 1)
def forward(self, x, t):
# x: (batch, n_genes)
t_emb = self.time_embed(t)
# Treat each gene as a node
h = self.node_embed(x.unsqueeze(-1)) # (batch, n_genes, hidden)
for gcn in self.gcn_layers:
h = gcn(h, self.adj, t_emb)
return self.output(h).squeeze(-1) # (batch, n_genes)
scRNA-seq Data¶
Single-cell RNA-seq data has additional challenges: - Sparse: Many zero counts (dropout) - Cell-type structure: Cells cluster by type - Batch effects: Technical variation across experiments
Architecture Considerations¶
class scRNAScoreNet(nn.Module):
"""Score network for scRNA-seq with sparsity handling."""
def __init__(self, n_genes, hidden_dim=512, time_dim=64):
super().__init__()
self.time_embed = SinusoidalTimeEmbedding(time_dim)
# Sparse-aware input processing
self.input_encoder = nn.Sequential(
nn.Linear(n_genes, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU()
)
# Attention over genes (learn which genes are important)
self.gene_attention = nn.MultiheadAttention(hidden_dim, num_heads=8, batch_first=True)
# Score prediction
self.decoder = nn.Sequential(
ResidualMLPBlock(hidden_dim, time_dim),
ResidualMLPBlock(hidden_dim, time_dim),
nn.LayerNorm(hidden_dim),
nn.Linear(hidden_dim, n_genes)
)
def forward(self, x, t):
t_emb = self.time_embed(t)
# Encode
h = self.input_encoder(x)
# Self-attention to learn gene relationships
h_attended, _ = self.gene_attention(h.unsqueeze(1), h.unsqueeze(1), h.unsqueeze(1))
h = h + h_attended.squeeze(1)
# Decode with time conditioning
return self.decoder(h, t_emb)
Choosing the Right Architecture¶
Decision Framework¶
What is your data type?
│
├── Images (2D)
│ │
│ ├── Small dataset (<50K) → U-Net
│ ├── Large dataset (>1M) → DiT
│ └── Medical imaging → 2D/3D U-Net with domain priors
│
├── Volumes (3D)
│ └── 3D U-Net
│
├── Tabular (gene expression, features)
│ │
│ ├── No structure → Deep MLP with residual connections
│ └── Known graph structure → GNN
│
└── Sequences
└── Transformer-based
Practical Recommendations¶
| Data Type | Recommended Architecture | Key Components |
|---|---|---|
| Natural images | U-Net or DiT | Time embedding, FiLM/AdaLN, attention |
| Medical imaging | 3D U-Net | Domain-specific normalization, high-res handling |
| Gene expression | Deep residual MLP | FiLM, layer normalization |
| scRNA-seq | MLP + attention | Sparsity handling, gene attention |
| Pathology slides | Hierarchical U-Net | Multi-scale, memory-efficient |
Implementation Tips¶
1. Start Simple, Then Scale¶
# Start with small model for debugging
model_small = UNet(base_channels=32)
# Scale up once working
model_large = UNet(base_channels=256)
2. Memory Management for Large Data¶
# Gradient checkpointing for U-Net
def forward_with_checkpointing(self, x, t):
t_emb = self.time_mlp(t)
# Use checkpointing for encoder blocks
h1 = torch.utils.checkpoint.checkpoint(self.enc1, x, t_emb)
h2 = torch.utils.checkpoint.checkpoint(self.enc2, self.down1(h1), t_emb)
# ... etc
3. Mixed Precision Training¶
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
with autocast():
score = model(x_t, t)
loss = score_matching_loss(score, noise, std)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
Summary¶
| Architecture | Best For | Key Features |
|---|---|---|
| Simple MLP | 2D toy data, quick experiments | Concatenation or FiLM |
| U-Net | Images, medical imaging | Skip connections, multi-scale, FiLM |
| DiT | Large-scale image generation | Transformer, AdaLN, scalable |
| Deep MLP | Tabular/gene expression | Residual connections, FiLM |
| GNN | Data with graph structure | Pathway-aware, message passing |
The key insight: match the architecture's inductive biases to your data's structure.
Further Reading¶
- DDPM U-Net: Ho et al., "Denoising Diffusion Probabilistic Models" (2020)
- DiT: Peebles & Xie, "Scalable Diffusion Models with Transformers" (2023)
- Medical Diffusion: Kazerouni et al., "Diffusion Models in Medical Imaging" (2023)
- scRNA Diffusion: Various recent works on single-cell generation
Related Documents¶
- Score Network Basics — SiLU activation, SimpleScoreNetwork MLP, architecture comparison
- Time Embedding and FiLM — Detailed explanation of time conditioning components