Skip to content

Time Embeddings in Diffusion Transformers: Deep Dive

Related: See Adaptive LayerNorm in diffusion_transformer.md for the broader context of how time embeddings are used in DiT.


Overview

Time embeddings are a crucial component of diffusion models, allowing the network to adapt its behavior based on the current noise level. This document explains:

  1. What time embeddings are
  2. Why they don't "perturb ordering" when passed through MLPs
  3. How they differ from positional embeddings
  4. How they're used in Diffusion Transformers (DiT)

The Core Question

When looking at the AdaLN formulation in DiT:

# From diffusion_transformer.md, lines 118-124
τ = TimeEmbed(t)           # Time embedding
(γ, β) = MLP(τ)            # Pass through MLP
h_modulated = γ · LN(h) + β  # Modulate features

Common confusion: "Doesn't passing the time embedding through an MLP perturb the temporal ordering?"

Answer: No! The ordering is preserved because it's already encoded in the time embedding itself. The MLP learns to transform time information into useful modulation parameters without losing temporal relationships.


Time Embeddings vs Positional Embeddings

These are different concepts that are often confused:

Positional Embeddings (NLP Transformers)

Purpose: Tell the model "where in the sequence am I?"

# Sequence position
tokens = ["The", "cat", "sat"]  # positions 0, 1, 2

# Positional embeddings
pos_embed_0 = sinusoidal_embedding(0)  # "The" is at position 0
pos_embed_1 = sinusoidal_embedding(1)  # "cat" is at position 1
pos_embed_2 = sinusoidal_embedding(2)  # "sat" is at position 2

# Added directly to token embeddings
token_with_pos = token_embed + pos_embed

Use case: Preserving sequence order in transformers (which otherwise treat input as a set)

Time Embeddings (Diffusion Models)

Purpose: Tell the model "how much noise is in the data?"

# Diffusion timestep
t = 500  # Current noise level (0=clean, 1000=pure noise)

# Time embedding
time_embed = sinusoidal_embedding(t)  # Encodes noise level

# Used to CONDITION model behavior (not added to tokens!)
γ, β = MLP(time_embed)
h_modulated = γ * LayerNorm(h) + β

Use case: Adapting model behavior to different noise levels

Key Differences

Aspect Positional Embedding Time Embedding
Encodes Sequence position Noise level
Applied to Each token separately Entire image/data
Integration Added to token embeddings Used for feature modulation
Varies per Token Timestep
Example Token 0 vs Token 1 t=100 vs t=500

How Time Embeddings Work

Step 1: Sinusoidal Encoding

Time embeddings convert a scalar timestep into a high-dimensional vector:

def sinusoidal_time_embedding(t, dim=256):
    """
    Converts scalar timestep t into a high-dimensional vector.

    Key property: Similar timesteps → similar embeddings

    Args:
        t: Timestep(s), shape (batch_size,)
        dim: Embedding dimension (default 256)

    Returns:
        Embedding of shape (batch_size, dim)
    """
    half_dim = dim // 2
    emb = math.log(10000) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim) * -emb)
    emb = t[:, None] * emb[None, :]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
    return emb

Why sinusoidal? Uses multiple frequencies to encode the timestep value smoothly:

# Frequency components
freq_1 = sin(t * ω), cos(t * ω)  # Low frequency (captures large changes)
freq_2 = sin(t * ω), cos(t * ω)  # Medium frequency
...
freq_n = sin(t * ωₙ), cos(t * ωₙ)  # High frequency (captures fine changes)

# Concatenate all frequencies
time_embed = [freq_1, freq_2, ..., freq_n]

Step 2: Ordering Preservation

The key property: similar timesteps have similar embeddings

# Example: Close timesteps
emb_500 = sinusoidal_embedding(500)
emb_501 = sinusoidal_embedding(501)

cosine_similarity(emb_500, emb_501)  0.999  # Very similar!

# Example: Distant timesteps
emb_100 = sinusoidal_embedding(100)
emb_900 = sinusoidal_embedding(900)

