Skip to content

[Contribution] JambaEHR: Hybrid Transformer-Mamba model for EHR prediction#848

Merged
jhnwu3 merged 1 commit intosunlabuiuc:masterfrom
joshuasteier:feature/jamba-ehr
Feb 13, 2026
Merged

[Contribution] JambaEHR: Hybrid Transformer-Mamba model for EHR prediction#848
jhnwu3 merged 1 commit intosunlabuiuc:masterfrom
joshuasteier:feature/jamba-ehr

Conversation

@joshuasteier
Copy link
Contributor

Contributor Information

  • Name: Joshua Steier
  • Contribution Type: Model + Tests + Documentation

Description

Added JambaEHR, a hybrid Transformer-Mamba model for EHR clinical prediction inspired by Jamba (AI21 Labs, ICLR 2025). Interleaves existing PyHealth TransformerBlock and MambaBlock layers in a configurable ratio, combining attention's global context modeling with SSM's linear-time efficiency for long patient histories.

Key features:

  • Reuses existing TransformerBlock and MambaBlock — zero code duplication
  • Configurable num_transformer_layers + num_mamba_layers parameters
  • Evenly distributes attention layers through the SSM stack (Jamba design)
  • Degrades gracefully to pure Transformer or pure Mamba with 0-count params
  • Compatible with all existing PyHealth processors and Trainer
  • Uses get_last_visit pooling (same as EHRMamba)

Paper: Jamba: A Hybrid Transformer-Mamba Language Model (AI21 Labs, 2024)

This model is part of the multimodal embedding pipeline:

TemporalFeatureProcessor → Modality Encoders → JambaEHR backbone → Prediction Head

Files to Review

File Description
pyhealth/models/jamba_ehr.py Model implementation (JambaLayer + JambaEHR)
pyhealth/models/init.py Added JambaEHR, JambaLayer exports
tests/core/test_jamba_ehr.py 17 unit tests
docs/api/models/pyhealth.models.JambaEHR.rst API docs
docs/api/models.rst Added to toctree

Testing

# Run model smoke test (tests 4 configurations)
python pyhealth/models/jamba_ehr.py

Run unit tests

python -m pytest tests/models/test_jamba_ehr.py -v

Result: 17 passed

Architecture

Input: (B, S, E) per feature key
         │
    EmbeddingModel (shared)
         │
    JambaLayer per feature key:
      [Mamba, Mamba, Mamba, Transformer, Mamba, Mamba, Mamba, Transformer]
         │              (configurable schedule)
    get_last_visit pooling → concat → dropout → FC head
         │
Output: {loss, y_prob, y_true, logit}

Usage

from pyhealth.models import JambaEHR

model = JambaEHR(
dataset=dataset,
embedding_dim=128,
num_transformer_layers=2, # attention layers
num_mamba_layers=6, # SSM layers
heads=4,
)

Copy link
Collaborator

@jhnwu3 jhnwu3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, that was really fast! We'll definitely need to quickly iterate later on the embedding models soon so we can update these files in a Multimodal update once we can test these things with the correct amount of compute.

@jhnwu3 jhnwu3 merged commit c063e47 into sunlabuiuc:master Feb 13, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants