Architecture Q&A: Gene Expression Diffusion Models¶
This document addresses common questions about architectural choices for gene expression data in diffusion models, particularly around the latent token approach.
Related: 02a_diffusion_arch_gene_expression.md — Main architecture document
Question 1: Handling Thousands of Samples¶
Question: The latent token architecture shows encoding gene expression to tokens. But how does this account for gene expression data where each gene is represented by thousands of samples with expression levels?
Answer: Sample-by-Sample Processing¶
The architecture processes gene expression sample-by-sample, not all samples at once. Let me clarify the dimensions:
What the Architecture Actually Does¶
# ONE sample (single cell or bulk RNA-seq measurement)
x = gene_expression # Shape: (num_genes,) = (20000,)
# This is ONE measurement: [gene1_count, gene2_count, ..., gene20000_count]
# Encode to tokens
z = encoder(x) # Shape: (num_tokens, token_dim) = (64, 256)
# Creates 64 "semantic tokens" from the 20K gene values
# In practice, we batch multiple samples
x_batch = gene_expressions # Shape: (batch_size, num_genes) = (32, 20000)
z_batch = encoder(x_batch) # Shape: (32, 64, 256)
# 32 samples, each encoded to 64 tokens
Data Flow Clarification¶
DATASET STRUCTURE:
──────────────────
You have: N samples × 20,000 genes
- Sample 1: [gene1=5, gene2=120, ..., gene20000=3]
- Sample 2: [gene1=8, gene2=95, ..., gene20000=7]
- ...
- Sample N: [gene1=12, gene2=150, ..., gene20000=2]
TRAINING:
─────────
Each training iteration:
1. Sample a batch (e.g., 32 samples)
2. Each sample processed independently through encoder
Sample 1 → Encoder → 64 tokens of dim 256
Sample 2 → Encoder → 64 tokens of dim 256
...
Sample 32 → Encoder → 64 tokens of dim 256
3. Batch shape: (32, 64, 256)
- 32 samples (batch dimension)
- 64 tokens per sample (sequence dimension)
- 256 features per token (feature dimension)
4. Transformer processes each sample's token sequence
5. Decoder: tokens → gene expression prediction
The Key Insight¶
The encoder is NOT trying to encode all samples into one representation. Instead:
- Input: One gene expression profile (20K genes for one cell/sample)
- Output: A compressed representation as 64 tokens (each 256-dim)
- Batching: Process multiple samples in parallel (standard minibatch training)
Think of it like processing images:
# Images
images = (batch=32, channels=3, height=224, width=224)
# Each image processed independently
# Gene expression
gene_expr = (batch=32, num_genes=20000)
# Each sample processed independently → (batch=32, num_tokens=64, token_dim=256)
Complete Training Example¶
# ═══════════════════════════════════════════════════════════
# GENE EXPRESSION DATA STRUCTURE
# ═══════════════════════════════════════════════════════════
# Dataset: Collection of samples
dataset = {
'sample_1': [gene1=5, gene2=120, ..., gene20000=3], # Cell 1
'sample_2': [gene1=8, gene2=95, ..., gene20000=7], # Cell 2
'sample_3': [gene1=12, gene2=150, ..., gene20000=2], # Cell 3
...
'sample_N': [...]
}
# ═══════════════════════════════════════════════════════════
# DURING TRAINING (Minibatch)
# ═══════════════════════════════════════════════════════════
# Step 1: Sample a batch
batch = 32 samples
x = (32, 20000) # 32 samples, each with 20K gene counts
# Step 2: Encode each sample to tokens
# Each sample processed independently!
z = encoder(x) # (32, 64, 256)
# ↑ ↑ ↑
# | | └─ Features per token
# | └───── Tokens per sample
# └───────── Batch size
# Step 3: Add positional encoding (per sample)
z = z + pos_embed # pos_embed: (1, 64, 256), broadcasts across batch
# Step 4: Transformer processes each sample's token sequence
# Attention operates WITHIN each sample's 64 tokens
# (Can also attend across samples if desired, but typically within)
t = timesteps # (32,) - Diffusion timestep for each sample in batch
# e.g., [500, 732, 123, ..., 891]
# Controls noise level (high t = more noise)
condition = conditions # (32, cond_dim) - Optional conditioning per sample
# Examples:
# - Cell type: [CD4+, B_cell, NK, ..., Monocyte]
# - Perturbation: [CRISPR_gene1, drug_A, ..., control]
# - Disease state: [healthy, disease_A, ..., healthy]
# Can be None for unconditional generation
z_out = transformer(z, t, condition) # (32, 64, 256)
# Transformer uses:
# - z: Token sequences to process
# - t: Time conditioning (via AdaLN - modulates features based on noise level)
# - condition: Biological conditioning (affects what to generate)
# Step 5: Decode back to gene space
# Each sample's tokens → that sample's gene expression
x_pred = decoder(z_out) # (32, 20000)
Understanding the Transformer Inputs¶
Let's break down what transformer(z, t, condition) means:
1. Token Sequences (z)¶
z = (32, 64, 256)
# Main input: The latent token sequences to process
# - 32 samples in batch
# - Each sample has 64 tokens
# - Each token has 256 features
# Think of each sample's 64 tokens as a "sentence"
# where each token represents a semantic gene module
2. Timestep (t)¶
t = (32,) # One timestep per sample in batch
# Example values: [500, 732, 123, 891, ..., 445]
# What does t mean?
# - Diffusion timestep: ranges from 0 (clean) to T (pure noise)
# - t=0: Nearly clean gene expression
# - t=500: Medium noise level
# - t=1000: Almost pure noise
# How is t used?
# - Embedded via sinusoidal encoding: t → time_embed (256-dim)
# - Used for Adaptive LayerNorm (AdaLN):
# γ, β = MLP(time_embed)
# h_modulated = γ * LayerNorm(h) + β
# - Tells model "how much noise is present"
# - Model adjusts its behavior based on noise level
# Why different t per sample?
# - During training: randomly sample t for each sample
# - Makes model learn to denoise at ALL noise levels
# - More efficient training (diverse timesteps per batch)
3. Condition (condition)¶
condition = (32, cond_dim) or None
# Optional conditioning information per sample
# Examples of what condition could be:
# A) Cell type conditioning
condition = ['CD4+ T cell', 'B cell', 'NK cell', ...]
# After embedding: (32, 256)
# Purpose: Generate expression for specific cell types
# B) Perturbation conditioning
condition = {
'gene_knockout': ['FOXO1', 'TP53', None, ...],
'drug': ['drug_A', None, 'drug_B', ...],
'dose': [1.0, 0.0, 0.5, ...]
}
# After embedding: (32, 512) # Combined embeddings
# Purpose: Generate response to perturbations
# C) Disease state conditioning
condition = ['healthy', 'diabetes', 'healthy', 'cancer', ...]
# After embedding: (32, 256)
# Purpose: Generate disease-specific expression
# D) Multi-modal conditioning
condition = {
'cell_type': ['T_cell', ...],
'tissue': ['liver', ...],
'age': [45, ...],
'sex': ['M', ...]
}
# After embedding: (32, 1024) # All concatenated
# Purpose: Complex, multi-factor conditioning
# E) No conditioning (unconditional generation)
condition = None
# Purpose: Generate generic, diverse gene expression
How to Represent Conditions of Different Complexity¶
Key Question: How do we go from raw conditioning information (strings, dictionaries, numbers) to the tensor that the transformer uses?
Strategy 1: Simple Categorical Conditions¶
Example: Cell type conditioning
# Raw input
cell_types = ['CD4+ T cell', 'B cell', 'NK cell', 'Monocyte', ...] # (32,)
# Step 1: Convert to indices
cell_type_to_idx = {
'CD4+ T cell': 0,
'B cell': 1,
'NK cell': 2,
'Monocyte': 3,
...
}
indices = [cell_type_to_idx[ct] for ct in cell_types] # [0, 1, 2, 3, ...]
# Step 2: Embed via embedding layer
class CellTypeEmbedder(nn.Module):
def __init__(self, num_cell_types=50, embed_dim=256):
super().__init__()
# Learned embedding for each cell type
self.embedding = nn.Embedding(num_cell_types, embed_dim)
def forward(self, cell_type_indices):
# Input: (batch,) - integer indices
# Output: (batch, embed_dim)
return self.embedding(cell_type_indices)
# Result
embedder = CellTypeEmbedder(num_cell_types=50, embed_dim=256)
condition_embed = embedder(torch.tensor(indices)) # (32, 256)
Why this works:
- Cell types are discrete categories
- Each gets a learnable vector representation
- Similar to word embeddings in NLP
Strategy 2: Complex Multi-Component Conditions¶
Example: Perturbation conditioning (gene knockout + drug + dose)
# Raw input (different types of information)
perturbations = {
'gene_knockout': ['FOXO1', 'TP53', None, 'MYC', ...], # Categorical or None
'drug': ['drug_A', None, 'drug_B', 'drug_A', ...], # Categorical or None
'dose': [1.0, 0.0, 0.5, 1.0, ...] # Continuous
}
# ════════════════════════════════════════════════════════════
# Approach A: Separate Embeddings + Concatenation
# ════════════════════════════════════════════════════════════
class PerturbationEmbedder(nn.Module):
def __init__(
self,
num_genes=1000, # Number of possible knockout targets
num_drugs=100, # Number of drugs
embed_dim=128 # Embedding dimension per component
):
super().__init__()
# Separate embeddings for each component
self.gene_embed = nn.Embedding(num_genes + 1, embed_dim) # +1 for "none"
self.drug_embed = nn.Embedding(num_drugs + 1, embed_dim) # +1 for "none"
# MLP for continuous dose
self.dose_encoder = nn.Sequential(
nn.Linear(1, embed_dim),
nn.SiLU(),
nn.Linear(embed_dim, embed_dim)
)
# Final projection (concatenated → combined)
self.combiner = nn.Sequential(
nn.Linear(3 * embed_dim, 512),
nn.SiLU(),
nn.Linear(512, 512)
)
# Special indices for "none"
self.gene_none_idx = num_genes
self.drug_none_idx = num_drugs
def forward(self, gene_indices, drug_indices, doses):
"""
Args:
gene_indices: (batch,) - indices of knockout genes (or none_idx)
drug_indices: (batch,) - indices of drugs (or none_idx)
doses: (batch,) - continuous dose values
Returns:
(batch, 512) - combined perturbation embedding
"""
# Embed each component separately
gene_emb = self.gene_embed(gene_indices) # (batch, 128)
drug_emb = self.drug_embed(drug_indices) # (batch, 128)
dose_emb = self.dose_encoder(doses[:, None]) # (batch, 128)
# Concatenate
combined = torch.cat([gene_emb, drug_emb, dose_emb], dim=-1) # (batch, 384)
# Project to final dimension
condition_embed = self.combiner(combined) # (batch, 512)
return condition_embed
# Usage
gene_to_idx = {'FOXO1': 0, 'TP53': 1, 'MYC': 2, None: embedder.gene_none_idx}
drug_to_idx = {'drug_A': 0, 'drug_B': 1, None: embedder.drug_none_idx}
gene_indices = torch.tensor([gene_to_idx[g] for g in perturbations['gene_knockout']])
drug_indices = torch.tensor([drug_to_idx[d] for d in perturbations['drug']])
doses = torch.tensor(perturbations['dose'])
embedder = PerturbationEmbedder()
condition_embed = embedder(gene_indices, drug_indices, doses) # (32, 512)
Why this works:
- Each component embedded separately (preserves semantics)
- Concatenation combines all information
- MLP learns interactions between components
- Handles missing values naturally (None → special index)
Strategy 3: Variable-Length Complex Conditions¶
Example: Multiple perturbations per sample (variable number)
# Raw input: Some samples have multiple perturbations
perturbations = [
['FOXO1_knockout', 'drug_A_0.5'], # Sample 1: 2 perturbations
['TP53_knockout'], # Sample 2: 1 perturbation
['drug_B_1.0', 'drug_C_0.3', 'MYC_knockout'], # Sample 3: 3 perturbations
[], # Sample 4: No perturbation
...
]
# ════════════════════════════════════════════════════════════
# Approach B: Set Embedding (Order-Invariant)
# ════════════════════════════════════════════════════════════
class SetPerturbationEmbedder(nn.Module):
def __init__(self, num_perturbations=500, embed_dim=256, max_perts=10):
super().__init__()
# Embedding for each perturbation type
self.pert_embed = nn.Embedding(num_perturbations + 1, embed_dim) # +1 for padding
# Transformer encoder (order-invariant aggregation)
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=embed_dim, nhead=4),
num_layers=2
)
# Pooling to fixed size
self.pool = nn.AdaptiveAvgPool1d(1) # or attention pooling
self.pad_idx = num_perturbations
self.max_perts = max_perts
def forward(self, perturbation_indices, mask):
"""
Args:
perturbation_indices: (batch, max_perts) - padded perturbation indices
mask: (batch, max_perts) - True where valid, False for padding
Returns:
(batch, embed_dim) - aggregated perturbation embedding
"""
# Embed each perturbation
pert_embs = self.pert_embed(perturbation_indices) # (batch, max_perts, embed_dim)
# Transformer (handles variable length via masking)
pert_embs = pert_embs.transpose(0, 1) # (max_perts, batch, embed_dim)
encoded = self.transformer(pert_embs, src_key_padding_mask=~mask)
encoded = encoded.transpose(0, 1) # (batch, max_perts, embed_dim)
# Pool to fixed size (mean over valid perturbations)
# Apply mask before pooling
encoded = encoded * mask.unsqueeze(-1)
condition_embed = encoded.sum(dim=1) / mask.sum(dim=1, keepdim=True)
return condition_embed # (batch, embed_dim)
# Usage: Convert to padded format
pert_to_idx = {'FOXO1_knockout': 0, 'drug_A_0.5': 1, ...}
def collate_perturbations(pert_lists, max_perts=10, pad_idx=500):
"""Convert variable-length lists to padded tensors"""
batch_size = len(pert_lists)
# Create padded tensor
indices = torch.full((batch_size, max_perts), pad_idx, dtype=torch.long)
mask = torch.zeros(batch_size, max_perts, dtype=torch.bool)
for i, perts in enumerate(pert_lists):
num_perts = min(len(perts), max_perts)
indices[i, :num_perts] = torch.tensor([pert_to_idx[p] for p in perts[:num_perts]])
mask[i, :num_perts] = True
return indices, mask
indices, mask = collate_perturbations(perturbations)
embedder = SetPerturbationEmbedder()
condition_embed = embedder(indices, mask) # (32, 256)
Why this works:
- Handles variable number of perturbations
- Order-invariant (set semantics, not sequence)
- Masking handles different lengths
- Aggregation (mean/max/attention) gives fixed-size output
Strategy 4: Hierarchical Conditions¶
Example: Multi-level biological context (organism → tissue → cell type → state)
# Raw input: Hierarchical structure
conditions = {
'organism': ['human', 'human', 'mouse', ...],
'tissue': ['liver', 'brain', 'liver', ...],
'cell_type': ['hepatocyte', 'neuron', 'hepatocyte', ...],
'state': ['healthy', 'healthy', 'diseased', ...]
}
# ════════════════════════════════════════════════════════════
# Approach C: Hierarchical Embeddings
# ════════════════════════════════════════════════════════════
class HierarchicalConditionEmbedder(nn.Module):
def __init__(self):
super().__init__()
# Separate embeddings for each level
self.organism_embed = nn.Embedding(10, 64) # Few organisms
self.tissue_embed = nn.Embedding(50, 128) # More tissues
self.cell_type_embed = nn.Embedding(200, 256) # Many cell types
self.state_embed = nn.Embedding(20, 128) # Few states
# Hierarchical combination (bottom-up)
self.combine_tissue_cell = nn.Sequential(
nn.Linear(128 + 256, 256),
nn.SiLU()
)
self.combine_org_tissue_cell = nn.Sequential(
nn.Linear(64 + 256, 256),
nn.SiLU()
)
self.combine_all = nn.Sequential(
nn.Linear(256 + 128, 512),
nn.SiLU(),
nn.Linear(512, 512)
)
def forward(self, organism_idx, tissue_idx, cell_type_idx, state_idx):
"""Hierarchical composition: organism > tissue > cell_type, + state"""
# Embed each level
org_emb = self.organism_embed(organism_idx) # (batch, 64)
tis_emb = self.tissue_embed(tissue_idx) # (batch, 128)
cel_emb = self.cell_type_embed(cell_type_idx) # (batch, 256)
sta_emb = self.state_embed(state_idx) # (batch, 128)
# Hierarchical composition
# Level 1: Tissue + Cell type (cell exists in tissue)
tissue_cell = self.combine_tissue_cell(
torch.cat([tis_emb, cel_emb], dim=-1)
) # (batch, 256)
# Level 2: Organism + (Tissue + Cell)
org_tissue_cell = self.combine_org_tissue_cell(
torch.cat([org_emb, tissue_cell], dim=-1)
) # (batch, 256)
# Level 3: Add state (orthogonal to hierarchy)
final = self.combine_all(
torch.cat([org_tissue_cell, sta_emb], dim=-1)
) # (batch, 512)
return final
# Usage
organism_to_idx = {'human': 0, 'mouse': 1, 'rat': 2}
tissue_to_idx = {'liver': 0, 'brain': 1, ...}
# ... similar for cell_type and state
embedder = HierarchicalConditionEmbedder()
condition_embed = embedder(
torch.tensor([organism_to_idx[o] for o in conditions['organism']]),
torch.tensor([tissue_to_idx[t] for t in conditions['tissue']]),
torch.tensor([cell_type_to_idx[c] for c in conditions['cell_type']]),
torch.tensor([state_to_idx[s] for s in conditions['state']])
) # (32, 512)
Why this works:
- Respects biological hierarchy
- Bottom-up composition (cell → tissue → organism)
- Different embedding sizes reflect complexity
- Can incorporate prior knowledge about relationships
Comparison of Strategies¶
| Complexity | Example | Strategy | Output Shape | Best For |
|---|---|---|---|---|
| Simple Categorical | Cell type | Embedding layer | (batch, 256) | Single discrete condition |
| Multi-Component | Gene KO + Drug + Dose | Separate embeds + concat | (batch, 512) | Fixed set of heterogeneous conditions |
| Variable-Length | Multiple perturbations | Set embedding + pooling | (batch, 256) | Variable number of conditions |
| Hierarchical | Organism → Tissue → Cell | Hierarchical composition | (batch, 512) | Nested/structured conditions |
| Very Complex | Text descriptions | Pre-trained encoder (CLIP, T5) | (batch, 768) | Natural language |
Practical Implementation Template¶
class UniversalConditionEmbedder(nn.Module):
"""Handles conditions of varying complexity"""
def __init__(self):
super().__init__()
# Simple categorical
self.cell_type_embed = nn.Embedding(50, 256)
# Multi-component
self.perturbation_embed = PerturbationEmbedder()
# Continuous
self.dose_encoder = nn.Sequential(
nn.Linear(1, 128),
nn.SiLU(),
nn.Linear(128, 128)
)
# Combiner (if using multiple condition types)
self.combiner = nn.Linear(256 + 512 + 128, 512) # Adjust based on what you use
def forward(self, batch):
"""
Flexible forward pass based on what conditions are present
Args:
batch: Dictionary containing various condition types
Returns:
Combined condition embedding
"""
embeds = []
# Handle cell type if present
if 'cell_type' in batch:
cell_emb = self.cell_type_embed(batch['cell_type'])
embeds.append(cell_emb)
# Handle perturbation if present
if 'perturbation' in batch:
pert_emb = self.perturbation_embed(
batch['perturbation']['gene'],
batch['perturbation']['drug'],
batch['perturbation']['dose']
)
embeds.append(pert_emb)
# Handle continuous dose if present
if 'dose' in batch:
dose_emb = self.dose_encoder(batch['dose'][:, None])
embeds.append(dose_emb)
# Combine all present conditions
if len(embeds) == 0:
return None # Unconditional
elif len(embeds) == 1:
return embeds[0]
else:
combined = torch.cat(embeds, dim=-1)
return self.combiner(combined)
Summary: From Raw Data to Condition Tensor¶
The general pipeline:
Raw Condition Data
↓
[Convert to appropriate format]
↓ (strings → indices, numbers → tensors, etc.)
Processable Format
↓
[Embed each component]
↓ (embeddings, MLPs, encoders)
Component Embeddings
↓
[Combine components]
↓ (concatenation, addition, hierarchical composition)
Final Condition Tensor
↓ (batch, condition_dim)
[Input to Transformer]
Key principles: 1. Each component gets its own embedding strategy 2. Categorical → Embedding layers 3. Continuous → MLPs 4. Complex → Combination of above 5. Variable-length → Masking + pooling 6. Hierarchical → Compositional embeddings 7. Combine via concatenation, addition, or learned fusion
The condition tensor shape (batch, condition_dim) is then used by the transformer via AdaLN, cross-attention, or other conditioning mechanisms.
# How is condition used?
# Option 1: AdaLN (like time)
# γ_cond, β_cond = MLP(condition_embed)
# h = γ_time * γ_cond * LayerNorm(h) + β_time + β_cond
#
# Option 2: Cross-attention
# h_out = CrossAttention(query=h, key=condition, value=condition)
#
# Option 3: Concatenation
# combined = concat([time_embed, condition_embed])
# γ, β = MLP(combined)
Complete Example with All Inputs¶
# Training step
for batch in dataloader:
# Load data
x_0 = batch['expression'] # (32, 20000) - Clean gene expression
cell_types = batch['cell_type'] # (32,) - ['T_cell', 'B_cell', ...]
# Sample random timesteps (different for each sample)
t = torch.randint(0, 1000, (32,)) # e.g., [234, 789, 12, 901, ...]
# Add noise according to timestep
noise = torch.randn_like(x_0)
x_t = sqrt(alpha_bar[t]) * x_0 + sqrt(1 - alpha_bar[t]) * noise
# x_t: (32, 20000) - Noisy gene expression
# Embed conditions
condition = cell_type_embedder(cell_types) # (32, 256)
# Forward pass
z = encoder(x_t) # (32, 64, 256) - Encode to tokens
z_out = transformer(z, t, condition) # (32, 64, 256) - Process with context
noise_pred = decoder(z_out) # (32, 20000) - Predict noise
# Loss
loss = F.mse_loss(noise_pred, noise)
Question 2: Is This Like a "Black and White Image"?¶
Question: Is the token representation (num_tokens, token_dim) almost like a "black and white image" with only 1 channel of size (num_tokens, token_dim)?
Reference: The positional embedding line in the architecture:
Answer: No, It's a Sequence, Not an Image¶
Great intuition, but not quite! Let me explain the dimensionality:
Image vs Token Representation¶
# GRAYSCALE IMAGE (what you're thinking of)
image = (batch, 1, height, width)
= (32, 1, 224, 224)
# - 1 channel (grayscale)
# - Spatial dimensions: height × width
# - Still has 2D spatial structure
# LATENT TOKENS (what we have)
tokens = (batch, num_tokens, token_dim)
= (32, 64, 256)
# - 64 tokens (like 64 "patches")
# - 256 features per token
# - NO spatial structure! Just a SEQUENCE of tokens
Better Analogy: Sequence, Not Image¶
The token representation is more like:
TEXT SEQUENCE (NLP):
────────────────────
sentence = ["The", "cat", "sat", "on", "mat"]
embeddings = (batch, seq_len, embed_dim)
= (32, 5, 768)
# - 5 words in sequence
# - 768-dimensional embedding per word
GENE EXPRESSION TOKENS:
───────────────────────
gene_profile → [token1, token2, ..., token64]
tokens = (batch, num_tokens, token_dim)
= (32, 64, 256)
# - 64 tokens in sequence
# - 256-dimensional features per token
# - Each token represents some "semantic cluster" of genes
Visual Comparison¶
IMAGE REPRESENTATION:
═══════════════════════
┌─────────────┐
│ ░░░░░░░░░░░ │ Channel 1 (R)
│ ░░░▓▓░░░░░░ │
│ ░░▓▓▓▓░░░░░ │
└─────────────┘
┌─────────────┐
│ ░░░░░░░░░░░ │ Channel 2 (G)
│ ░░░▓▓░░░░░░ │
└─────────────┘
┌─────────────┐
│ ░░░░░░░░░░░ │ Channel 3 (B)
└─────────────┘
Shape: (height, width, channels)
Structure: 2D spatial grid
TOKEN REPRESENTATION:
═══════════════════════
Token 1: [0.5, -0.2, 0.8, ..., 0.3] ← 256 features
Token 2: [0.1, 0.7, -0.4, ..., 0.9]
Token 3: [-0.3, 0.2, 0.6, ..., -0.1]
...
Token 64: [0.4, -0.5, 0.1, ..., 0.7]
Shape: (num_tokens, token_dim)
Structure: 1D sequence (like words in a sentence)
What Each Token Might Represent¶
Unlike pixels in an image, each token captures semantic information:
# Hypothetical learned representation
Token 1 → "Cell cycle genes" (high for proliferating cells)
Token 2 → "Immune response genes" (high in activated immune cells)
Token 3 → "Metabolic genes" (high in metabolically active cells)
Token 4 → "Housekeeping genes" (stable across conditions)
...
Token 64 → "Rare pathway genes"
# Each token is 256-dimensional, encoding complex patterns
Token 1 = [0.5, -0.2, 0.8, ..., 0.3]
↑ ↑ ↑ ↑
Features capturing different aspects of "cell cycle-ness"
Dimensionality Breakdown¶
Comparison Table¶
| Dimension | Name | Size | Meaning |
|---|---|---|---|
| Batch | batch_size |
32 | Number of samples processed together |
| Tokens | num_tokens |
64 | Number of semantic "chunks" per sample |
| Features | token_dim |
256 | Features describing each token |
Cross-Domain Comparison¶
| Data Type | Dimension 1 | Dimension 2 | Dimension 3 | Dimension 4 |
|---|---|---|---|---|
| Image | batch=32 | channels=3 | height=224 | width=224 |
| Tokens | batch=32 | num_tokens=64 | token_dim=256 | - |
| Text | batch=32 | seq_len=50 | embed_dim=768 | - |
Key difference:
- Images have 2D spatial structure (height × width)
- Tokens have 1D sequence structure (just num_tokens)
- The positional encoding provides ordering info (like in NLP)
Why Not Actually Like an Image?¶
If we tried to make it image-like:
# ❌ WRONG: Treating as image
tokens_as_image = (batch, channels=1, height=8, width=8)
= (32, 1, 8, 8) # 64 "pixels" arranged spatially
# ✅ CORRECT: Treating as sequence
tokens_as_sequence = (batch, seq_len=64, features=256)
= (32, 64, 256) # 64 tokens with 256 features each
Why Sequence is Better¶
1. No inherent spatial structure in gene expression
- Gene order in genome ≠ meaningful for expression patterns
- Unlike pixels, where neighbors have spatial meaning
2. Transformers work on sequences
- Self-attention doesn't assume spatial locality
- Can capture any gene-gene interactions
3. More like NLP than vision
- Genes are like "words" in a biological "sentence"
- Tokens are semantic clusters, not spatial patches
Example: What Spatial Structure Would Mean¶
# If we treated as 2D image (8×8 grid of tokens)
# ❌ This would imply:
# - Token at position (0,0) is "near" token at (0,1)
# - Token at position (0,0) is "far" from token at (7,7)
# - Spatial locality matters
# But in gene expression:
# - Token 1 (cell cycle) might interact strongly with Token 50 (DNA repair)
# - No reason to assume nearby tokens are more related
# - Attention should be free to connect any tokens
Practical Implications¶
For Model Design¶
1. Use sequence models, not convolutional models
# ✅ Good: Transformer (no spatial bias)
transformer = nn.TransformerEncoder(...)
# ❌ Bad: CNN (assumes spatial locality)
cnn = nn.Conv2d(...) # Wrong inductive bias for gene expression
2. Positional encoding is flexible
# Can use learned or sinusoidal
self.pos_embed = nn.Parameter(torch.randn(1, num_tokens, token_dim))
# Or sinusoidal (like BERT)
self.pos_embed = SinusoidalPositionEmbeddings(token_dim)
3. Attention is unrestricted
# Each token can attend to all other tokens
# No spatial locality assumption
attn_scores = Q @ K.T # (num_tokens, num_tokens)
# All-to-all attention
For Interpretation¶
1. Tokens are semantic, not spatial
- Analyze what biological patterns each token captures
- Use gene loadings to interpret tokens
- Compare to known pathways/modules
2. Token order doesn't matter (much)
- Positional encoding adds ordering info
- But tokens aren't inherently ordered like pixels
- Could potentially shuffle and retrain
3. Visualization differs from images
# For images: Show as 2D grid
plt.imshow(image)
# For tokens: Show as heatmap or t-SNE
plt.imshow(tokens.T) # (token_dim, num_tokens)
# Or project to 2D
tsne = TSNE(n_components=2)
tokens_2d = tsne.fit_transform(tokens)
plt.scatter(tokens_2d[:, 0], tokens_2d[:, 1])
Summary¶
Question 1: Handling Thousands of Samples¶
The architecture processes one sample at a time (with batching for efficiency). The "thousands of samples" are your dataset, processed in minibatches during training, just like images in computer vision.
Key points:
- Each sample: 20K genes → 64 tokens (256-dim each)
- Batching: Process 32 samples in parallel
- Training: Iterate through dataset in minibatches
- Same as standard deep learning practice
Question 2: Is This Like a Black-and-White Image?¶
The token representation is NOT like a black-and-white image. It's more like:
Text embeddings: A sequence of semantic tokens - Each token is a 256-dimensional feature vector - No 2D spatial structure, just 1D sequence - Position info added via positional encoding
Key intuition: Think of it as compressing a 20,000-dimensional gene expression vector into a sequence of 64 semantic tokens (each 256-dim), where each token represents some learned biological pattern/module.
Correct Mental Model¶
Gene Expression → Encoder → Sequence of Semantic Tokens → Transformer → Decoder → Prediction
NOT:
Gene Expression → Encoder → 2D Image → CNN → Decoder → Prediction
Think: NLP (BERT, GPT) not Computer Vision (ResNet, ViT)
Related Questions¶
Q: Why 64 tokens specifically?¶
A: Hyperparameter choice balancing: - Fewer tokens (e.g., 32): Faster, but less capacity - More tokens (e.g., 128): More capacity, but slower - 64 tokens: Sweet spot for most applications
Experiment to find optimal for your data.
Q: Can tokens attend across samples in a batch?¶
A: Typically no, but possible:
# Standard: Within-sample attention
# Each sample's 64 tokens attend to each other
attn_mask = None # Full attention within sample
# Advanced: Cross-sample attention
# Tokens can attend across samples (rare for gene expression)
# Useful for: batch effects, sample relationships
Q: How do I interpret learned tokens?¶
A: Several approaches:
# 1. Gene loadings
# Which genes contribute most to each token?
encoder_weights = model.encoder[0].weight # (2048, 20000)
# Analyze top genes per token
# 2. Activation patterns
# When is each token active?
z = model.encode(x_batch) # (batch, 64, 256)
token_activations = z.mean(dim=-1) # (batch, 64)
# Correlate with cell types, conditions
# 3. Pathway enrichment
# Do tokens align with known pathways?
# Use GSEA on gene loadings per token
Related Documents¶
- 02a_diffusion_arch_gene_expression.md — Architecture options
- 02_ddpm_training.md — Training strategies
- ../DiT/01_dit_foundations.md — Transformer details
- ../DiT/open_research_tokenization.md — Tokenization deep dive
References¶
Sequence models for biology:
- Theodoris et al. (2023): "Transfer learning enables predictions in network biology" (Geneformer)
- Cui et al. (2024): "scGPT: Toward Building a Foundation Model for Single-Cell Multi-omics"
Transformers and attention:
- Vaswani et al. (2017): "Attention Is All You Need"
- Devlin et al. (2019): "BERT: Pre-training of Deep Bidirectional Transformers"
Latent representations:
- Rombach et al. (2022): "High-Resolution Image Synthesis with Latent Diffusion Models"
- Kingma & Welling (2014): "Auto-Encoding Variational Bayes"