cosine_similarity(emb_100, emb_900)  0.2   # Very different!

Visualization:

Timestep Space          Embedding Space
──────────────          ───────────────
t=0    ───────────→     emb₀   = [0.0,  1.0,  0.0,  1.0, ...]
t=100  ───────────→     emb₁₀₀ = [0.9,  0.4, -0.3,  0.8, ...]
t=500  ───────────→     emb₅₀₀ = [0.2, -0.8,  0.4, -0.1, ...]
t=1000 ───────────→     emb₁₀₀₀= [-0.5, 0.8,  0.2, -0.9, ...]

Distance relationships preserved:
dist(0, 100) < dist(0, 500) < dist(0, 1000)
dist(emb₀, emb₁₀₀) < dist(emb₀, emb₅₀₀) < dist(emb₀, emb₁₀₀₀)

Why MLPs Don't "Destroy" Ordering

The MLP's Role

The MLP transforms time embeddings into modulation parameters:

# Time embedding encodes "what time is it"
time_embed = sinusoidal_embedding(t)  # (batch, 256)

# MLP learns "how should I adjust model behavior for this time?"
adaLN_params = MLP(time_embed)  # (batch, 1024)
γ, β = adaLN_params.chunk(2, dim=-1)  # (batch, 512), (batch, 512)

# Modulate features based on time
h_modulated = γ * LayerNorm(h) + β

Why Ordering is Preserved

Reason 1: Input distinguishability

Different timesteps have different embeddings:

emb_100 = [0.9, 0.1, -0.3, ...]   # t=100 → unique vector
emb_500 = [0.2, -0.8, 0.4, ...]   # t=500 → different vector
emb_900 = [-0.5, 0.8, 0.2, ...]   # t=900 → another different vector

Reason 2: Function preserves distinguishability

An MLP maps different inputs to different outputs:

γ₁₀₀, β₁₀₀ = MLP(emb_100)  # Produces one set of parameters
γ₅₀₀, β₅₀₀ = MLP(emb_500)  # Produces different parameters
γ₉₀₀, β₉₀₀ = MLP(emb_900)  # Produces yet different parameters

# As long as emb_100 ≠ emb_500 ≠ emb_900 (which they are!)
# Then γ₁₀₀ ≠ γ₅₀₀ ≠ γ₉₀₀ (assuming the MLP isn't degenerate)

Reason 3: Smooth functions preserve smoothness

MLPs with smooth activations (SiLU, GELU) are continuous:

# If inputs are similar
emb_500  emb_501  # Close timesteps

# Then outputs are similar (by continuity)
MLP(emb_500)  MLP(emb_501)

# Temporal relationships are preserved

What the MLP Actually Learns

The MLP learns time-dependent behavior:

# Early timesteps (high noise, t≈1000)
# Model needs to focus on global structure
γ_early = [0.1, 0.1, 0.1, ...]  # Small scale → suppress details
β_early = [0, 0, 0, ...]        # No shift

# Middle timesteps (medium noise, t≈500)
# Model balances structure and detail
γ_mid = [0.5, 0.8, 0.3, ...]    # Mixed scale
β_mid = [0.1, -0.2, 0.3, ...]   # Some shift

# Late timesteps (low noise, t≈100)
# Model needs to refine fine details
γ_late = [1.2, 1.5, 0.8, ...]   # Large scale → amplify details
β_late = [0.5, -0.3, 0.9, ...]  # Large shift → fine-tune features

The learned mapping:

Time Embedding       →  MLP  →  Modulation Strategy
──────────────          ───     ───────────────────
emb(t=1000, high noise) → γ_suppress_details, β_zero
emb(t=500, mid noise)   → γ_mixed, β_adjust
emb(t=100, low noise)   → γ_amplify_details, β_finetune

Complete Pipeline in DiT

Here's exactly what happens in a DiT block:

