Tutorial 3: Loss Function Formulation¶
Part of: Discrete-Time Survival Analysis for EHR Sequences
Audience: Researchers implementing survival models with deep learning
Table of Contents¶
- Overview
- Discrete-Time Survival Framework
- Likelihood Formulation
- Implementation Details
- Training Considerations
- Common Issues and Solutions
Overview¶
The Goal¶
Train a neural network to predict hazard at each time point:
Where: - \(T\): Time of event - \(h_t\): Probability of event at time \(t\) given survival to \(t\) - \(H_t\): Patient history up to time \(t\)
Why Not Binary Cross-Entropy?¶
Binary classification approach (WRONG):
Problems: 1. Ignores survival information (patient survived to this visit) 2. Doesn't handle censoring properly 3. Treats all time points independently
Survival approach (CORRECT):
Benefits: 1. Uses survival information from all visits before event 2. Handles censoring naturally 3. Respects temporal dependencies
Discrete-Time Survival Framework¶
Hazard Function¶
Definition: Probability of event at time \(t\) given survival to \(t\)
Properties: - \(0 \leq h_t \leq 1\) (it's a probability) - Can vary arbitrarily over time (no parametric assumptions) - Predicted by neural network with sigmoid activation
Example:
# Patient trajectory
h_1 = 0.05 # Low risk at visit 1
h_2 = 0.08 # Slightly higher at visit 2
h_3 = 0.15 # Increasing risk at visit 3
h_4 = 0.40 # High risk at visit 4 (event occurs)
Survival Function¶
Definition: Probability of surviving past time \(t\)
Interpretation: - Survival = not having event at any prior time - Product of \((1 - h_i)\) for all times up to \(t\)
Example:
S(1) = (1 - h_1) = 0.95
S(2) = (1 - h_1)(1 - h_2) = 0.95 × 0.92 = 0.874
S(3) = (1 - h_1)(1 - h_2)(1 - h_3) = 0.874 × 0.85 = 0.743
Probability Mass Function¶
Definition: Probability of event exactly at time \(t\)
Interpretation: - Survive to \(t-1\): \(\prod_{i=1}^{t-1} (1 - h_i)\) - Then have event at \(t\): \(h_t\)
Example:
# Probability of event at visit 3
P(T = 3) = S(2) × h_3
= (1 - h_1)(1 - h_2) × h_3
= 0.95 × 0.92 × 0.15
= 0.131
Likelihood Formulation¶
For a Single Patient¶
Case 1: Event Observed (δ = 1)
Patient has event at time \(T\):
Interpretation: - Survive through visits \(1, 2, \ldots, T-1\): \(\prod_{t=1}^{T-1} (1 - h_t)\) - Have event at visit \(T\): \(h_T\)
Log-likelihood: $\(\log L = \sum_{t=1}^{T-1} \log(1 - h_t) + \log(h_T)\)$
Case 2: Censored (δ = 0)
Patient is censored at time \(T\) (no event observed):
Interpretation: - Survive through all observed visits: \(\prod_{t=1}^{T} (1 - h_t)\) - We don't know what happens after
Log-likelihood: $\(\log L = \sum_{t=1}^{T} \log(1 - h_t)\)$
Combined Formulation¶
For any patient:
Where: - First term: Survival contribution (all patients) - Second term: Event contribution (only if \(\delta = 1\))
Batch Loss:
Implementation Details¶
PyTorch Implementation¶
class DiscreteTimeSurvivalLoss(nn.Module):
def __init__(self, eps=1e-7):
super().__init__()
self.eps = eps # Numerical stability
def forward(self, hazards, event_times, event_indicators, sequence_mask):
"""
Args:
hazards: [batch_size, max_visits] in (0, 1)
event_times: [batch_size] - index of event/censoring
event_indicators: [batch_size] - 1 if event, 0 if censored
sequence_mask: [batch_size, max_visits] - valid visits
Returns:
Scalar loss (negative log-likelihood)
"""
batch_size, max_visits = hazards.shape
# Clamp hazards for numerical stability
hazards = torch.clamp(hazards, min=self.eps, max=1 - self.eps)
# Create time index tensor
time_idx = torch.arange(max_visits, device=hazards.device).unsqueeze(0)
event_times_expanded = event_times.unsqueeze(1)
# Mask for visits before event/censoring
before_event_mask = (time_idx < event_times_expanded).float() * sequence_mask
# Mask for event visit
at_event_mask = (time_idx == event_times_expanded).float() * sequence_mask
# Survival log-likelihood: sum of log(1 - h_t) for t < T
survival_ll = torch.sum(
torch.log(1 - hazards) * before_event_mask,
dim=1
)
# Event log-likelihood: log(h_T) if event occurred
event_ll = torch.sum(
torch.log(hazards) * at_event_mask,
dim=1
) * event_indicators.float()
# Total log-likelihood per patient
log_likelihood = survival_ll + event_ll
# Return negative log-likelihood (to minimize)
return -torch.mean(log_likelihood)
Step-by-Step Example¶
Patient data:
hazards = [0.05, 0.08, 0.15, 0.40, 0.0] # Padded to max_visits=5
event_time = 3 # Event at visit 3 (0-indexed)
event_indicator = 1 # Event observed
sequence_mask = [1, 1, 1, 1, 0] # 4 valid visits
Step 1: Create masks
time_idx = [0, 1, 2, 3, 4]
event_time_expanded = 3
before_event_mask = [1, 1, 1, 0, 0] # Visits 0, 1, 2
at_event_mask = [0, 0, 0, 1, 0] # Visit 3
Step 2: Survival log-likelihood
survival_ll = log(1 - 0.05) + log(1 - 0.08) + log(1 - 0.15)
= log(0.95) + log(0.92) + log(0.85)
= -0.051 - 0.083 - 0.163
= -0.297
Step 3: Event log-likelihood
Step 4: Total log-likelihood
Step 5: Loss (negative log-likelihood)
Numerical Stability¶
Problem: log(0) is undefined
Solution: Clamp hazards
eps = 1e-7
hazards = torch.clamp(hazards, min=eps, max=1 - eps)
# Now:
# log(1 - hazards) is safe (1 - hazards >= eps)
# log(hazards) is safe (hazards >= eps)
Why this works:
- eps = 1e-7 is tiny (0.0000001)
- Doesn't affect predictions (hazards are typically 0.01-0.9)
- Prevents numerical overflow/underflow
Training Considerations¶
1. Masking for Variable-Length Sequences¶
Problem: Patients have different numbers of visits
Solution: Use sequence mask
# Only compute loss for valid visits
survival_ll = torch.sum(
torch.log(1 - hazards) * before_event_mask * sequence_mask,
dim=1
)
Example:
# Patient 1: 10 visits
sequence_mask[0] = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, ...]
# Patient 2: 5 visits
sequence_mask[1] = [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, ...]
# Loss only computed for valid visits (mask = 1)
2. Handling Censoring¶
Censored patient (δ = 0):
Observed event (δ = 1):
Key insight: Censored patients still provide information through survival term!
3. Gradient Flow¶
Survival term gradient:
- Encourages low hazard before event - Stronger gradient when \(h_t\) is high (bad prediction)Event term gradient:
- Encourages high hazard at event time - Stronger gradient when \(h_T\) is low (bad prediction)Combined effect: - Model learns to predict low hazard early - High hazard at event time - Smooth transition between them
4. Batch Size Considerations¶
Small batches (16-32): - Faster iterations - More noise in gradient - Better for small datasets
Large batches (64-128): - Stable gradients - Slower iterations - Better for large datasets
Recommendation: Start with 32, adjust based on dataset size
Common Issues and Solutions¶
Issue 1: Loss Not Decreasing¶
Symptoms: - Loss stays constant or increases - Model predicts same hazard for all patients
Possible causes:
-
Learning rate too high
-
Numerical instability
-
Weak synthetic data correlation
Issue 2: Exploding Gradients¶
Symptoms: - Loss becomes NaN - Gradients become very large
Solutions:
-
Gradient clipping
-
Check for log(0)
-
Reduce learning rate
Issue 3: Overfitting¶
Symptoms: - Training loss decreases, validation loss increases - Large gap between train and validation C-index
Solutions:
-
Dropout
-
L2 regularization
-
Early stopping
Issue 4: Poor C-index Despite Low Loss¶
Symptoms: - Loss decreases normally - C-index stays around 0.5 (random)
Possible causes:
-
Length bias in risk score
-
Weak correlation in synthetic data
-
Model not learning from data
Comparison with Other Losses¶
Binary Cross-Entropy (WRONG)¶
# Treats each visit independently
loss = BCE(hazards, labels)
# Problems:
# - Doesn't use survival information
# - Doesn't handle censoring
# - Ignores temporal structure
Mean Squared Error (WRONG)¶
# Regression on event time
loss = MSE(predicted_time, actual_time)
# Problems:
# - Doesn't handle censoring
# - Assumes Gaussian errors (wrong for time-to-event)
# - Doesn't model hazard process
Discrete-Time Survival (CORRECT)¶
# Models survival process explicitly
loss = -log_likelihood(hazards, event_times, event_indicators)
# Benefits:
# - Uses all survival information
# - Handles censoring naturally
# - Respects temporal dependencies
# - Theoretically grounded
Key Takeaways¶
- Discrete-time survival loss models the survival process explicitly
- Two components: Survival term (all patients) + Event term (observed events only)
- Handles censoring naturally through likelihood formulation
- Numerical stability requires epsilon clamping
- Masking is essential for variable-length sequences
- C-index evaluation requires careful risk score formulation to avoid length bias
Further Reading¶
Theory¶
- Singer & Willett (2003). Applied Longitudinal Data Analysis - Chapter 10
- Tutz & Schmid (2016). Modeling Discrete Time-to-Event Data
Implementation¶
- PyHealth:
pyhealth.models.DeepSurv - PyCox:
pycox.models.LogisticHazard
Related Approaches¶
- DeepSurv (continuous-time with Cox model)
- DeepHit (competing risks)
- DRSA (deep recurrent survival analysis)
Previous Tutorial: Synthetic Data Design
Notebook: 01_discrete_time_survival_lstm.ipynb