BEHRT Embeddings as Structured Inductive Bias¶
Topics: Embedding design, age binning, visit vs positional embedding, sinusoidal vs learned, visit ID embedding vs aggregated visit embedding
Reference code: src/ehrsequencing/models/embeddings.py
Table of Contents¶
- Embedding Sum as Model Prior
- Age Embedding: Why Binned, Not Continuous
- Visit Embedding vs Positional Embedding
- Sinusoidal vs Learned Positional Encoding
- Visit ID Embedding vs Aggregated Visit Embedding
- Practical Tuning Notes
Embedding Sum as Model Prior¶
BEHRT uses additive embeddings:
This is not just an implementation detail — it encodes structural assumptions about what matters:
- code identity — which medical concept this token represents
- age context — the patient's biological stage at this code
- visit membership — which clinical encounter this code belongs to
- sequence order — absolute position in the flattened token stream
Each component is a learned lookup table (or fixed sinusoidal function for position). The sum is passed through LayerNorm and dropout before entering the transformer.
Why summation? Additive composition is a standard prior in NLP (BERT uses the same pattern). It assumes the contributions of each axis are approximately independent and additive in embedding space. See dev/models/pretrain_finetune/05_embedding_summation_and_quality_analysis.md for mathematical justification.
Age Embedding: Why Binned, Not Continuous¶
AgeEmbedding discretizes continuous age into bins before embedding:
# Default: 5-year bins
age_bin = age // age_bin_size # e.g., age=67 → bin=13
embedding = self.embedding(age_bin)
Pros: - Robust to noisy or imprecise age values in EHR data - Easier optimization than raw continuous feature injection - Captures coarse life-stage effects (pediatric / adult / elderly) - Consistent with how BEHRT was originally designed
Tradeoff: - Loses within-bin resolution (all ages 65–69 map to the same embedding)
When to adjust: If the outcome hazard is sensitive to narrow age windows (e.g., pediatric dosing thresholds), reduce age_bin_size to 1–2 years. Watch for sparsity in older age bins with small cohorts.
Visit Embedding vs Positional Embedding¶
These two embeddings are not redundant — they encode different axes:
| Embedding | Question answered | Granularity |
|---|---|---|
VisitEmbedding |
"Which clinical encounter does this token belong to?" | Visit-level |
PositionalEmbedding |
"Where is this token in the flattened stream?" | Token-level |
Why both matter:
Two tokens can share the same visit ID but differ in token position (e.g., the 1st vs 3rd code within visit 2). Two tokens can have similar position indices but belong to different visits or different patients.
Removing either weakens temporal structure. Visit embedding helps the transformer learn intra-visit vs inter-visit attention patterns. Positional embedding preserves token-level ordering within a visit.
Sinusoidal vs Learned Positional Encoding¶
BEHRTEmbedding supports both modes:
# Learned (default): position embedding is a trainable nn.Embedding
use_sinusoidal=False
# Fixed: position encoding is a non-trainable sinusoidal buffer
use_sinusoidal=True
With sinusoidal encoding: - Fixed: positional basis (not updated during training) - Learned: code, age, visit embeddings + transformer weights + task heads
When to use sinusoidal: - Stronger generalization on sequence lengths not seen during training - Less positional overfitting on small cohorts - Domain shift across institutions (sinusoidal is more stable)
When to use learned: - Larger cohorts where positional patterns are worth learning - Sequences with consistent length distributions
Visit ID Embedding vs Aggregated Visit Embedding¶
This is the most important conceptual distinction in the BEHRT survival architecture. There are two different things called "visit embedding":
Visit ID Embedding (Input-Side)¶
# In BEHRTEmbedding.forward()
visit_emb = self.visit_embedding(visit_ids) # lookup table
token_emb = code_emb + age_emb + visit_emb + pos_emb
- What it is: A learned lookup table mapping visit index → embedding vector
- When it runs: Before the transformer, as part of input construction
- What it encodes: "This token belongs to visit #3" — purely structural metadata
- Medical content: None — it knows nothing about what codes are in visit #3
Aggregated Visit Embedding (Output-Side)¶
# In BEHRTForSurvival.aggregate_visits()
visit_embeddings = scatter_add(contextualized_code_embeddings, visit_ids)
visit_embeddings /= visit_counts # mean pooling
- What it is: Mean pool of BEHRT-contextualized code embeddings within a visit
- When it runs: After the transformer, as part of hazard prediction
- What it encodes: "Visit #3 contained these conditions, in this patient context"
- Medical content: Rich — it reflects the full bidirectional context from the transformer
The Full Data Flow¶
codes (tokens)
↓
BEHRTEmbedding:
code_emb + age_emb + visit_emb + pos_emb ← visit_emb is structural only
↓
Transformer encoder (self-attention across all codes)
↓
code_embeddings: (B, L, H) ← each code now "knows about" neighboring codes
↓
aggregate_visits (scatter_add mean-pool)
↓
visit_embeddings: (B, V, H) ← content-aware visit representation
↓
hazard head → hazards: (B, V)
VisitEmbedding helps the transformer learn intra-visit vs inter-visit attention patterns during encoding. aggregate_visits produces the actual content-rich visit representation used for hazard prediction. They serve complementary roles at different stages.
Practical Tuning Notes¶
Small cohort (< 2K patients):
- Keep embedding defaults
- Avoid over-parameterizing (large embedding_dim with little data)
- Use sinusoidal positional encoding to reduce overfitting
Larger cohort (> 10K patients):
- Learned positional embeddings often pay off
- Consider increasing embedding_dim if vocabulary is large
Domain shift across institutions: - Sinusoidal positional encoding is more stable - Consider freezing age and visit embeddings if source/target age distributions differ
See also: 01a_visit_embeddings.md for a deep treatment of the visit ID vs aggregated visit embedding distinction, including scatter_add mechanics and gradient flow analysis.
Next: 04_pretraining_objectives.md — what MLM and Next-Visit Prediction teach the encoder.