class DiTBlock(nn.Module):
    def __init__(self, hidden_dim=512, time_embed_dim=256):
        self.norm = LayerNorm(hidden_dim)
        self.attention = MultiHeadAttention(hidden_dim)

        # MLP to convert time embedding → modulation parameters
        self.adaLN_mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_embed_dim, 2 * hidden_dim),
        )

    def forward(self, h, t):
        """
        Args:
            h: Token features, shape (batch, num_tokens, hidden_dim)
            t: Timesteps, shape (batch,)

        Returns:
            Output features, shape (batch, num_tokens, hidden_dim)
        """
        # ──────────────────────────────────────────────────────
        # STEP 1: Create time embedding (preserves ordering!)
        # ──────────────────────────────────────────────────────
        time_embed = sinusoidal_embedding(t)  # (batch, 256)

        # ──────────────────────────────────────────────────────
        # STEP 2: Transform to modulation parameters
        # ──────────────────────────────────────────────────────
        adaLN_params = self.adaLN_mlp(time_embed)  # (batch, 1024)
        γ, β = adaLN_params.chunk(2, dim=-1)  # 2x (batch, 512)

        # ──────────────────────────────────────────────────────
        # STEP 3: Apply Adaptive LayerNorm
        # ──────────────────────────────────────────────────────
        h_norm = self.norm(h)  # (batch, num_tokens, 512)

        # Broadcast γ and β across tokens
        γ = γ.unsqueeze(1)  # (batch, 1, 512) → broadcast to all tokens
        β = β.unsqueeze(1)  # (batch, 1, 512) → broadcast to all tokens

        # Modulate features based on time
        h_modulated = γ * h_norm + β  # (batch, num_tokens, 512)

        # ──────────────────────────────────────────────────────
        # STEP 4: Standard attention (now time-conditioned!)
        # ──────────────────────────────────────────────────────
        h_out = self.attention(h_modulated)

        return h + h_out  # Residual connection

Flow Diagram

Input: h (token features), t (scalar timestep)
═══════════════════════════════════════════════

t = 500 (scalar)
┌─────────────────────────────────────────────┐
│ Step 1: Sinusoidal Embedding                │
│ (Preserves temporal ordering)               │
└─────────────────────────────────────────────┘
time_embed = [0.2, -0.8, 0.4, ..., 0.1]  (256-dim vector)
┌─────────────────────────────────────────────┐
│ Step 2: MLP Transform                       │
│ (Learns time-dependent behavior)            │
└─────────────────────────────────────────────┘
γ = [1.2, 0.8, 1.5, ...]  (512-dim)
β = [0.1, -0.3, 0.5, ...]  (512-dim)
┌─────────────────────────────────────────────┐
│ Step 3: Modulate Features                   │
│ (Condition on time)                         │
└─────────────────────────────────────────────┘
h_modulated = γ ⊙ LayerNorm(h) + β
┌─────────────────────────────────────────────┐
│ Step 4: Attention & Processing              │
│ (Standard transformer operations)           │
└─────────────────────────────────────────────┘
Output

Key insight: The MLP doesn't operate on your data tokens (h). It operates on the time embedding to produce modulation parameters that affect how the model processes the tokens.


Concrete Example with Numbers

Let's trace through with actual values:

Example 1: High Noise (t=900)

# Input
t = 900

# Step 1: Time embedding
time_embed = sinusoidal_embedding(900)
# → [-0.448, 0.894, 0.732, -0.681, ..., 0.123]  (256 values)

# Step 2: MLP → modulation parameters
γ, β = MLP(time_embed)
# γ → [0.1, 0.1, 0.2, 0.1, ..., 0.1]  # Small values → suppress details
# β → [0.0, 0.0, 0.0, 0.0, ..., 0.0]  # Near zero → no shift

# Step 3: Modulate features
h = [[0.5, -0.2, 0.8, 0.3, ...],   # Token 1
     [0.3, 0.1, -0.4, 0.6, ...],   # Token 2
     ...]

h_modulated = γ * LayerNorm(h) + β
# Result: Features are suppressed (focus on structure, ignore details)

