Diffusion Models for Gene Expression: DDPM Tutorial¶
Goal: Implement a denoising diffusion probabilistic model (DDPM) for generating gene expression profiles.
This notebook demonstrates:
- Core DDPM mechanics (forward/reverse diffusion)
- Training a time-conditional score network on gene expression data
- Conditional generation (cell type → gene expression)
- Foundation for drug-response prediction (scPPDM approach)
Dataset: PBMC 3k (small subset for fast iteration)
Next steps: Extend to perturbation response (baseline + drug → perturbed expression)
Prerequisites¶
Environment: Make sure you're in the genailab conda environment:
mamba activate genailab
Required packages: torch, scanpy, numpy, matplotlib, tqdm
Setup¶
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import scanpy as sc
from tqdm.auto import tqdm
# Set random seeds
np.random.seed(42)
torch.manual_seed(42)
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")
1. Load and Prepare Gene Expression Data¶
We'll use a small subset of PBMC 3k for fast iteration. For production, you'd use the full dataset.
# Load PBMC 3k data
data_path = Path("../data/pbmc3k_raw.h5ad")
if data_path.exists():
adata = sc.read_h5ad(data_path)
print(f"Loaded data: {adata.shape}")
else:
# Download if not available
adata = sc.datasets.pbmc3k()
print(f"Downloaded PBMC 3k: {adata.shape}")
# Basic preprocessing
sc.pp.filter_cells(adata, min_genes=200)
sc.pp.filter_genes(adata, min_cells=3)
# Normalize and log-transform for diffusion model
# Note: For count-based models (like NB VAE), we'd use raw counts
# For diffusion, we typically work with normalized continuous data
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
# Select highly variable genes for faster training
sc.pp.highly_variable_genes(adata, n_top_genes=500, flavor='seurat_v3')
adata = adata[:, adata.var.highly_variable].copy()
print(f"After preprocessing: {adata.shape}")
print(f"Gene expression range: [{adata.X.min():.2f}, {adata.X.max():.2f}]")
# Annotate cell types for conditional generation
sc.pp.neighbors(adata, n_neighbors=10)
sc.tl.leiden(adata, resolution=0.5)
# Store cell type labels
cell_types = adata.obs['leiden'].values
n_cell_types = len(np.unique(cell_types))
print(f"Found {n_cell_types} cell type clusters")
print(adata.obs['leiden'].value_counts())
2. Create PyTorch Dataset¶
We'll create a dataset that returns:
- Gene expression vector (x)
- Cell type label (condition)
class GeneExpressionDataset(Dataset):
"""Dataset for gene expression with optional conditioning."""
def __init__(self, adata, condition_key=None):
"""
Args:
adata: AnnData object with preprocessed gene expression
condition_key: Key in adata.obs for conditioning (e.g., 'leiden', 'treatment')
"""
# Convert to dense array if sparse
if hasattr(adata.X, 'toarray'):
self.X = adata.X.toarray()
else:
self.X = adata.X
self.X = torch.FloatTensor(self.X)
# Extract conditions if provided
if condition_key is not None:
conditions = adata.obs[condition_key].values
# Convert to integer labels
unique_conditions = np.unique(conditions)
condition_to_idx = {c: i for i, c in enumerate(unique_conditions)}
self.conditions = torch.LongTensor([condition_to_idx[c] for c in conditions])
self.n_conditions = len(unique_conditions)
else:
self.conditions = None
self.n_conditions = 0
def __len__(self):
return len(self.X)
def __getitem__(self, idx):
if self.conditions is not None:
return self.X[idx], self.conditions[idx]
return self.X[idx]
# Create dataset
dataset = GeneExpressionDataset(adata, condition_key='leiden')
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
print(f"Dataset size: {len(dataset)}")
print(f"Gene dimension: {dataset.X.shape[1]}")
print(f"Number of conditions: {dataset.n_conditions}")
class NoiseScheduler:
"""Linear noise schedule for DDPM."""
def __init__(self, num_timesteps=1000, beta_start=1e-4, beta_end=0.02):
"""
Args:
num_timesteps: Number of diffusion steps (T)
beta_start: Starting noise variance
beta_end: Ending noise variance
"""
self.num_timesteps = num_timesteps
# Linear schedule for beta
self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
# Precompute useful quantities
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
# For sampling
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
# For posterior q(x_{t-1} | x_t, x_0)
self.posterior_variance = (
self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
)
def add_noise(self, x_0, t, noise=None):
"""
Forward diffusion: q(x_t | x_0) = N(x_t; sqrt(alpha_bar_t) * x_0, (1 - alpha_bar_t) * I)
Args:
x_0: Original data [batch_size, dim]
t: Timestep [batch_size]
noise: Optional noise to add (for reproducibility)
Returns:
x_t: Noisy data at timestep t
noise: The noise that was added
"""
if noise is None:
noise = torch.randn_like(x_0)
# Get coefficients for this timestep
sqrt_alpha_prod = self.sqrt_alphas_cumprod[t].reshape(-1, 1)
sqrt_one_minus_alpha_prod = self.sqrt_one_minus_alphas_cumprod[t].reshape(-1, 1)
# x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * noise
x_t = sqrt_alpha_prod * x_0 + sqrt_one_minus_alpha_prod * noise
return x_t, noise
# Test the scheduler
scheduler = NoiseScheduler(num_timesteps=1000)
# Visualize noise schedule
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
axes[0].plot(scheduler.betas.numpy())
axes[0].set_title('Beta Schedule')
axes[0].set_xlabel('Timestep')
axes[0].set_ylabel('Beta')
axes[1].plot(scheduler.alphas_cumprod.numpy())
axes[1].set_title('Cumulative Alpha')
axes[1].set_xlabel('Timestep')
axes[1].set_ylabel('Alpha_bar')
axes[2].plot(scheduler.sqrt_one_minus_alphas_cumprod.numpy())
axes[2].set_title('Noise Coefficient')
axes[2].set_xlabel('Timestep')
axes[2].set_ylabel('sqrt(1 - alpha_bar)')
plt.tight_layout()
plt.show()
3.2 Visualize Forward Diffusion Process¶
Let's see how a gene expression vector gets progressively noisier.
# Take a sample gene expression vector
x_0 = dataset.X[0:1] # Shape: [1, n_genes]
# Add noise at different timesteps
timesteps = [0, 100, 250, 500, 750, 999]
noisy_samples = []
for t in timesteps:
t_tensor = torch.tensor([t])
x_t, _ = scheduler.add_noise(x_0, t_tensor)
noisy_samples.append(x_t[0].numpy())
# Visualize
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
axes = axes.flatten()
for i, (t, x_t) in enumerate(zip(timesteps, noisy_samples)):
axes[i].hist(x_t, bins=50, alpha=0.7)
axes[i].set_title(f'Timestep t={t}')
axes[i].set_xlabel('Expression value')
axes[i].set_ylabel('Frequency')
axes[i].axvline(x=0, color='r', linestyle='--', alpha=0.5)
plt.suptitle('Forward Diffusion: Gene Expression → Noise', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()
print(f"Original data mean: {x_0.mean():.3f}, std: {x_0.std():.3f}")
print(f"Final noise mean: {noisy_samples[-1].mean():.3f}, std: {noisy_samples[-1].std():.3f}")
3.3 Time-Conditional Score Network¶
The core of DDPM: a neural network that predicts the noise $\epsilon_\theta(x_t, t, c)$ given:
- Noisy data $x_t$
- Timestep $t$
- Condition $c$ (e.g., cell type, drug)
For gene expression (tabular data), we use an MLP with:
- Sinusoidal time embeddings
- Conditional embeddings
- Residual connections
class SinusoidalPositionEmbeddings(nn.Module):
"""Sinusoidal embeddings for timesteps (like in Transformers)."""
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, time):
device = time.device
half_dim = self.dim // 2
embeddings = np.log(10000) / (half_dim - 1)
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
embeddings = time[:, None] * embeddings[None, :]
embeddings = torch.cat([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)
return embeddings
class MLPBlock(nn.Module):
"""MLP block with residual connection."""
def __init__(self, dim, dropout=0.1):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim * 4, dim),
nn.Dropout(dropout),
)
self.norm = nn.LayerNorm(dim)
def forward(self, x):
return self.norm(x + self.net(x))
class ConditionalScoreNetwork(nn.Module):
"""Time and condition-conditional score network for gene expression."""
def __init__(
self,
input_dim,
hidden_dim=256,
time_dim=64,
n_conditions=0,
condition_dim=32,
n_layers=4,
dropout=0.1,
):
"""
Args:
input_dim: Gene expression dimension
hidden_dim: Hidden layer dimension
time_dim: Time embedding dimension
n_conditions: Number of condition classes (0 for unconditional)
condition_dim: Condition embedding dimension
n_layers: Number of MLP blocks
dropout: Dropout rate
"""
super().__init__()
self.input_dim = input_dim
self.n_conditions = n_conditions
# Time embedding
self.time_mlp = nn.Sequential(
SinusoidalPositionEmbeddings(time_dim),
nn.Linear(time_dim, time_dim * 2),
nn.GELU(),
nn.Linear(time_dim * 2, time_dim),
)
# Condition embedding (if conditional)
if n_conditions > 0:
self.condition_embed = nn.Embedding(n_conditions, condition_dim)
total_input_dim = input_dim + time_dim + condition_dim
else:
self.condition_embed = None
total_input_dim = input_dim + time_dim
# Input projection
self.input_proj = nn.Linear(total_input_dim, hidden_dim)
# MLP blocks
self.blocks = nn.ModuleList([
MLPBlock(hidden_dim, dropout) for _ in range(n_layers)
])
# Output projection (predict noise)
self.output_proj = nn.Linear(hidden_dim, input_dim)
def forward(self, x, t, condition=None):
"""
Args:
x: Noisy gene expression [batch_size, input_dim]
t: Timestep [batch_size]
condition: Condition labels [batch_size] (optional)
Returns:
Predicted noise [batch_size, input_dim]
"""
# Time embedding
t_emb = self.time_mlp(t)
# Concatenate inputs
if self.condition_embed is not None and condition is not None:
c_emb = self.condition_embed(condition)
h = torch.cat([x, t_emb, c_emb], dim=-1)
else:
h = torch.cat([x, t_emb], dim=-1)
# Project to hidden dimension
h = self.input_proj(h)
# Apply MLP blocks
for block in self.blocks:
h = block(h)
# Predict noise
noise_pred = self.output_proj(h)
return noise_pred
# Test the network
model = ConditionalScoreNetwork(
input_dim=dataset.X.shape[1],
hidden_dim=256,
n_conditions=dataset.n_conditions,
n_layers=4,
).to(device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
# Test forward pass
x_test, c_test = next(iter(dataloader))
x_test, c_test = x_test.to(device), c_test.to(device)
t_test = torch.randint(0, 1000, (x_test.shape[0],), device=device)
noise_pred = model(x_test, t_test, c_test)
print(f"Input shape: {x_test.shape}")
print(f"Output shape: {noise_pred.shape}")
4. Training Loop¶
DDPM training is simple:
- Sample a batch of data $x_0$
- Sample random timesteps $t$
- Add noise: $x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1 - \bar{\alpha}_t} \epsilon$
- Predict noise: $\epsilon_\theta(x_t, t, c)$
- Compute MSE loss: $\|\epsilon - \epsilon_\theta(x_t, t, c)\|^2$
def train_ddpm(model, dataloader, scheduler, num_epochs=100, lr=1e-4):
"""Train DDPM model."""
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
model.train()
losses = []
for epoch in range(num_epochs):
epoch_loss = 0.0
for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False):
if len(batch) == 2:
x_0, condition = batch
x_0 = x_0.to(device)
condition = condition.to(device)
else:
x_0 = batch.to(device)
condition = None
batch_size = x_0.shape[0]
# Sample random timesteps
t = torch.randint(0, scheduler.num_timesteps, (batch_size,), device=device)
# Add noise
noise = torch.randn_like(x_0)
x_t, _ = scheduler.add_noise(x_0, t, noise)
# Predict noise
noise_pred = model(x_t, t, condition)
# Compute loss
loss = F.mse_loss(noise_pred, noise)
# Backprop
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
epoch_loss += loss.item()
avg_loss = epoch_loss / len(dataloader)
losses.append(avg_loss)
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")
return losses
# Train the model (start with fewer epochs for testing)
losses = train_ddpm(
model=model,
dataloader=dataloader,
scheduler=scheduler,
num_epochs=50, # Increase to 200-500 for better results
lr=1e-4,
)
# Plot training curve
plt.figure(figsize=(10, 5))
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.title('DDPM Training Loss')
plt.grid(True, alpha=0.3)
plt.show()
5. Sampling (Reverse Diffusion)¶
Generate new gene expression profiles by:
- Start with pure noise $x_T \sim \mathcal{N}(0, I)$
- Iteratively denoise: $x_{t-1} = \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \epsilon_\theta(x_t, t, c) \right) + \sigma_t z$
@torch.no_grad()
def sample_ddpm(model, scheduler, n_samples, condition=None, device='cpu'):
"""Sample from DDPM model.
Args:
model: Trained score network
scheduler: Noise scheduler
n_samples: Number of samples to generate
condition: Condition labels [n_samples] (optional)
device: Device to run on
Returns:
Generated samples [n_samples, input_dim]
"""
model.eval()
# Start from pure noise
x = torch.randn(n_samples, model.input_dim, device=device)
if condition is not None:
condition = condition.to(device)
# Reverse diffusion
for t in tqdm(reversed(range(scheduler.num_timesteps)), desc="Sampling", total=scheduler.num_timesteps):
t_batch = torch.full((n_samples,), t, device=device, dtype=torch.long)
# Predict noise
noise_pred = model(x, t_batch, condition)
# Get scheduler coefficients
alpha_t = scheduler.alphas[t]
alpha_bar_t = scheduler.alphas_cumprod[t]
beta_t = scheduler.betas[t]
# Compute mean
mean = (1 / torch.sqrt(alpha_t)) * (
x - (beta_t / torch.sqrt(1 - alpha_bar_t)) * noise_pred
)
# Add noise (except at t=0)
if t > 0:
noise = torch.randn_like(x)
sigma_t = torch.sqrt(scheduler.posterior_variance[t])
x = mean + sigma_t * noise
else:
x = mean
return x
# Generate samples for each cell type
n_samples_per_type = 50
generated_samples = []
generated_labels = []
for cell_type_idx in range(dataset.n_conditions):
condition = torch.full((n_samples_per_type,), cell_type_idx, dtype=torch.long)
samples = sample_ddpm(model, scheduler, n_samples_per_type, condition, device=device)
generated_samples.append(samples.cpu())
generated_labels.extend([cell_type_idx] * n_samples_per_type)
generated_samples = torch.cat(generated_samples, dim=0).numpy()
generated_labels = np.array(generated_labels)
print(f"Generated {generated_samples.shape[0]} samples")
print(f"Sample shape: {generated_samples.shape}")
6. Evaluation¶
Compare generated vs real gene expression distributions.
# Compare distributions
real_data = dataset.X.numpy()
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
axes = axes.flatten()
# Overall distribution
axes[0].hist(real_data.flatten(), bins=50, alpha=0.5, label='Real', density=True)
axes[0].hist(generated_samples.flatten(), bins=50, alpha=0.5, label='Generated', density=True)
axes[0].set_title('Overall Distribution')
axes[0].legend()
# Mean expression per gene
axes[1].scatter(real_data.mean(axis=0), generated_samples.mean(axis=0), alpha=0.3)
axes[1].plot([real_data.mean(axis=0).min(), real_data.mean(axis=0).max()],
[real_data.mean(axis=0).min(), real_data.mean(axis=0).max()],
'r--', alpha=0.5)
axes[1].set_xlabel('Real mean expression')
axes[1].set_ylabel('Generated mean expression')
axes[1].set_title('Mean Expression per Gene')
# Std expression per gene
axes[2].scatter(real_data.std(axis=0), generated_samples.std(axis=0), alpha=0.3)
axes[2].plot([real_data.std(axis=0).min(), real_data.std(axis=0).max()],
[real_data.std(axis=0).min(), real_data.std(axis=0).max()],
'r--', alpha=0.5)
axes[2].set_xlabel('Real std expression')
axes[2].set_ylabel('Generated std expression')
axes[2].set_title('Std Expression per Gene')
# Sample a few genes and compare distributions
for i, gene_idx in enumerate([0, 10, 50]):
axes[3 + i].hist(real_data[:, gene_idx], bins=30, alpha=0.5, label='Real', density=True)
axes[3 + i].hist(generated_samples[:, gene_idx], bins=30, alpha=0.5, label='Generated', density=True)
axes[3 + i].set_title(f'Gene {gene_idx}')
axes[3 + i].legend()
plt.tight_layout()
plt.show()
7. Next Steps: Extending to Drug-Response Prediction¶
To implement the scPPDM approach for perturbation response:
Architecture Changes:¶
- Input: Concatenate baseline expression + drug embedding
- Output: Predict perturbed expression (not noise)
- Conditioning: Drug type + dose
Data Requirements:¶
- Paired samples: (baseline, drug, dose, perturbed_expression)
- Examples: Sci-Plex, LINCS L1000, Replogle et al. Perturb-seq
Modified Forward Process:¶
# Instead of: x_t = sqrt(alpha_bar) * x_0 + sqrt(1 - alpha_bar) * noise
# Use: x_t = sqrt(alpha_bar) * x_perturbed + sqrt(1 - alpha_bar) * noise
# Condition on: [x_baseline, drug_embedding, dose]
Training:¶
# Predict perturbed expression from baseline + drug
def forward(x_baseline, drug, dose, t):
# Encode drug
drug_emb = drug_encoder(drug, dose)
# Concatenate baseline + drug info
condition = torch.cat([x_baseline, drug_emb], dim=-1)
# Predict noise for perturbed expression
noise_pred = score_network(x_t, t, condition)
return noise_pred
Sampling:¶
# Generate counterfactual response
x_perturbed = sample_ddpm(
model,
scheduler,
condition={'baseline': x_baseline, 'drug': drug_id, 'dose': dose_value}
)
Summary¶
We've implemented:
- ✅ Noise scheduler (linear beta schedule)
- ✅ Forward diffusion (adding noise)
- ✅ Time-conditional score network (MLP for tabular data)
- ✅ Training loop (simple MSE loss)
- ✅ Sampling (reverse diffusion)
- ✅ Conditional generation (cell type → expression)
Next notebook: Implement full scPPDM for drug-response prediction with perturbation datasets.