Data Shape Transformations: From Raw EHR to LSTM Input¶
This document provides a comprehensive reference for all data shape transformations in the EHR sequence modeling pipeline, from raw Synthea CSV files to LSTM model predictions.
Prediction Task: Binary classification - predicting diabetes diagnosis from patient EHR sequences.
Table of Contents¶
- Overview
- Stage-by-Stage Transformations
- Detailed Shape Specifications
- Memory Considerations
- Common Pitfalls
Overview¶
Pipeline Summary¶
Raw CSV Files
↓ SyntheaAdapter.load_events()
List[MedicalEvent]
↓ VisitGrouper.group_events()
Dict[str, List[Visit]]
↓ PatientSequenceBuilder.build_sequences()
List[PatientSequence]
↓ PatientSequenceBuilder.encode_sequence()
Encoded Sequences (with padding)
↓ Add labels + collate_fn()
Batched PyTorch Tensors
↓ LSTM Model forward()
Predictions
Key Dimensions¶
| Symbol | Meaning | Typical Value |
|---|---|---|
N |
Number of patients | 50-10000 |
V |
Number of visits per patient | 2-100 |
C |
Number of codes per visit | 1-50 |
B |
Batch size | 16-128 |
E |
Embedding dimension | 128-512 |
H |
Hidden dimension | 256-1024 |
vocab_size |
Vocabulary size | 500-5000 |
Stage-by-Stage Transformations¶
Stage 1: Raw CSV Files → Medical Events¶
Input: CSV files on disk
- patients.csv
- encounters.csv
- conditions.csv
- observations.csv
- medications.csv
- procedures.csv
Transformation: SyntheaAdapter.load_events(patient_ids)
Output: List[MedicalEvent]
# Data structure
MedicalEvent(
patient_id: str, # UUID
timestamp: datetime, # Timezone-naive
code: str, # Medical code (SNOMED-CT, LOINC, RxNorm, etc.)
code_type: str, # 'diagnosis', 'lab', 'medication', 'procedure'
encounter_id: Optional[str],
metadata: Optional[Dict]
)
# Shape
List with length = total number of events across all patients
Example: 4,372 events for 50 patients
Key Points: - Events are sorted by timestamp - All timestamps are timezone-naive for consistency - Each event represents a single medical code occurrence
Stage 2: Medical Events → Visit Groups¶
Input: List[MedicalEvent]
Transformation: VisitGrouper.group_events(events, patient_id)
Output: Dict[str, List[Visit]]
# Data structure
{
patient_id_1: [Visit_1, Visit_2, ..., Visit_n1],
patient_id_2: [Visit_1, Visit_2, ..., Visit_n2],
...
}
Visit(
visit_id: str,
patient_id: str,
timestamp: datetime,
encounter_id: Optional[str],
codes_by_type: Dict[str, List[str]], # e.g., {'diagnosis': [...], 'lab': [...]}
codes_flat: List[str],
metadata: Optional[Dict]
)
# Shape
Dict with:
- Keys: N patient IDs (strings)
- Values: Lists of Visit objects
- Total visits: sum(len(visits) for visits in dict.values())
Example:
- 50 patients
- 421 total visits
- Average 8.4 visits per patient
Key Points:
- Visits group events that occur within the same encounter or time window
- codes_by_type preserves semantic structure
- codes_flat provides simple list for basic models
- Visit timestamps represent the start of the visit
Stage 3: Visit Groups → Patient Sequences¶
Input: Dict[str, List[Visit]]
Transformation: PatientSequenceBuilder.build_sequences(patient_visits, min_visits=2)
Output: List[PatientSequence]
# Data structure
PatientSequence(
patient_id: str,
visits: List[Visit],
sequence_length: int,
metadata: Optional[Dict]
)
# Shape
List with length = number of patients with >= min_visits
Example: 42 sequences (filtered from 50 patients)
# Each sequence contains:
- visits: List[Visit] with length = sequence_length
- sequence_length: int (number of visits)
Key Points: - Filters out patients with insufficient visits - Visits are chronologically ordered - Sequences can have variable length - No padding at this stage
Stage 4: Patient Sequences → Encoded Sequences¶
Input: PatientSequence
Transformation: PatientSequenceBuilder.encode_sequence(sequence, return_tensors=False)
Output: Dict[str, Any]
# Data structure
{
'patient_id': str,
'visit_codes': List[List[int]], # [num_visits, max_codes_per_visit]
'visit_mask': List[List[int]], # [num_visits, max_codes_per_visit]
'sequence_mask': List[int], # [num_visits]
'time_deltas': List[float], # [num_visits - 1]
'sequence_length': int
}
# Shape details
visit_codes: [V, C]
- V = min(sequence.sequence_length, max_visits)
- C = max_codes_per_visit
- Values: integer IDs from vocabulary (0 = [PAD], 1 = [UNK], 2+ = codes)
- Padded with 0s
visit_mask: [V, C]
- Same shape as visit_codes
- Values: 1 for real codes, 0 for padding
- Used to ignore padding in aggregation
sequence_mask: [V]
- Values: 1 for real visits, 0 for padding visits
- Used to handle variable-length sequences in LSTM
time_deltas: [V-1]
- Time between consecutive visits in days
- Padded with 0.0 for padding visits
# Example with max_visits=50, max_codes_per_visit=100
visit_codes: [50, 100] # 5,000 integers
visit_mask: [50, 100] # 5,000 integers
sequence_mask: [50] # 50 integers
time_deltas: [49] # 49 floats
Key Points:
- Codes converted from strings to integer IDs via vocabulary
- Padding applied to standardize dimensions
- Masks track real vs. padded data
- Most recent visits kept if sequence exceeds max_visits
Stage 5: Encoded Sequences → Labeled Dataset¶
Input: List[PatientSequence]
Transformation: Add labels based on prediction task
Output: List[Dict]
# Data structure
[
{
'patient_id': str,
'visit_codes': List[List[int]],
'visit_mask': List[List[int]],
'sequence_mask': List[int],
'time_deltas': List[float],
'label': int # 0 or 1 for binary classification
},
...
]
# Shape
List with length = number of sequences
Each item is a dict with encoded sequence + label
# Label creation (diabetes example)
diabetes_codes = {'44054006', '46635009', '73211009', ...}
label = 1 if any code in diabetes_codes appears in any visit else 0
# Example label distribution
Total sequences: 42
Positive (has diabetes): 8 (19.0%)
Negative (no diabetes): 34 (81.0%)
Key Points: - Labels derived from medical codes in the sequence - For diabetes: check if any visit contains diabetes diagnosis code - Other tasks: readmission (time-based), mortality (death_date), etc. - Labels can be binary, multi-class, or continuous
Stage 6: Labeled Dataset → Batched Tensors¶
Input: List[Dict] (dataset items)
Transformation: collate_fn(batch) in DataLoader
Output: Dict[str, torch.Tensor]
# Data structure
{
'visit_codes': torch.Tensor, # [B, V_max, C_max]
'visit_mask': torch.Tensor, # [B, V_max, C_max]
'sequence_mask': torch.Tensor, # [B, V_max]
'labels': torch.Tensor # [B, 1]
}
# Shape details
visit_codes: [B, V_max, C_max]
- B = batch_size (e.g., 32)
- V_max = max number of visits in batch
- C_max = max number of codes per visit in batch
- dtype: torch.long
- Values: 0 (padding) or vocabulary IDs
visit_mask: [B, V_max, C_max]
- Same shape as visit_codes
- dtype: torch.bool
- Values: True for real codes, False for padding
sequence_mask: [B, V_max]
- dtype: torch.bool
- Values: True for real visits, False for padding
labels: [B, 1]
- dtype: torch.float32
- Values: 0.0 or 1.0 for binary classification
# Example with batch_size=32
visit_codes: [32, 45, 87] # 125,280 values
visit_mask: [32, 45, 87] # 125,280 values
sequence_mask: [32, 45] # 1,440 values
labels: [32, 1] # 32 values
# Memory footprint (batch_size=32)
visit_codes: ~977 KB (int64)
visit_mask: ~122 KB (bool)
sequence_mask: ~1.4 KB (bool)
labels: ~0.1 KB (float32)
Total: ~1.1 MB per batch
Key Points:
- Dynamic padding: V_max and C_max determined by batch contents
- Efficient packing: only pad to max in current batch, not global max
- Masks essential for proper gradient computation
- DataLoader handles batching automatically
Stage 7: Batched Tensors → Model Output¶
Input: Batched tensors from Stage 6
Transformation: model.forward(visit_codes, visit_mask, sequence_mask)
Output: Dict[str, torch.Tensor]
# Model architecture flow
visit_codes [B, V, C]
↓ Embedding layer
code_embeddings [B, V, C, E]
↓ Visit encoder (aggregation)
visit_vectors [B, V, E]
↓ LSTM
lstm_output [B, V, H]
↓ Take final hidden state
final_hidden [B, H]
↓ Linear + activation
predictions [B, 1]
# Output structure
{
'logits': torch.Tensor, # [B, 1]
'predictions': torch.Tensor, # [B, 1]
'hidden_states': torch.Tensor # [B, V, H] (if return_hidden=True)
}
# Shape details
logits: [B, 1]
- dtype: torch.float32
- Raw predictions before activation
- Range: (-∞, +∞)
predictions: [B, 1]
- dtype: torch.float32
- After sigmoid activation
- Range: (0, 1) - interpreted as probabilities
- P(patient has diabetes)
hidden_states: [B, V, H]
- dtype: torch.float32
- LSTM hidden states for each visit
- Can be used for attention, interpretability, etc.
# Example with batch_size=32, hidden_dim=256
logits: [32, 1] # 32 values
predictions: [32, 1] # 32 values (probabilities)
hidden_states: [32, 45, 256] # 368,640 values
# Memory footprint
hidden_states: ~1.4 MB (float32)
Key Points: - Embedding layer maps code IDs to dense vectors - Visit encoder aggregates codes within each visit (mean/sum/attention) - LSTM captures temporal dependencies across visits - Final prediction from last hidden state - Sigmoid activation for binary classification
Detailed Shape Specifications¶
Vocabulary¶
vocab: Dict[str, int]
- Keys: Medical codes (strings)
- Values: Integer IDs
- Special tokens:
• [PAD]: 0
• [UNK]: 1
• [MASK]: 2
• [CLS]: 3
• [SEP]: 4
- Regular codes: 5, 6, 7, ...
Example:
{
'[PAD]': 0,
'[UNK]': 1,
'[MASK]': 2,
'[CLS]': 3,
'[SEP]': 4,
'44054006': 5, # Type 2 diabetes
'8302-2': 6, # Body height
...
}
Typical size: 500-5000 codes
Visit Object¶
Visit:
- visit_id: str (UUID)
- patient_id: str (UUID)
- timestamp: datetime (timezone-naive)
- encounter_id: Optional[str]
- codes_by_type: Dict[str, List[str]]
• Keys: 'diagnosis', 'lab', 'medication', 'procedure'
• Values: Lists of code strings
- codes_flat: List[str]
• Flattened list of all codes
- metadata: Optional[Dict]
Methods:
- num_codes() -> int
- get_all_codes() -> List[str]
- get_ordered_codes(type_order) -> List[str]
Example:
Visit(
visit_id='abc-123',
patient_id='patient-456',
timestamp=datetime(2024, 5, 5),
codes_by_type={
'diagnosis': ['44054006', '73211009'],
'lab': ['8302-2', '29463-7', '8867-4'],
'medication': ['197361']
},
codes_flat=['44054006', '73211009', '8302-2', '29463-7', '8867-4', '197361']
)
PatientSequence Object¶
PatientSequence:
- patient_id: str
- visits: List[Visit]
- sequence_length: int (len(visits))
- metadata: Optional[Dict]
Methods:
- get_code_sequence(use_semantic_order: bool) -> List[List[str]]
- get_flat_code_sequence() -> List[str]
- get_time_deltas() -> List[float]
Example:
PatientSequence(
patient_id='patient-456',
visits=[Visit_1, Visit_2, ..., Visit_10],
sequence_length=10,
metadata={'age': 45, 'gender': 'M'}
)
Memory Considerations¶
Per-Sequence Memory¶
For a single encoded sequence with max_visits=50, max_codes_per_visit=100:
visit_codes: 50 × 100 × 8 bytes (int64) = 40 KB
visit_mask: 50 × 100 × 1 byte (bool) = 5 KB
sequence_mask: 50 × 1 byte (bool) = 50 bytes
time_deltas: 49 × 4 bytes (float32) = 196 bytes
Total: ≈ 45 KB per sequence
Per-Batch Memory¶
For a batch of 32 sequences:
visit_codes: 32 × 50 × 100 × 8 bytes = 1.28 MB
visit_mask: 32 × 50 × 100 × 1 byte = 160 KB
sequence_mask: 32 × 50 × 1 byte = 1.6 KB
labels: 32 × 1 × 4 bytes = 128 bytes
Total: ≈ 1.44 MB per batch
Model Memory¶
For LSTM baseline (small) with vocab_size=1000, embedding_dim=128, hidden_dim=256:
Embedding: 1000 × 128 × 4 bytes = 512 KB
LSTM: ~500K parameters = 2 MB
Linear: 256 × 1 × 4 bytes = 1 KB
Total parameters: ≈ 2.5 MB
Forward pass (batch_size=32):
- Embeddings: 32 × 50 × 100 × 128 = 20.48 MB
- Visit vectors: 32 × 50 × 128 = 819 KB
- LSTM hidden: 32 × 50 × 256 = 1.64 MB
- Gradients: ~2× forward pass = 45 MB
Total: ≈ 70 MB per batch
Scaling Considerations¶
| Dataset Size | Sequences | Batches (B=32) | Memory | Training Time |
|---|---|---|---|---|
| Small | 100 | 4 | ~6 MB | Minutes |
| Medium | 1,000 | 32 | ~50 MB | Hours |
| Large | 10,000 | 313 | ~450 MB | Days |
| Very Large | 100,000 | 3,125 | ~4.5 GB | Weeks |
Recommendations:
- Use max_visits=50 and max_codes_per_visit=100 for most tasks
- Reduce dimensions if memory-constrained
- Use gradient accumulation for larger effective batch sizes
- Consider mixed precision training (fp16) to reduce memory by 50%
Common Pitfalls¶
1. Forgetting Masks¶
Problem: Not using masks leads to incorrect aggregation and gradient computation.
# ❌ Wrong - includes padding in mean
visit_vector = code_embeddings.mean(dim=1)
# ✅ Correct - masks out padding
masked_embeddings = code_embeddings * visit_mask.unsqueeze(-1)
visit_vector = masked_embeddings.sum(dim=1) / visit_mask.sum(dim=1, keepdim=True).clamp(min=1)
2. Incorrect Padding Direction¶
Problem: Padding at the beginning instead of the end.
# ❌ Wrong - pads at start (shifts temporal order)
padded = [PAD, PAD, code1, code2, code3]
# ✅ Correct - pads at end (preserves temporal order)
padded = [code1, code2, code3, PAD, PAD]
3. Timezone Issues¶
Problem: Mixing timezone-aware and timezone-naive timestamps.
# ❌ Wrong - causes comparison errors
timestamp1 = pd.to_datetime('2024-01-01') # naive
timestamp2 = pd.to_datetime('2024-01-01').tz_localize('UTC') # aware
# ✅ Correct - all timestamps naive
timestamp = pd.to_datetime(row['DATE']).tz_localize(None)
4. Dictionary vs. Dataclass Access¶
Problem: Treating dataclasses as dictionaries.
# ❌ Wrong - PatientSequence is a dataclass
patient_id = sequence['patient_id']
# ✅ Correct - use attribute access
patient_id = sequence.patient_id
5. Batch Dimension Confusion¶
Problem: Forgetting batch dimension in reshaping.
# ❌ Wrong - loses batch structure
embeddings = embeddings.view(-1, embedding_dim)
# ✅ Correct - preserves batch
embeddings = embeddings.view(batch_size, num_visits, -1, embedding_dim)
6. Variable Length Handling¶
Problem: Not using packed sequences for efficiency.
# ❌ Inefficient - processes padding
lstm_output, _ = lstm(visit_vectors)
# ✅ Efficient - skips padding
lengths = sequence_mask.sum(dim=1).cpu()
packed = nn.utils.rnn.pack_padded_sequence(visit_vectors, lengths, batch_first=True, enforce_sorted=False)
packed_output, _ = lstm(packed)
lstm_output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
Summary¶
This document provides a complete reference for data shapes throughout the EHR sequence modeling pipeline. Key points:
- Consistent shapes: All transformations maintain clear input/output contracts
- Proper masking: Essential for handling variable-length sequences
- Memory efficiency: Dynamic padding and packed sequences reduce waste
- Type safety: Clear distinction between lists, dicts, dataclasses, and tensors
- Scalability: Pipeline handles datasets from 100 to 100,000+ patients
For implementation details, see:
- 01_synthea_data_exploration.ipynb - Data loading and exploration
- 01a_lstm_data_preparation.ipynb - LSTM input preparation
- examples/train_lstm_baseline.py - Full training pipeline
- src/ehrsequencing/models/lstm_baseline.py - Model architecture