Example 2: Low Noise (t=100)

# Input
t = 100

# Step 1: Time embedding
time_embed = sinusoidal_embedding(100)
# → [0.985, 0.174, -0.342, 0.940, ..., -0.766]  (256 values)
# Note: Very different from t=900 embedding!

# Step 2: MLP → modulation parameters
γ, β = MLP(time_embed)
# γ → [1.5, 1.2, 1.8, 1.3, ..., 1.6]  # Large values → amplify details
# β → [0.3, -0.2, 0.5, -0.4, ..., 0.6]  # Non-zero → fine-tune

# Step 3: Modulate features
h_modulated = γ * LayerNorm(h) + β
# Result: Features are amplified and shifted (refine details)

Ordering is Preserved

# Distance in timestep space
|t_900 - t_100| = 800  # Large distance

# Distance in embedding space
||emb_900 - emb_100|| = 5.23  # Large distance (preserves ordering!)

# Distance in modulation space
||γ_900 - γ_100|| = 12.4  # Large distance (different behaviors!)

The key: The MLP preserves the distinguishability between different timesteps while learning useful transformations.


Why This Design is Better

Alternative 1: Direct Addition (Not Used)

# Add time info directly to tokens (like positional embedding)
h = h + time_embed.unsqueeze(1)  # Broadcast time to all tokens

Problems:

  • ❌ Same time signal added to all features
  • ❌ Can't selectively affect different channels
  • ❌ Less expressive

Alternative 2: Concatenation (Not Used)

# Concatenate time info with features
h = torch.cat([h, time_embed.expand(num_tokens, -1)], dim=-1)

Problems:

  • ❌ Increases dimensionality
  • ❌ Wastes computation
  • ❌ Less flexible than modulation

AdaLN (What DiT Uses) ✓

# Time-dependent scaling and shifting
γ, β = MLP(time_embed)
h = γ * LayerNorm(h) + β

Advantages:

  • ✅ Different features can be affected differently
  • ✅ Learned optimal time-dependent behavior
  • ✅ Powerful: can amplify/suppress features based on noise level
  • ✅ Efficient: no extra dimensions

Comparison: Different Conditioning Methods

Method How it Works Pros Cons
AdaLN (DiT) γ, β = MLP(t)
h = γ·LN(h) + β
✅ Flexible
✅ Powerful
✅ Efficient
Slightly complex
FiLM Same as AdaLN ✅ Same as AdaLN Same as AdaLN
Cross-Attention Attend to time token ✅ Very flexible ❌ Expensive (O(n²))
Addition h = h + emb(t) ✅ Simple ❌ Less expressive
Concatenation h = [h; emb(t)] ✅ Simple ❌ Increases dims

DiT uses AdaLN/FiLM because it offers the best balance of expressiveness and efficiency.


Implementation Tips

1. Time Embedding Dimension

# Common choices
time_embed_dim = 256  # Standard
time_embed_dim = 512  # More capacity
time_embed_dim = 128  # Smaller models

# Rule of thumb: ~1/2 to 1x of hidden_dim

2. MLP Architecture

# Simple (DiT-S)
MLP = nn.Sequential(
    nn.SiLU(),
    nn.Linear(time_embed_dim, 2 * hidden_dim)
)

# With intermediate layer (more capacity)
MLP = nn.Sequential(
    nn.SiLU(),
    nn.Linear(time_embed_dim, 4 * hidden_dim),
    nn.SiLU(),
    nn.Linear(4 * hidden_dim, 2 * hidden_dim)
)

3. Initialization

# Initialize γ to 1, β to 0 for identity at start
def init_adaln_mlp(module):
    if isinstance(module, nn.Linear):
        # Last layer: initialize to produce γ≈1, β≈0
        nn.init.zeros_(module.weight)
        nn.init.zeros_(module.bias)

4. Numerical Stability

# Clip γ to prevent extreme scaling
γ = γ.clamp(min=0.1, max=10.0)

