Survival Analysis with EHR Sequencing¶
Discrete-time survival analysis for predicting time-to-event outcomes from Electronic Health Records.
Overview¶
EHR Sequencing provides tools for survival analysis on patient sequences, enabling prediction of:
- Hospital Readmission - 30-day readmission risk
- Mortality - In-hospital, 30-day, and 1-year mortality
- Disease Onset - Time to diabetes, heart failure, stroke, etc.
- Treatment Response - Time to treatment efficacy or failure
Models¶
LSTM Baseline¶
Recurrent neural network for survival analysis.
Architecture:
Features: - Learns embeddings from scratch - Unidirectional (causal) processing - Fast training, good baseline performance
Usage:
from ehrsequencing.models.survival_lstm import DiscreteTimeSurvivalLSTM
model = DiscreteTimeSurvivalLSTM(
vocab_size=5000,
embedding_dim=128,
hidden_dim=256,
num_layers=2,
dropout=0.3
)
Training:
python examples/survival_analysis/train_lstm.py \
--data-dir data/synthea/ \
--outcome synthetic \
--epochs 100 \
--batch-size 64
BEHRT for Survival (Coming Soon)¶
Transformer-based model with pre-trained representations.
Architecture:
Features: - Pre-trained with Masked Language Modeling (MLM) - Bidirectional context within visits - EHR-specific embeddings (age, visit, segment) - Transfer learning from large unlabeled data
Expected Advantages: - 10-20% improvement in C-index over LSTM - Faster convergence with pre-training - Better generalization (smaller train-val gap) - Richer representations for downstream tasks
Usage (Planned):
from ehrsequencing.models.behrt_survival import BEHRTForSurvival
from ehrsequencing.models.behrt import BEHRTForMLM
# Load pre-trained BEHRT
pretrained = BEHRTForMLM.from_pretrained('checkpoints/behrt_mlm/')
# Create survival model
model = BEHRTForSurvival(
behrt_encoder=pretrained.behrt,
hidden_dim=256,
dropout=0.1
)
# Fine-tune for survival
# Option 1: Freeze encoder, train only head
for param in model.behrt.parameters():
param.requires_grad = False
# Option 2: LoRA fine-tuning (efficient)
from ehrsequencing.models.lora import apply_lora_to_behrt
model.behrt = apply_lora_to_behrt(model.behrt, rank=16)
# Option 3: Full fine-tuning
# All parameters trainable (slower, potentially better)
Discrete-Time Survival Framework¶
Hazard Function¶
At each visit \(t\), the model predicts a hazard:
The hazard represents the probability of an event occurring at time \(t\), given survival to time \(t\).
Survival Function¶
The survival probability is computed from hazards:
Loss Function¶
Negative log-likelihood for discrete-time survival:
where \(\delta_i\) is the event indicator (1 if event occurred, 0 if censored).
Evaluation Metrics¶
Concordance Index (C-index)¶
Measures discrimination - the probability that a model correctly orders pairs of patients by risk.
- C-index = 0.5: Random predictions
- C-index = 1.0: Perfect discrimination
- C-index > 0.7: Good performance
Interpretation: - C-index of 0.75 means the model correctly ranks 75% of patient pairs
Calibration¶
Measures agreement between predicted and observed event rates.
- Calibration plot: Predicted risk vs observed frequency
- Brier score: Mean squared error of predictions
- Well-calibrated: Predictions match reality
Time-Dependent AUC¶
AUC at specific time horizons (e.g., 7, 14, 30 days).
Useful for evaluating early vs late prediction performance.
Clinical Use Cases¶
1. Hospital Readmission Prediction¶
Goal: Identify patients at high risk of 30-day readmission
Clinical Value: - Reduce readmission rates (CMS penalty avoidance) - Target interventions (follow-up calls, home health) - Optimize discharge planning
Metrics: - C-index for discrimination - Calibration at 7, 14, 30 days - Cost-benefit analysis (prevented readmissions)
Example:
python examples/survival_analysis/train_lstm.py \
--data-dir data/synthea/ \
--outcome readmission \
--time-horizon 30 \
--epochs 100
2. Mortality Risk Prediction¶
Goal: Estimate survival probability for ICU/ED triage
Clinical Value: - ICU resource allocation - Goals-of-care discussions - Clinical trial enrollment criteria
Time Horizons: - In-hospital mortality - 30-day mortality - 1-year mortality
Example:
python examples/survival_analysis/train_lstm.py \
--data-dir data/synthea/ \
--outcome mortality \
--time-horizon 365 \
--epochs 100
3. Disease Onset Prediction¶
Goal: Predict time to disease development for preventive care
Clinical Value: - Identify pre-diabetic patients - Predict cardiovascular events - Screen for cancer risk
Target Diseases: - Diabetes (high prevalence, preventable) - Heart failure (high cost, manageable) - Stroke (severe outcome, preventable)
Example:
python examples/survival_analysis/train_lstm.py \
--data-dir data/synthea/ \
--outcome disease_onset \
--target-disease diabetes \
--time-horizon 365 \
--epochs 100
Data Requirements¶
Input Format¶
Patient sequences with visit-level structure:
{
'patient_id': '12345',
'visits': [
{
'visit_date': '2020-01-15',
'codes': ['LOINC:4548-4', 'SNOMED:44054006'],
'age': 45.2
},
{
'visit_date': '2020-06-15',
'codes': ['LOINC:2339-0', 'RXNORM:860975'],
'age': 45.6
}
],
'outcome': {
'event_time': 3, # Event at visit 3
'event_indicator': 1, # 1 = event occurred, 0 = censored
'event_type': 'readmission'
}
}
Synthetic Data Generation¶
For development and testing:
from ehrsequencing.synthetic.survival import DiscreteTimeSurvivalGenerator
generator = DiscreteTimeSurvivalGenerator(
risk_correlation=-0.5, # Negative correlation (higher risk → earlier event)
censoring_rate=0.3,
time_scale=10
)
# Generate outcomes for patient sequences
outcomes = generator.generate_outcomes(patient_sequences)
Training Pipeline¶
1. Data Preparation¶
from ehrsequencing.data import SyntheaAdapter, VisitGrouper, PatientSequenceBuilder
# Load data
adapter = SyntheaAdapter('data/synthea/')
patients = adapter.load_patients()
events = adapter.load_events()
# Group into visits
grouper = VisitGrouper(strategy='hybrid')
visits = grouper.group_by_patient(events)
# Build sequences
builder = PatientSequenceBuilder()
sequences = builder.build_sequences(visits)
2. Model Training¶
from ehrsequencing.models.survival_lstm import DiscreteTimeSurvivalLSTM
from ehrsequencing.models.losses import DiscreteTimeSurvivalLoss
model = DiscreteTimeSurvivalLSTM(vocab_size=5000)
loss_fn = DiscreteTimeSurvivalLoss()
# Training loop
for epoch in range(epochs):
for batch in train_loader:
hazards = model(batch['codes'], batch['visit_mask'], batch['sequence_mask'])
loss = loss_fn(hazards, batch['event_time'], batch['event_indicator'])
loss.backward()
optimizer.step()
3. Evaluation¶
from ehrsequencing.models.losses import concordance_index
# Compute C-index
c_index = concordance_index(
risk_scores=predicted_risks,
event_times=true_event_times,
event_indicators=true_event_indicators
)
print(f"C-index: {c_index:.4f}")
Best Practices¶
Model Selection¶
- Start with LSTM baseline - Fast, interpretable, good performance
- Try BEHRT if available - Pre-training may improve performance
- Compare multiple approaches - Use benchmarking framework
Hyperparameter Tuning¶
- Embedding dimension: 128-256 for LSTM, pre-trained for BEHRT
- Hidden dimension: 256-512
- Dropout: 0.2-0.4 (higher for small datasets)
- Learning rate: 1e-3 for LSTM, 1e-5 for BEHRT fine-tuning
- Early stopping: Patience 10-20 epochs
Avoiding Overfitting¶
- Use dropout and weight decay
- Early stopping on validation C-index
- Cross-validation for small datasets
- Regularization (L2, LoRA for BEHRT)
Clinical Validation¶
- Calibration plots (predicted vs observed)
- Subgroup analysis (age, gender, comorbidities)
- Temporal validation (train on old data, test on new)
- External validation (different hospital/dataset)
References¶
Papers¶
- Discrete-Time Survival: Tutz & Schmid (2016). "Modeling Discrete Time-to-Event Data"
- C-index: Harrell et al. (1982). "Evaluating the Yield of Medical Tests"
- BEHRT: Li et al. (2020). "BEHRT: Transformer for Electronic Health Records"
Code Examples¶
examples/survival_analysis/train_lstm.py- LSTM training scriptexamples/survival_analysis/train_lstm_demo.py- Quick demonotebooks/01_discrete_time_survival_lstm.ipynb- Interactive tutorial
Support¶
For questions or issues: - Check the tutorials - Review example scripts - See benchmarking guide for model comparison
Status: LSTM baseline complete | BEHRT survival in development
Updated: February 2026