Latent Diffusion Applications: Computational Biology¶
This document covers applications of latent diffusion models in computational biology, including single-cell generation, perturbation prediction, multi-omics translation, trajectory modeling, and spatial transcriptomics.
Prerequisites: Understanding of latent diffusion foundations and training.
Overview¶
Latent diffusion models are particularly well-suited for computational biology because:
- Efficiency — 10-100× faster than pixel-space diffusion
- Quality — Better than VAE, stable than GAN
- Flexibility — Multi-modal, multi-task, controllable
- Interpretability — Latent space has biological meaning
1. Single-Cell Generation¶
1.1 Task Definition¶
Goal: Generate realistic single-cell gene expression profiles.
Applications:
- Data augmentation for rare cell types
- Synthetic controls for experiments
- Batch effect removal
- Counterfactual cell states
1.2 Architecture¶
class SingleCellLatentDiffusion(nn.Module):
"""
Latent diffusion for single-cell generation.
Args:
num_genes: Number of genes
latent_dim: Latent dimension
num_cell_types: Number of cell types (for conditioning)
"""
def __init__(
self,
num_genes=5000,
latent_dim=256,
num_cell_types=20,
):
super().__init__()
# VAE with ZINB decoder
self.vae = GeneExpressionVAE(
num_genes=num_genes,
latent_dim=latent_dim,
decoder_type='zinb',
)
# Latent diffusion with cell type conditioning
self.diffusion = LatentDiffusionModel(
latent_dim=latent_dim,
model_type='dit',
)
# Cell type embedding
self.cell_type_embed = nn.Embedding(num_cell_types, 128)
def forward(self, x, cell_type, library_size=None):
"""
Training forward pass.
Args:
x: Gene expression (B, num_genes)
cell_type: Cell type indices (B,)
library_size: Library size (B,)
Returns:
vae_loss, diffusion_loss
"""
# Train VAE
recon_params, mu, logvar = self.vae(x, library_size)
vae_loss, _ = self.vae.loss(x, recon_params, mu, logvar, library_size)
# Train diffusion on latent
with torch.no_grad():
z0 = mu
# Cell type conditioning
cell_type_emb = self.cell_type_embed(cell_type)
# Diffusion loss
t = torch.randint(0, self.diffusion.num_steps, (z0.shape[0],), device=z0.device)
zt, noise = self.diffusion.add_noise(z0, t)
noise_pred = self.diffusion(zt, t, cell_type_emb)
diffusion_loss = F.mse_loss(noise_pred, noise)
return vae_loss, diffusion_loss
@torch.no_grad()
def generate(self, cell_type, num_samples=100, library_size=None):
"""
Generate single-cell profiles.
Args:
cell_type: Cell type index (scalar or (num_samples,))
num_samples: Number of samples
library_size: Optional library size
Returns:
x_gen: Generated expression (num_samples, num_genes)
"""
device = next(self.parameters()).device
# Cell type conditioning
if isinstance(cell_type, int):
cell_type = torch.full((num_samples,), cell_type, device=device)
cell_type_emb = self.cell_type_embed(cell_type)
# Sample latent from diffusion
z0 = self.diffusion.sample(num_samples, cell_type_emb)
# Decode to gene expression
x_gen = self.vae.decode(z0, library_size)
return x_gen
1.3 Training Strategy¶
def train_single_cell_generation(
model,
train_loader,
val_loader,
num_epochs=100,
lr=1e-4,
device='cuda',
):
"""
Train single-cell generation model.
Args:
model: SingleCellLatentDiffusion
train_loader: Training data (x, cell_type, library_size)
val_loader: Validation data
num_epochs: Number of epochs
lr: Learning rate
device: Device
"""
model.to(device)
# Separate optimizers
optimizer_vae = torch.optim.AdamW(model.vae.parameters(), lr=lr)
optimizer_diffusion = torch.optim.AdamW(
list(model.diffusion.parameters()) + list(model.cell_type_embed.parameters()),
lr=lr
)
for epoch in range(num_epochs):
model.train()
for x, cell_type, library_size in train_loader:
x = x.to(device)
cell_type = cell_type.to(device)
library_size = library_size.to(device)
# Forward
vae_loss, diffusion_loss = model(x, cell_type, library_size)
# Update VAE
optimizer_vae.zero_grad()
vae_loss.backward(retain_graph=True)
optimizer_vae.step()
# Update diffusion
optimizer_diffusion.zero_grad()
diffusion_loss.backward()
optimizer_diffusion.step()
print(f"Epoch {epoch+1}: VAE={vae_loss:.4f}, Diff={diffusion_loss:.4f}")
1.4 Evaluation¶
@torch.no_grad()
def evaluate_single_cell_generation(model, test_data, cell_types, device='cuda'):
"""
Evaluate single-cell generation quality.
Args:
model: Trained model
test_data: Real test data (N, num_genes)
cell_types: Cell type labels (N,)
device: Device
Returns:
metrics: Evaluation metrics
"""
import scanpy as sc
from scipy.stats import wasserstein_distance
model.eval()
model.to(device)
# Generate samples for each cell type
unique_cell_types = torch.unique(cell_types)
metrics = {}
for ct in unique_cell_types:
# Real data for this cell type
mask = (cell_types == ct)
real_data = test_data[mask]
# Generate synthetic data
num_samples = len(real_data)
synthetic_data = model.generate(ct.item(), num_samples)
# 1. Mean expression correlation
real_mean = real_data.mean(dim=0)
synth_mean = synthetic_data.mean(dim=0)
corr = torch.corrcoef(torch.stack([real_mean, synth_mean]))[0, 1]
# 2. Wasserstein distance (per gene, average)
w_dists = []
for gene_idx in range(real_data.shape[1]):
w_dist = wasserstein_distance(
real_data[:, gene_idx].cpu().numpy(),
synthetic_data[:, gene_idx].cpu().numpy()
)
w_dists.append(w_dist)
# 3. Sparsity similarity
real_sparsity = (real_data < 0.1).float().mean()
synth_sparsity = (synthetic_data < 0.1).float().mean()
metrics[f'cell_type_{ct.item()}'] = {
'mean_correlation': corr.item(),
'wasserstein_distance': np.mean(w_dists),
'real_sparsity': real_sparsity.item(),
'synth_sparsity': synth_sparsity.item(),
}
return metrics
2. Perturbation Prediction¶
2.1 Task Definition¶
Goal: Predict cellular response to genetic/chemical perturbations.
Applications:
- Virtual screening (predict without experiment)
- Combination prediction (multiple perturbations)
- Mechanism discovery (analyze latent changes)
2.2 Architecture¶
class PerturbationLatentDiffusion(nn.Module):
"""
Latent diffusion for perturbation prediction.
Predicts perturbed state from baseline + perturbation.
Args:
num_genes: Number of genes
latent_dim: Latent dimension
perturbation_dim: Perturbation embedding dimension
"""
def __init__(
self,
num_genes=5000,
latent_dim=256,
perturbation_dim=128,
):
super().__init__()
# VAE
self.vae = GeneExpressionVAE(
num_genes=num_genes,
latent_dim=latent_dim,
decoder_type='zinb',
)
# Perturbation encoder
self.perturbation_encoder = nn.Sequential(
nn.Linear(num_genes, 512),
nn.LayerNorm(512),
nn.GELU(),
nn.Linear(512, perturbation_dim),
)
# Latent diffusion conditioned on baseline + perturbation
self.diffusion = LatentDiffusionModel(
latent_dim=latent_dim,
model_type='dit',
)
# Conditioning projection
self.condition_proj = nn.Linear(latent_dim + perturbation_dim, latent_dim)
def forward(self, x_baseline, x_perturbed, perturbation_indicator):
"""
Training forward pass.
Args:
x_baseline: Baseline expression (B, num_genes)
x_perturbed: Perturbed expression (B, num_genes)
perturbation_indicator: One-hot perturbation (B, num_genes)
Returns:
vae_loss, diffusion_loss
"""
# Encode baseline and perturbed
z_baseline = self.vae.encode(x_baseline)
z_perturbed = self.vae.encode(x_perturbed)
# Encode perturbation
pert_emb = self.perturbation_encoder(perturbation_indicator)
# Condition: baseline latent + perturbation
condition = torch.cat([z_baseline, pert_emb], dim=-1)
condition = self.condition_proj(condition)
# Diffusion: predict perturbed latent from baseline + perturbation
t = torch.randint(0, self.diffusion.num_steps, (z_perturbed.shape[0],), device=z_perturbed.device)
zt, noise = self.diffusion.add_noise(z_perturbed, t)
noise_pred = self.diffusion(zt, t, condition)
diffusion_loss = F.mse_loss(noise_pred, noise)
return diffusion_loss
@torch.no_grad()
def predict_perturbation(self, x_baseline, perturbation_indicator, library_size=None):
"""
Predict perturbed state.
Args:
x_baseline: Baseline expression (B, num_genes)
perturbation_indicator: Perturbation (B, num_genes)
library_size: Library size (B,)
Returns:
x_perturbed_pred: Predicted perturbed expression (B, num_genes)
"""
# Encode baseline
z_baseline = self.vae.encode(x_baseline)
# Encode perturbation
pert_emb = self.perturbation_encoder(perturbation_indicator)
# Condition
condition = torch.cat([z_baseline, pert_emb], dim=-1)
condition = self.condition_proj(condition)
# Sample perturbed latent
z_perturbed = self.diffusion.sample(len(x_baseline), condition)
# Decode
x_perturbed_pred = self.vae.decode(z_perturbed, library_size)
return x_perturbed_pred
2.3 Delta-in-Latent Formulation¶
Alternative approach: Predict the change in latent space.
class DeltaLatentPerturbation(nn.Module):
"""
Predict delta in latent space.
More stable than predicting absolute perturbed state.
"""
def __init__(
self,
num_genes=5000,
latent_dim=256,
perturbation_dim=128,
):
super().__init__()
self.vae = GeneExpressionVAE(num_genes, latent_dim, 'zinb')
self.perturbation_encoder = nn.Sequential(
nn.Linear(num_genes, 512),
nn.GELU(),
nn.Linear(512, perturbation_dim),
)
# Predict delta
self.delta_predictor = nn.Sequential(
nn.Linear(latent_dim + perturbation_dim, 512),
nn.LayerNorm(512),
nn.GELU(),
nn.Linear(512, 512),
nn.LayerNorm(512),
nn.GELU(),
nn.Linear(512, latent_dim),
)
def forward(self, x_baseline, x_perturbed, perturbation_indicator):
"""
Training: predict delta in latent space.
"""
# Encode
z_baseline = self.vae.encode(x_baseline)
z_perturbed = self.vae.encode(x_perturbed)
# True delta
delta_true = z_perturbed - z_baseline
# Encode perturbation
pert_emb = self.perturbation_encoder(perturbation_indicator)
# Predict delta
delta_pred = self.delta_predictor(torch.cat([z_baseline, pert_emb], dim=-1))
# MSE loss on delta
loss = F.mse_loss(delta_pred, delta_true)
return loss
@torch.no_grad()
def predict_perturbation(self, x_baseline, perturbation_indicator, library_size=None):
"""
Predict perturbed state via delta.
"""
# Encode baseline
z_baseline = self.vae.encode(x_baseline)
# Encode perturbation
pert_emb = self.perturbation_encoder(perturbation_indicator)
# Predict delta
delta_pred = self.delta_predictor(torch.cat([z_baseline, pert_emb], dim=-1))
# Add delta to baseline
z_perturbed = z_baseline + delta_pred
# Decode
x_perturbed_pred = self.vae.decode(z_perturbed, library_size)
return x_perturbed_pred
2.4 Evaluation¶
@torch.no_grad()
def evaluate_perturbation_prediction(
model,
test_data,
perturbations,
device='cuda',
):
"""
Evaluate perturbation prediction.
Args:
model: Trained model
test_data: (baseline, perturbed, perturbation_indicator) tuples
perturbations: List of perturbation names
device: Device
Returns:
metrics: Evaluation metrics
"""
from scipy.stats import pearsonr
model.eval()
model.to(device)
all_correlations = []
all_mse = []
for x_baseline, x_perturbed_true, pert_indicator in test_data:
x_baseline = x_baseline.to(device)
x_perturbed_true = x_perturbed_true.to(device)
pert_indicator = pert_indicator.to(device)
# Predict
x_perturbed_pred = model.predict_perturbation(x_baseline, pert_indicator)
# Correlation (per sample)
for i in range(len(x_baseline)):
corr, _ = pearsonr(
x_perturbed_true[i].cpu().numpy(),
x_perturbed_pred[i].cpu().numpy()
)
all_correlations.append(corr)
# MSE
mse = F.mse_loss(x_perturbed_pred, x_perturbed_true)
all_mse.append(mse.item())
metrics = {
'mean_correlation': np.mean(all_correlations),
'median_correlation': np.median(all_correlations),
'mean_mse': np.mean(all_mse),
}
return metrics
3. Multi-Omics Translation¶
3.1 Task Definition¶
Goal: Predict one modality from another (e.g., protein from RNA).
Applications:
- Fill missing modalities
- Cross-modality validation
- Integrated multi-omics analysis
3.2 Architecture¶
class MultiOmicsLatentDiffusion(nn.Module):
"""
Latent diffusion for multi-omics translation.
Shared latent space for RNA and Protein.
Args:
num_genes_rna: Number of RNA genes
num_genes_protein: Number of protein genes
latent_dim: Shared latent dimension
"""
def __init__(
self,
num_genes_rna=5000,
num_genes_protein=100,
latent_dim=256,
):
super().__init__()
# RNA encoder
self.rna_encoder = GeneExpressionEncoder(
num_genes=num_genes_rna,
latent_dim=latent_dim,
)
# Protein encoder
self.protein_encoder = GeneExpressionEncoder(
num_genes=num_genes_protein,
latent_dim=latent_dim,
)
# RNA decoder (ZINB)
self.rna_decoder = ZINBDecoder(
latent_dim=latent_dim,
num_genes=num_genes_rna,
)
# Protein decoder (NB, less sparse)
self.protein_decoder = NegativeBinomialDecoder(
latent_dim=latent_dim,
num_genes=num_genes_protein,
)
# Latent diffusion
self.diffusion = LatentDiffusionModel(
latent_dim=latent_dim,
model_type='dit',
)
def forward(self, x_rna, x_protein):
"""
Training: align RNA and Protein in latent space.
Args:
x_rna: RNA expression (B, num_genes_rna)
x_protein: Protein expression (B, num_genes_protein)
Returns:
loss_dict: Dictionary of losses
"""
# Encode both modalities
mu_rna, logvar_rna = self.rna_encoder(x_rna)
mu_protein, logvar_protein = self.protein_encoder(x_protein)
z_rna = self.rna_encoder.reparameterize(mu_rna, logvar_rna)
z_protein = self.protein_encoder.reparameterize(mu_protein, logvar_protein)
# Reconstruction losses
rna_recon = self.rna_decoder(z_rna)
protein_recon = self.protein_decoder(z_protein)
loss_rna_recon = -self.rna_decoder.log_prob(x_rna, *rna_recon).mean()
loss_protein_recon = -self.protein_decoder.log_prob(x_protein, *protein_recon).mean()
# KL losses
kl_rna = -0.5 * torch.sum(1 + logvar_rna - mu_rna.pow(2) - logvar_rna.exp(), dim=-1).mean()
kl_protein = -0.5 * torch.sum(1 + logvar_protein - mu_protein.pow(2) - logvar_protein.exp(), dim=-1).mean()
# Alignment loss (latents should be similar)
loss_align = F.mse_loss(z_rna, z_protein)
# Total loss
loss = loss_rna_recon + loss_protein_recon + kl_rna + kl_protein + 10.0 * loss_align
return {
'loss': loss,
'rna_recon': loss_rna_recon.item(),
'protein_recon': loss_protein_recon.item(),
'kl_rna': kl_rna.item(),
'kl_protein': kl_protein.item(),
'align': loss_align.item(),
}
@torch.no_grad()
def translate_rna_to_protein(self, x_rna):
"""
Translate RNA to Protein.
Args:
x_rna: RNA expression (B, num_genes_rna)
Returns:
x_protein_pred: Predicted protein (B, num_genes_protein)
"""
# Encode RNA to latent
mu_rna, _ = self.rna_encoder(x_rna)
# Decode to protein
protein_params = self.protein_decoder(mu_rna)
x_protein_pred = protein_params[0] # Mean
return x_protein_pred
@torch.no_grad()
def translate_protein_to_rna(self, x_protein):
"""
Translate Protein to RNA.
Args:
x_protein: Protein expression (B, num_genes_protein)
Returns:
x_rna_pred: Predicted RNA (B, num_genes_rna)
"""
# Encode protein to latent
mu_protein, _ = self.protein_encoder(x_protein)
# Decode to RNA
rna_params = self.rna_decoder(mu_protein)
mean, _, dropout = rna_params
x_rna_pred = (1 - dropout) * mean
return x_rna_pred
4. Trajectory Modeling¶
4.1 Task Definition¶
Goal: Model developmental or disease trajectories over time.
Applications:
- Predict future cell states
- Identify branch points
- Model differentiation
4.2 Architecture¶
class TrajectoryLatentDiffusion(nn.Module):
"""
Latent diffusion for trajectory modeling.
Predicts future states conditioned on time.
Args:
num_genes: Number of genes
latent_dim: Latent dimension
"""
def __init__(
self,
num_genes=5000,
latent_dim=256,
):
super().__init__()
self.vae = GeneExpressionVAE(num_genes, latent_dim, 'zinb')
# Time encoder
self.time_encoder = nn.Sequential(
SinusoidalPositionEmbeddings(128),
nn.Linear(128, 128),
nn.GELU(),
)
# Latent diffusion conditioned on current state + time
self.diffusion = LatentDiffusionModel(latent_dim, 'dit')
@torch.no_grad()
def predict_trajectory(self, x_start, time_points, library_size=None):
"""
Predict trajectory from starting state.
Args:
x_start: Starting expression (B, num_genes)
time_points: Time points to predict (T,)
library_size: Library size (B,)
Returns:
trajectory: Predicted trajectory (B, T, num_genes)
"""
device = x_start.device
batch_size = len(x_start)
# Encode starting state
z_current = self.vae.encode(x_start)
trajectory = []
for t in time_points:
# Time embedding
t_tensor = torch.full((batch_size,), t, device=device)
t_emb = self.time_encoder(t_tensor)
# Condition on current state + time
condition = torch.cat([z_current, t_emb], dim=-1)
# Sample next state
z_next = self.diffusion.sample(batch_size, condition)
# Decode
x_next = self.vae.decode(z_next, library_size)
trajectory.append(x_next)
# Update current state
z_current = z_next
trajectory = torch.stack(trajectory, dim=1) # (B, T, num_genes)
return trajectory
5. Spatial Transcriptomics¶
5.1 Task Definition¶
Goal: Generate spatial gene expression patterns.
Applications:
- Super-resolution (increase spatial resolution)
- Missing region imputation
- 3D reconstruction
5.2 Architecture¶
class SpatialLatentDiffusion(nn.Module):
"""
Latent diffusion for spatial transcriptomics.
Conditioned on spatial coordinates.
Args:
num_genes: Number of genes
latent_dim: Latent dimension
spatial_dim: Spatial coordinate dimension (2 or 3)
"""
def __init__(
self,
num_genes=5000,
latent_dim=256,
spatial_dim=2,
):
super().__init__()
self.vae = GeneExpressionVAE(num_genes, latent_dim, 'zinb')
# Spatial encoder
self.spatial_encoder = nn.Sequential(
nn.Linear(spatial_dim, 128),
nn.GELU(),
nn.Linear(128, 128),
nn.GELU(),
)
self.diffusion = LatentDiffusionModel(latent_dim, 'dit')
@torch.no_grad()
def generate_at_location(self, coordinates, library_size=None):
"""
Generate expression at spatial locations.
Args:
coordinates: Spatial coordinates (B, spatial_dim)
library_size: Library size (B,)
Returns:
x_gen: Generated expression (B, num_genes)
"""
# Encode spatial coordinates
spatial_emb = self.spatial_encoder(coordinates)
# Sample latent conditioned on location
z = self.diffusion.sample(len(coordinates), spatial_emb)
# Decode
x_gen = self.vae.decode(z, library_size)
return x_gen
@torch.no_grad()
def super_resolution(self, x_low_res, coords_low_res, coords_high_res):
"""
Super-resolution: predict high-res from low-res.
Args:
x_low_res: Low-resolution expression (N_low, num_genes)
coords_low_res: Low-res coordinates (N_low, spatial_dim)
coords_high_res: High-res coordinates (N_high, spatial_dim)
Returns:
x_high_res: High-resolution expression (N_high, num_genes)
"""
# Encode low-res to latent
z_low_res = self.vae.encode(x_low_res)
# Interpolate latent to high-res locations
# (Simple approach: nearest neighbor or Gaussian kernel)
from scipy.spatial import cKDTree
tree = cKDTree(coords_low_res.cpu().numpy())
distances, indices = tree.query(coords_high_res.cpu().numpy(), k=3)
# Weighted average based on distance
weights = 1.0 / (distances + 1e-8)
weights = weights / weights.sum(axis=1, keepdims=True)
z_high_res = torch.zeros(len(coords_high_res), z_low_res.shape[1], device=z_low_res.device)
for i in range(len(coords_high_res)):
for j, idx in enumerate(indices[i]):
z_high_res[i] += weights[i, j] * z_low_res[idx]
# Decode
x_high_res = self.vae.decode(z_high_res)
return x_high_res
Comparison with Existing Methods¶
Single-Cell Generation¶
| Method | Approach | Pros | Cons |
|---|---|---|---|
| scGAN | GAN | Fast sampling | Mode collapse, unstable |
| scVI | VAE | Fast, interpretable | Blurry samples |
| Latent Diffusion | VAE + Diffusion | Sharp, diverse | Slower sampling |
Perturbation Prediction¶
| Method | Approach | Pros | Cons |
|---|---|---|---|
| scGen | VAE + arithmetic | Simple, fast | Linear assumption |
| CPA | Autoencoder + perturbation | Flexible | No uncertainty |
| Latent Diffusion | VAE + Diffusion | Uncertainty, flexible | More complex |
Multi-Omics¶
| Method | Approach | Pros | Cons |
|---|---|---|---|
| MOFA | Factor analysis | Interpretable | Linear |
| totalVI | Joint VAE | Unified framework | Fixed architecture |
| Latent Diffusion | Joint VAE + Diffusion | Flexible, high-quality | Training complexity |
Key Takeaways¶
Applications¶
- Single-cell generation — Data augmentation, rare cell types
- Perturbation prediction — Virtual screening, combinations
- Multi-omics translation — Fill missing modalities
- Trajectory modeling — Predict future states
- Spatial transcriptomics — Super-resolution, imputation
Advantages¶
- Efficiency — 10-100× faster than pixel-space
- Quality — Better than VAE, stable than GAN
- Flexibility — Multi-modal, multi-task
- Uncertainty — Probabilistic predictions
Best Practices¶
- Start with VAE — Get good latent space first
- Condition carefully — Use appropriate conditioning mechanism
- Validate biologically — Not just loss metrics
- Compare baselines — scGen, CPA, totalVI
Related Documents¶
- 00_latent_diffusion_overview.md — High-level concepts
- 01_latent_diffusion_foundations.md — Architecture details
- 02_latent_diffusion_training.md — Training strategies
- 04_latent_diffusion_combio.md — Complete implementation
References¶
Single-Cell Generation:
- Marouf et al. (2020): "Realistic in silico generation and augmentation of single-cell RNA-seq data using generative adversarial networks" (scGAN)
- Lopez et al. (2018): "Deep generative modeling for single-cell transcriptomics" (scVI)
Perturbation Prediction:
- Lotfollahi et al. (2019): "scGen predicts single-cell perturbation responses"
- Lotfollahi et al. (2023): "Predicting cellular responses to novel drug combinations with a deep generative model" (CPA)
Multi-Omics:
- Argelaguet et al. (2018): "Multi-Omics Factor Analysis" (MOFA)
- Gayoso et al. (2021): "Joint probabilistic modeling of single-cell multi-omic data with totalVI"
Spatial Transcriptomics:
- Cable et al. (2022): "Robust decomposition of cell type mixtures in spatial transcriptomics"
- Biancalani et al. (2021): "Deep learning and alignment of spatially resolved single-cell transcriptomes"