# Or use softer constraints
γ = torch.exp(γ.clamp(min=-2, max=2))  # Ensures γ ∈ [0.14, 7.4]

For Gene Expression: Additional Considerations

When adapting time embeddings for gene expression data:

1. Timestep Range

# Images: Often T=1000
T = 1000

# Gene expression: May want different range
T = 100   # Fewer steps (faster sampling)
T = 1000  # Standard (better quality)

# Continuous time (rectified flow)
t  [0, 1]  # Normalized time

2. Domain-Specific Modulation

# Standard: modulate all features equally
h_modulated = γ * LN(h) + β

# Gene-specific: different modulation per gene group
γ_housekeeping = γ[:, :1000]   # Housekeeping genes
γ_variable = γ[:, 1000:]       # Variable genes

h_modulated[:, :1000] = γ_housekeeping * LN(h[:, :1000]) + β[:, :1000]
h_modulated[:, 1000:] = γ_variable * LN(h[:, 1000:]) + β[:, 1000:]

3. Biological Constraints

# Enforce positive expression after modulation
h_modulated = F.softplus(γ * LN(h) + β)

# Or use log-space
log_h_modulated = γ * LN(log_h) + β
h_modulated = torch.exp(log_h_modulated)

Analogy: Thermostat Control

To build intuition, think of time embeddings like a thermostat:

Time Embedding = Temperature Reading
─────────────────────────────────────
Summer (t=900, high noise): thermometer reads 90°F
  → time_embed = [high values]

Winter (t=100, low noise): thermometer reads 30°F
  → time_embed = [low values]

MLP = Control Logic
───────────────────
[90°F reading] → MLP → (γ=0.1, β=0)
                        "Turn AC on, suppress heating"

[30°F reading] → MLP → (γ=1.5, β=0.5)
                        "Turn heater on, boost warmth"

Result: Adaptive Behavior
─────────────────────────
High noise (90°F): Model focuses on structure (suppress detail features)
Low noise (30°F):  Model refines details (amplify detail features)

The MLP doesn't "lose" the temperature information. It learns: "Given THIS temperature, adjust controls THIS way."

Similarly, the MLP in DiT learns: "Given THIS noise level, modulate features THIS way."


Common Misconceptions

❌ Misconception 1: "MLP destroys temporal ordering"

Reality: The time embedding already encodes the ordering. The MLP just transforms it while preserving distinguishability.

❌ Misconception 2: "Time embeddings are like positional embeddings"

Reality: Different purposes: - Positional: Where in sequence (per token) - Time: How much noise (global)

❌ Misconception 3: "Should add time embedding to features directly"

Reality: Modulation (AdaLN) is more powerful and expressive than addition.

❌ Misconception 4: "Time embedding needs to be high-dimensional"

Reality: 256-dim is usually sufficient. More isn't always better (can lead to overfitting).


Further Reading

Papers

  • DDPM (Ho et al., 2020): Introduced noise prediction with time conditioning
  • Improved DDPM (Nichol & Dhariwal, 2021): Learned noise schedules
  • DiT (Peebles & Xie, 2023): Adaptive LayerNorm conditioning
  • FiLM (Perez et al., 2018): Feature-wise linear modulation

Summary

Time embeddings convert scalar timesteps into high-dimensional vectors that allow diffusion models to adapt their behavior to different noise levels.

Key points:

  1. Purpose: Encode noise level, not sequence position
  2. Method: Sinusoidal encoding preserves temporal ordering
  3. Integration: Used for feature modulation (AdaLN), not addition
  4. MLP role: Learns time-dependent behavior without destroying ordering
  5. Design choice: AdaLN is more expressive than addition or concatenation

The ordering is preserved because: - Sinusoidal encoding makes similar times → similar embeddings - MLP is a smooth, continuous function - Different times produce different modulation parameters

The result: The model learns to denoise differently at different noise levels, from focusing on global structure (high noise) to refining fine details (low noise).


Last Updated: January 12, 2026

Back to: DiT Documentation | Diffusion Models Overview