β-VAE: Disentanglement and the Information Bottleneck¶
Building on the VAE foundation, β-VAE introduces a single hyperparameter that controls the trade-off between reconstruction quality and latent space structure.
1. Motivation: Why Modify the VAE?¶
Standard VAEs optimize:
This treats reconstruction and regularization equally. But what if we want:
- More structured latent space → increase KL weight
- Better reconstruction → decrease KL weight
β-VAE makes this explicit.
2. The β-VAE Objective¶
Simply multiply the KL term by β:
| β value | Effect |
|---|---|
| β = 1 | Standard VAE |
| β > 1 | Stronger regularization → more disentanglement, worse reconstruction |
| β < 1 | Weaker regularization → better reconstruction, less structure |
3. Information Bottleneck Interpretation¶
β-VAE can be understood through the information bottleneck framework:
Where: - \(I(z; y)\) — mutual information between latent and target (reconstruction) - \(I(z; x)\) — mutual information between latent and input (compression)
Higher β forces the model to: 1. Compress more aggressively 2. Keep only the most informative features 3. Discard nuisance variation
4. Disentanglement: What Does It Mean?¶
A disentangled representation has:
- Each latent dimension captures one independent factor of variation
- Changing one dimension changes one semantic attribute
- Dimensions are statistically independent
Example (Images)¶
| Dimension | Controls |
|---|---|
| \(z_1\) | Rotation |
| \(z_2\) | Scale |
| \(z_3\) | Color |
Example (Gene Expression)¶
| Dimension | Controls |
|---|---|
| \(z_1\) | Cell type identity |
| \(z_2\) | Cell cycle phase |
| \(z_3\) | Batch effect |
5. Why β > 1 Encourages Disentanglement¶
The KL term can be decomposed:
Increasing β: 1. Reduces \(I(z; x)\) — forces compression 2. Pushes \(q(z)\) toward \(p(z) = \mathcal{N}(0, I)\) — encourages independence
The factorial prior \(p(z) = \prod_i p(z_i)\) induces statistical independence between dimensions.
6. The Reconstruction-Disentanglement Trade-off¶
This is the fundamental tension:
β small ←————————————————————→ β large
Better reconstruction Better disentanglement
Entangled latents Worse reconstruction
Overfitting risk Posterior collapse risk
Posterior Collapse¶
When β is too high: - \(q(z|x) \approx p(z)\) for all \(x\) - Latent carries no information - Decoder ignores \(z\), generates "average" output
7. Disentanglement Metrics¶
DCI (Disentanglement, Completeness, Informativeness)¶
- Disentanglement: Does each code capture at most one factor?
- Completeness: Is each factor captured by at most one code?
- Informativeness: Can factors be predicted from codes?
MIG (Mutual Information Gap)¶
Measures the gap between the most and second-most informative latent for each factor.
SAP (Separated Attribute Predictability)¶
Trains classifiers to predict factors from individual latents.
8. Implementation¶
Loss Function¶
def beta_vae_loss(x, x_recon, mu, logvar, beta=4.0):
"""β-VAE loss with configurable β."""
# Reconstruction (negative log-likelihood)
recon_loss = F.mse_loss(x_recon, x, reduction='sum')
# KL divergence
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
# β-weighted ELBO
return recon_loss + beta * kl_loss
Annealing Strategies¶
Rather than fixed β, gradually increase it:
def get_beta(epoch, warmup_epochs=10, target_beta=4.0):
"""Linear β annealing."""
if epoch < warmup_epochs:
return target_beta * epoch / warmup_epochs
return target_beta
This helps avoid posterior collapse early in training.
9. Variants and Extensions¶
β-TCVAE (Total Correlation VAE)¶
Decomposes KL into three terms:
Only penalizes the total correlation term, which directly measures dependence between dimensions.
Factor VAE¶
Adds an adversarial term to encourage factorial \(q(z)\):
Where \(\bar{q}(z) = \prod_i q(z_i)\) is the factorial approximation.
DIP-VAE (Disentangled Inferred Prior)¶
Matches moments of \(q(z)\) to the prior:
10. Application to Gene Expression¶
Why Disentanglement Matters¶
In single-cell biology, we want latents that separate: - Biological signal (cell type, state) - Technical noise (batch, sequencing depth)
A disentangled model enables: 1. Batch correction: Zero out batch dimensions 2. Counterfactuals: Change disease dimension, keep cell type 3. Interpretability: Each dimension has biological meaning
Practical Considerations¶
| Challenge | Solution |
|---|---|
| No ground truth factors | Use known covariates (batch, donor) as proxies |
| High dimensionality | Start with PCA-reduced input |
| Sparse data | Use negative binomial likelihood |
11. Experiments to Run¶
Experiment 1: β Sweep¶
betas = [0.1, 0.5, 1.0, 2.0, 4.0, 10.0]
for beta in betas:
model = BetaVAE(beta=beta)
train(model)
evaluate_reconstruction(model)
evaluate_disentanglement(model)
Experiment 2: Latent Traversal¶
For each dimension \(i\): 1. Encode a sample: \(\mu, \sigma = \text{encode}(x)\) 2. Vary \(z_i\) from \(-3\) to \(+3\) 3. Decode and visualize
Experiment 3: Condition Prediction¶
Train linear classifiers to predict tissue/disease from individual latent dimensions.
12. Connection to Diffusion¶
β-VAE's insight—that compression induces structure—reappears in diffusion:
- Diffusion adds noise (compression) then learns to denoise
- The noise schedule is analogous to β annealing
- Both trade reconstruction for latent regularity
This is why understanding β-VAE deeply prepares you for diffusion.
13. References¶
- Higgins et al. (2017) — "β-VAE: Learning Basic Visual Concepts with a Constrained Variational Framework"
- Chen et al. (2018) — "Isolating Sources of Disentanglement in VAEs" (β-TCVAE)
- Kim & Mnih (2018) — "Disentangling by Factorising" (Factor VAE)
- Kumar et al. (2018) — "Variational Inference of Disentangled Latent Concepts from Unlabeled Observations" (DIP-VAE)
- Locatello et al. (2019) — "Challenging Common Assumptions in the Unsupervised Learning of Disentangled Representations"
Next Steps¶
After β-VAE: 1. IWAE — Tighter bounds without changing the objective structure 2. Score matching — The bridge to diffusion models
See ROADMAP.md for the full learning path.