Resource-Aware Model Configurations¶
Date: January 20, 2026
Focus: Small/Medium/Large model presets for M1 MacBook (16GB) and RunPod deployment
Overview¶
This document defines resource-aware model configurations for the EHR Sequencing project, optimized for different hardware constraints:
- SMALL: M1 MacBook Pro 16GB (local development, fast iteration)
- MEDIUM: RunPod with 24GB GPU (A10, RTX 4090)
- LARGE: Cloud instances with 40GB+ GPU (A40, A100)
Key Principle: All models are logically equivalent but scaled for available resources. This enables: - Fast iteration on M1 MacBook - Realistic training on RunPod - Production-scale deployment on cloud
Part 1: LSTM Baseline for Visit-Grouped Sequences¶
Should We Use LSTM as a Baseline?¶
Answer: YES - LSTM is an excellent baseline for visit-grouped sequences.
Rationale¶
1. Natural Fit for Visit Sequences
# Visit sequence is naturally sequential
patient_trajectory = [visit1, visit2, visit3, ..., visitN]
# LSTM processes sequences step-by-step
hidden_state = lstm(visit_embeddings)
2. Computational Efficiency - Much faster than Transformers for long sequences - Lower memory footprint - Suitable for M1 MacBook development
3. Strong Baseline Performance - LSTMs have proven effective for EHR sequences - Captures temporal dependencies - Easier to interpret than Transformers
4. Comparison Point - Establishes baseline performance - Compare against Transformer-based models - Validate that complexity is justified
LSTM vs Transformer for Visit Sequences¶
| Aspect | LSTM | Transformer |
|---|---|---|
| Sequence Length | Efficient for long sequences (50+ visits) | Quadratic complexity O(n²) |
| Memory | O(n) - Linear | O(n²) - Quadratic |
| Training Speed | Fast | Slower |
| Long-range Deps | Limited (vanishing gradients) | Excellent (attention) |
| Interpretability | Moderate (hidden states) | High (attention weights) |
| M1 MacBook | ✅ Runs well | ⚠️ Slower, more memory |
| Best For | Baseline, fast iteration | Production, best performance |
Recommendation: Implement both LSTM and Transformer, use LSTM as baseline.
Part 2: LSTM Baseline Architecture¶
Two-Level LSTM for Visit-Grouped Sequences¶
import torch
import torch.nn as nn
class LSTMVisitEncoder(nn.Module):
"""
Two-level LSTM for visit-grouped sequences.
Level 1: Encode codes within each visit
Level 2: Model sequence of visits
"""
def __init__(
self,
vocab_size: int,
code_embed_dim: int = 128,
visit_embed_dim: int = 256,
hidden_dim: int = 512,
num_layers: int = 2,
dropout: float = 0.1,
bidirectional: bool = False
):
super().__init__()
# Code embeddings (can use pre-trained)
self.code_embeddings = nn.Embedding(vocab_size, code_embed_dim, padding_idx=0)
# Level 1: Within-visit LSTM
self.visit_lstm = nn.LSTM(
input_size=code_embed_dim,
hidden_size=code_embed_dim,
num_layers=1,
batch_first=True,
dropout=0,
bidirectional=False
)
# Project to visit embedding
self.visit_projection = nn.Linear(code_embed_dim, visit_embed_dim)
# Level 2: Visit sequence LSTM
self.sequence_lstm = nn.LSTM(
input_size=visit_embed_dim + 2, # +2 for time features
hidden_size=hidden_dim,
num_layers=num_layers,
batch_first=True,
dropout=dropout if num_layers > 1 else 0,
bidirectional=bidirectional
)
self.dropout = nn.Dropout(dropout)
self.hidden_dim = hidden_dim * (2 if bidirectional else 1)
def encode_visit(self, visit_codes, visit_mask):
"""
Encode a single visit.
Args:
visit_codes: [batch, max_codes_per_visit]
visit_mask: [batch, max_codes_per_visit]
Returns:
visit_embedding: [batch, visit_embed_dim]
"""
# Embed codes
code_embeds = self.code_embeddings(visit_codes) # [batch, codes, embed_dim]
# LSTM over codes
lstm_out, (hidden, _) = self.visit_lstm(code_embeds)
# Use last hidden state (or mean pooling)
visit_repr = hidden[-1] # [batch, code_embed_dim]
# Project to visit space
return self.visit_projection(visit_repr)
def forward(self, patient_visits, time_features, visit_mask):
"""
Encode patient visit sequence.
Args:
patient_visits: [batch, num_visits, max_codes_per_visit]
time_features: [batch, num_visits, 2]
visit_mask: [batch, num_visits, max_codes_per_visit]
Returns:
sequence_output: [batch, num_visits, hidden_dim]
final_hidden: [num_layers, batch, hidden_dim]
"""
batch_size, num_visits, max_codes = patient_visits.shape
# Encode each visit
visit_embeds = []
for i in range(num_visits):
visit_embed = self.encode_visit(
patient_visits[:, i, :],
visit_mask[:, i, :]
)
visit_embeds.append(visit_embed)
visit_embeds = torch.stack(visit_embeds, dim=1) # [batch, visits, visit_dim]
# Concatenate time features
visit_embeds_with_time = torch.cat([visit_embeds, time_features], dim=-1)
# LSTM over visit sequence
sequence_output, (final_hidden, final_cell) = self.sequence_lstm(
visit_embeds_with_time
)
return sequence_output, final_hidden
Disease Progression Model with LSTM¶
class LSTMProgressionModel(nn.Module):
"""
LSTM-based disease progression model.
Predicts disease stage and time to progression.
"""
def __init__(
self,
vocab_size: int,
code_embed_dim: int = 128,
visit_embed_dim: int = 256,
hidden_dim: int = 512,
num_stages: int = 5,
num_layers: int = 2,
dropout: float = 0.1
):
super().__init__()
# Visit encoder
self.encoder = LSTMVisitEncoder(
vocab_size=vocab_size,
code_embed_dim=code_embed_dim,
visit_embed_dim=visit_embed_dim,
hidden_dim=hidden_dim,
num_layers=num_layers,
dropout=dropout,
bidirectional=False
)
# Prediction heads
self.stage_classifier = nn.Linear(hidden_dim, num_stages)
self.time_to_progression = nn.Linear(hidden_dim, 1)
self.dropout = nn.Dropout(dropout)
def forward(self, patient_visits, time_features, visit_mask):
"""
Predict disease progression.
Returns:
stage_logits: [batch, num_visits, num_stages]
time_pred: [batch, num_visits, 1]
"""
# Encode sequence
sequence_output, _ = self.encoder(patient_visits, time_features, visit_mask)
sequence_output = self.dropout(sequence_output)
# Predict at each visit
stage_logits = self.stage_classifier(sequence_output)
time_pred = torch.relu(self.time_to_progression(sequence_output))
return stage_logits, time_pred
Part 3: Resource-Aware Model Configurations¶
Configuration System¶
# src/ehrsequencing/models/configs/model_configs.py
from dataclasses import dataclass
from typing import Optional
@dataclass
class EHRModelConfig:
"""Configuration for EHR sequence models."""
# Model architecture
code_embed_dim: int
visit_embed_dim: int
hidden_dim: int
num_layers: int
dropout: float = 0.1
# Model type
model_type: str = "lstm" # "lstm" or "transformer"
# Transformer-specific
num_heads: Optional[int] = None
use_flash_attention: bool = False
# Vocabulary
vocab_size: int = 10000
max_visits: int = 50
max_codes_per_visit: int = 100
# Training
batch_size: int = 32
gradient_accumulation_steps: int = 1
mixed_precision: bool = True
use_checkpoint: bool = False
# Task-specific
num_stages: int = 5 # For disease progression
@property
def effective_batch_size(self) -> int:
return self.batch_size * self.gradient_accumulation_steps
@property
def total_params_millions(self) -> float:
"""Estimate total parameters in millions."""
if self.model_type == "lstm":
# LSTM parameters
# Input-to-hidden: 4 * (input_size * hidden_dim + hidden_dim^2)
# Per layer
code_lstm_params = 4 * (self.code_embed_dim * self.code_embed_dim +
self.code_embed_dim ** 2)
visit_lstm_params = 4 * (self.visit_embed_dim * self.hidden_dim +
self.hidden_dim ** 2) * self.num_layers
# Embeddings
embed_params = self.vocab_size * self.code_embed_dim
# Projection layers
proj_params = self.code_embed_dim * self.visit_embed_dim
# Prediction heads
head_params = self.hidden_dim * (self.num_stages + 1)
total = (code_lstm_params + visit_lstm_params + embed_params +
proj_params + head_params) / 1e6
elif self.model_type == "transformer":
# Transformer parameters (rough estimate)
attn_params = 4 * self.visit_embed_dim * self.visit_embed_dim * self.num_layers
ffn_params = 2 * self.visit_embed_dim * self.hidden_dim * self.num_layers
embed_params = self.vocab_size * self.code_embed_dim
total = (attn_params + ffn_params + embed_params) / 1e6
return round(total, 2)
def memory_estimate_gb(self, dtype_bytes: int = 4) -> float:
"""Estimate memory usage in GB."""
# Model parameters
model_memory = self.total_params_millions * 1e6 * dtype_bytes / 1e9
# Optimizer states (Adam: 2x)
optimizer_memory = model_memory * 2
# Activations (LSTM is more memory efficient than Transformer)
if self.model_type == "lstm":
activation_memory = (
self.batch_size * self.max_visits * self.hidden_dim *
self.num_layers * dtype_bytes / 1e9
)
else: # transformer
activation_memory = (
self.batch_size * self.max_visits * self.max_visits *
self.num_heads * dtype_bytes / 1e9
)
# Gradients
gradient_memory = model_memory
total = model_memory + optimizer_memory + activation_memory + gradient_memory
total *= 1.2 # 20% overhead
return round(total, 2)
Preset Configurations¶
# src/ehrsequencing/models/configs/presets.py
from .model_configs import EHRModelConfig
# ============================================================================
# SMALL: M1 MacBook Pro 16GB
# ============================================================================
SMALL_LSTM_CONFIG = EHRModelConfig(
# Architecture
code_embed_dim=128,
visit_embed_dim=256,
hidden_dim=256,
num_layers=2,
dropout=0.1,
# Model type
model_type="lstm",
# Vocabulary
vocab_size=10000,
max_visits=50,
max_codes_per_visit=50,
# Training (optimized for M1 16GB)
batch_size=4,
gradient_accumulation_steps=8, # Effective batch = 32
mixed_precision=True,
use_checkpoint=True, # Gradient checkpointing
# Task
num_stages=5,
)
SMALL_TRANSFORMER_CONFIG = EHRModelConfig(
# Architecture
code_embed_dim=128,
visit_embed_dim=256,
hidden_dim=512,
num_layers=4,
num_heads=4,
dropout=0.1,
# Model type
model_type="transformer",
use_flash_attention=False, # Not available on M1
# Vocabulary
vocab_size=10000,
max_visits=30, # Shorter for memory
max_codes_per_visit=50,
# Training
batch_size=2,
gradient_accumulation_steps=16, # Effective batch = 32
mixed_precision=True,
use_checkpoint=True,
# Task
num_stages=5,
)
# ============================================================================
# MEDIUM: RunPod 24GB GPU (A10, RTX 4090)
# ============================================================================
MEDIUM_LSTM_CONFIG = EHRModelConfig(
# Architecture
code_embed_dim=256,
visit_embed_dim=512,
hidden_dim=512,
num_layers=3,
dropout=0.1,
# Model type
model_type="lstm",
# Vocabulary
vocab_size=20000,
max_visits=100,
max_codes_per_visit=100,
# Training (optimized for 24GB GPU)
batch_size=32,
gradient_accumulation_steps=1,
mixed_precision=True,
use_checkpoint=False,
# Task
num_stages=5,
)
MEDIUM_TRANSFORMER_CONFIG = EHRModelConfig(
# Architecture
code_embed_dim=256,
visit_embed_dim=512,
hidden_dim=1024,
num_layers=6,
num_heads=8,
dropout=0.1,
# Model type
model_type="transformer",
use_flash_attention=True,
# Vocabulary
vocab_size=20000,
max_visits=50,
max_codes_per_visit=100,
# Training
batch_size=16,
gradient_accumulation_steps=2, # Effective batch = 32
mixed_precision=True,
use_checkpoint=False,
# Task
num_stages=5,
)
# ============================================================================
# LARGE: Cloud 40GB+ GPU (A40, A100)
# ============================================================================
LARGE_LSTM_CONFIG = EHRModelConfig(
# Architecture
code_embed_dim=512,
visit_embed_dim=768,
hidden_dim=1024,
num_layers=4,
dropout=0.1,
# Model type
model_type="lstm",
# Vocabulary
vocab_size=50000,
max_visits=200,
max_codes_per_visit=150,
# Training (optimized for 40GB+ GPU)
batch_size=64,
gradient_accumulation_steps=1,
mixed_precision=True,
use_checkpoint=False,
# Task
num_stages=5,
)
LARGE_TRANSFORMER_CONFIG = EHRModelConfig(
# Architecture
code_embed_dim=512,
visit_embed_dim=768,
hidden_dim=2048,
num_layers=12,
num_heads=12,
dropout=0.1,
# Model type
model_type="transformer",
use_flash_attention=True,
# Vocabulary
vocab_size=50000,
max_visits=100,
max_codes_per_visit=150,
# Training
batch_size=32,
gradient_accumulation_steps=2, # Effective batch = 64
mixed_precision=True,
use_checkpoint=False,
# Task
num_stages=5,
)
def get_model_config(size: str = "small", model_type: str = "lstm") -> EHRModelConfig:
"""
Get a preset model configuration.
Args:
size: "small", "medium", or "large"
model_type: "lstm" or "transformer"
Returns:
EHRModelConfig instance
Examples:
>>> # For M1 MacBook development
>>> config = get_model_config("small", "lstm")
>>> print(f"Memory: {config.memory_estimate_gb(dtype_bytes=2)}GB")
>>> # For RunPod training
>>> config = get_model_config("medium", "transformer")
"""
configs = {
("small", "lstm"): SMALL_LSTM_CONFIG,
("small", "transformer"): SMALL_TRANSFORMER_CONFIG,
("medium", "lstm"): MEDIUM_LSTM_CONFIG,
("medium", "transformer"): MEDIUM_TRANSFORMER_CONFIG,
("large", "lstm"): LARGE_LSTM_CONFIG,
("large", "transformer"): LARGE_TRANSFORMER_CONFIG,
}
key = (size, model_type)
if key not in configs:
raise ValueError(
f"Unknown config: size='{size}', model_type='{model_type}'. "
f"Valid sizes: small, medium, large. Valid types: lstm, transformer."
)
return configs[key]
Part 4: Configuration Comparison¶
Memory and Performance Estimates¶
| Config | Model Type | Params (M) | Memory (GB) | Batch | Max Visits | Hardware |
|---|---|---|---|---|---|---|
| Small LSTM | LSTM | ~5M | ~3GB | 4 (32) | 50 | M1 16GB |
| Small Transformer | Transformer | ~8M | ~5GB | 2 (32) | 30 | M1 16GB |
| Medium LSTM | LSTM | ~15M | ~8GB | 32 | 100 | RunPod 24GB |
| Medium Transformer | Transformer | ~25M | ~12GB | 16 (32) | 50 | RunPod 24GB |
| Large LSTM | LSTM | ~40M | ~18GB | 64 | 200 | A40/A100 40GB+ |
| Large Transformer | Transformer | ~80M | ~32GB | 32 (64) | 100 | A40/A100 40GB+ |
Notes: - Memory estimates assume fp16 mixed precision - Batch shows actual batch size (effective batch size in parentheses) - Small configs use gradient checkpointing - LSTM is 2-3x more memory efficient than Transformer
Part 5: Usage Examples¶
Training Script with Config System¶
# scripts/train_disease_progression.py
import argparse
import torch
from ehrsequencing.models.configs import get_model_config
from ehrsequencing.models import LSTMProgressionModel, TransformerProgressionModel
from ehrsequencing.data import load_sequences
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--size", default="small", choices=["small", "medium", "large"])
parser.add_argument("--model-type", default="lstm", choices=["lstm", "transformer"])
parser.add_argument("--epochs", type=int, default=50)
parser.add_argument("--data-path", required=True)
args = parser.parse_args()
# Get configuration
config = get_model_config(args.size, args.model_type)
print(f"\n{'='*80}")
print(f"Training {args.model_type.upper()} model ({args.size} config)")
print(f"{'='*80}")
print(f"Parameters: ~{config.total_params_millions}M")
print(f"Memory estimate: ~{config.memory_estimate_gb(dtype_bytes=2)}GB (fp16)")
print(f"Effective batch size: {config.effective_batch_size}")
print(f"Max visits: {config.max_visits}")
print(f"{'='*80}\n")
# Load data
sequences = load_sequences(args.data_path)
# Create model
if args.model_type == "lstm":
model = LSTMProgressionModel(
vocab_size=config.vocab_size,
code_embed_dim=config.code_embed_dim,
visit_embed_dim=config.visit_embed_dim,
hidden_dim=config.hidden_dim,
num_stages=config.num_stages,
num_layers=config.num_layers,
dropout=config.dropout
)
else:
model = TransformerProgressionModel(
vocab_size=config.vocab_size,
embed_dim=config.visit_embed_dim,
hidden_dim=config.hidden_dim,
num_stages=config.num_stages,
num_layers=config.num_layers,
num_heads=config.num_heads,
dropout=config.dropout
)
# Training loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# ... training code ...
if __name__ == "__main__":
main()
Local Development Workflow¶
# On M1 MacBook: Fast iteration with small LSTM
python scripts/train_disease_progression.py \
--size small \
--model-type lstm \
--epochs 10 \
--data-path data/ckd_sequences.pt
# Verify logic works, then scale up on RunPod
RunPod Training Workflow¶
# On RunPod A10 24GB: Medium transformer
python scripts/train_disease_progression.py \
--size medium \
--model-type transformer \
--epochs 100 \
--data-path /workspace/ehr-sequencing/data/ckd_sequences.pt
Part 6: Pre-trained Model Integration¶
Using Pre-trained Embeddings with Resource Configs¶
from ehrsequencing.embeddings import CEHRBERTWrapper
from ehrsequencing.models.configs import get_model_config
# Load pre-trained embeddings
cehrbert = CEHRBERTWrapper.from_pretrained('cehrbert-base')
pretrained_embeddings = cehrbert.get_code_embeddings()
# Get config
config = get_model_config("small", "lstm")
# Create model with pre-trained embeddings
model = LSTMProgressionModel(
vocab_size=config.vocab_size,
code_embed_dim=config.code_embed_dim,
# ... other params
)
# Replace embeddings
model.encoder.code_embeddings = pretrained_embeddings
# Optionally freeze embeddings
model.encoder.code_embeddings.weight.requires_grad = False
Part 7: Recommendations¶
Implementation Priority¶
Phase 1: LSTM Baseline (Week 1-2)
1. Implement LSTMVisitEncoder (small config)
2. Implement LSTMProgressionModel
3. Train on M1 MacBook with small dataset
4. Validate logic and pipeline
Phase 2: Scale to RunPod (Week 3) 1. Use medium LSTM config 2. Train on full dataset 3. Establish baseline performance
Phase 3: Transformer Comparison (Week 4+) 1. Implement Transformer variant 2. Compare LSTM vs Transformer 3. Decide on production model
Why Start with LSTM¶
- Fast Development - Works well on M1 MacBook
- Strong Baseline - Proven for EHR sequences
- Memory Efficient - 2-3x less memory than Transformer
- Interpretable - Easier to debug
- Comparison Point - Validate if Transformer complexity is needed
When to Use Transformer¶
- After LSTM baseline is established
- When training on RunPod (24GB+)
- When long-range dependencies are critical
- For production deployment (best performance)
Part 8: File Structure¶
src/ehrsequencing/models/
├── configs/
│ ├── __init__.py
│ ├── model_configs.py # EHRModelConfig dataclass
│ └── presets.py # SMALL/MEDIUM/LARGE configs
├── lstm/
│ ├── __init__.py
│ ├── visit_encoder.py # LSTMVisitEncoder
│ └── progression_model.py # LSTMProgressionModel
├── transformer/
│ ├── __init__.py
│ ├── visit_encoder.py # TransformerVisitEncoder
│ └── progression_model.py # TransformerProgressionModel
└── __init__.py
scripts/
├── train_disease_progression.py # Main training script
└── compare_models.py # LSTM vs Transformer comparison
Summary¶
Key Decisions¶
- ✅ LSTM as baseline - Excellent fit for visit-grouped sequences
- ✅ Resource-aware configs - Small/Medium/Large for different hardware
- ✅ M1 MacBook support - Small configs for local development
- ✅ RunPod scaling - Medium configs for realistic training
- ✅ Logically equivalent - Same architecture, different scales
Expected Performance¶
LSTM Baseline: - M1 MacBook: ~10 min/epoch (small dataset) - RunPod 24GB: ~2 min/epoch (full dataset) - A40 40GB: ~1 min/epoch (full dataset)
Memory Usage: - Small LSTM: ~3GB (M1 safe) - Medium LSTM: ~8GB (RunPod safe) - Large LSTM: ~18GB (A40 safe)
Next Steps¶
- Implement model config system
- Implement LSTM baseline
- Test on M1 MacBook (small config)
- Scale to RunPod (medium config)
- Compare with Transformer (optional)
Document Version: 1.0
Last Updated: January 20, 2026
Next Review: After LSTM baseline implementation