Skip to content

[Model] Add GraphCare — KG-enhanced EHR predictions (ICLR 2024)#830

Open
joshuasteier wants to merge 2 commits intosunlabuiuc:masterfrom
joshuasteier:model/graphcare
Open

[Model] Add GraphCare — KG-enhanced EHR predictions (ICLR 2024)#830
joshuasteier wants to merge 2 commits intosunlabuiuc:masterfrom
joshuasteier:model/graphcare

Conversation

@joshuasteier
Copy link
Contributor

@joshuasteier joshuasteier commented Feb 8, 2026

[Model] Add GraphCare — KG-enhanced EHR predictions (ICLR 2024)

Contributor

  • Name: Josh Steier

Contribution Type

New Model

Description

Implements GraphCare from "GraphCare: Enhancing Healthcare Predictions with Personalized Knowledge Graphs" (Jiang et al., ICLR 2024).

GraphCare constructs personalized patient-level knowledge graphs from EHR codes (conditions, procedures, drugs) and applies a GNN with bi-attention pooling for downstream prediction tasks. This is a faithful port of the original implementation adapted for PyHealth conventions.

Key features:

  • Three GNN backbones: BAT (Bi-Attention GNN, from the paper), GAT, GIN
  • Three patient representation modes: joint, graph, node
  • Visit-level (alpha) and node-level (beta) attention with temporal decay
  • Optional edge attention and edge dropout
  • Attention weight extraction for interpretability (store_attn=True)
  • Support for pre-trained or learned node/relation embeddings

Supported tasks: mortality, readmission, drug recommendation, length-of-stay

Paper: https://openreview.net/forum?id=tVTN7Zs0ml
Original repo: https://github.com/pat-jj/GraphCare

⚠️ Points for Maintainer Discussion

1. torch-geometric dependency

This is the first PyHealth model that requires torch-geometric. It is handled as a lazy optional dependency — imported inside a try/except with a clear error message if missing. No changes to setup.py or requirements.txt are needed, but I wanted to flag this for your review. Happy to adjust the approach.

2. Does not inherit from BaseModel

GraphCare operates on torch_geometric.data.Data objects (per-patient subgraphs) rather than PyHealth's standard SampleDataset collation. This means it cannot use PyHealth's Trainer directly. The model is a standalone nn.Module with its own data pipeline utilities (graphcare_utils.py). This mirrors how the original paper's pipeline works. Open to suggestions on integration.

3. Synthetic data in tests and tutorial

The unit tests and tutorial notebook use synthetic patient graphs rather than real MIMIC-III data. This is intentional — the real KG artifacts require MIMIC access plus running the full LLM-prompted KG construction pipeline from the original repo. The synthetic data validates architecture correctness (shapes, gradients, no NaN across all 9 GNN×mode combinations). Section 9 of the tutorial shows the real-data workflow using graphcare_utils.

Files to Review

File Path Description
Model pyhealth/models/graphcare.py GraphCare + BiAttentionGNNConv (507 lines)
Data utils pyhealth/models/graphcare_utils.py KG loading, subgraph extraction, DataLoader helpers
Unit tests tests/core/test_graphcare.py 28 tests covering all GNN/mode combos, attention flags, edge dropout, embeddings, numerical sanity
Tutorial examples/graphcare_tutorial.ipynb End-to-end: synthetic data → train → evaluate → attention visualization
Docs docs/api/models/pyhealth.models.GraphCare.rst Sphinx autodoc page
Docs index docs/api/models.rst Added toctree entry (between GAMENet and GRASP)
Init pyhealth/models/init.py Added import line

How to Test

# Run unit tests (requires torch-geometric)
pip install torch-geometric
python -m pytest tests/test_graphcare.py -v

Run tutorial notebook

jupyter notebook examples/graphcare_tutorial.ipynb

Checklist

  • PEP8 compliant, 88-char line limit
  • Google-style docstrings with type hints
  • Unit tests with dummy inputs and expected outputs (28 tests)
  • Example notebook with end-to-end usage
  • Sphinx documentation page
  • Rebased with sunlabuiuc/PyHealth main
  • "Allow edits by maintainers" enabled

Steier added 2 commits February 8, 2026 10:48
Tensor is from torch (always available), not torch_geometric.
Having it inside the try block caused NameError at import time
when torch_geometric is not installed, breaking the entire
pyhealth.models import chain in CI.
Copy link
Collaborator

@jhnwu3 jhnwu3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, by any chance, do you think its possible once you get compute access and dataset access to do a MIMIC native version of this? I am really interested in moving PyHealth's sample datasets and dataloaders towards compatibility with PyG such that we can natively do this type of stuff in PyHealth.

Let me think about this more in terms of design. Because, it is a really cool idea to be able to directly construct patient graphs with PyHealth. I think it'll change how easy it is to deploy graph-based models on patient populations.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why this only has "white space only changes" here. Maybe a typo?

@joshuasteier
Copy link
Contributor Author

Hi, @jhnwu3:

Definitely, I'd love to build a MIMIC-native version of this once I have data/compute access. I think the natural path is something like:

  1. This PR as-is (standalone with its own PyG pipeline, works for people who already have KG artifacts from the original repo)
  2. A PyHealthGraphDataset adapter that takes a standard SampleDataset and converts patient visits into PyG Data objects — subgraph extraction, code-to-KG mapping, DataLoader-ready batches. This is probably currently the most reasonable next step.
  3. Longer term, native PyG support in PyHealth's core so any EHR dataset can be projected onto a KG and consumed by graph models (GraphCare, G-BERT, KAME, etc.)

I'd want to take on step 2. The main design question is where KG construction fits — pre-computed artifacts vs. on-the-fly from code mappings — and how that interacts with SampleDataset. Let me know if you'd prefer to hash that out async or on a call.


Reply 2: Whitespace-only changes

Good catch, will fix.

Thank you!

@jhnwu3
Copy link
Collaborator

jhnwu3 commented Feb 9, 2026

re: I think the better approach is probably building a pyhealth.graph module and adding pyhealth graph processors to pyhealth.processors. Let me see if I can't find Patrick's old branch.

https://github.com/sunlabuiuc/PyHealth/tree/kg_embedding/pyhealth/datasets
https://github.com/sunlabuiuc/PyHealth/blob/kg_embedding/pyhealth/datasets/base_kg_dataset.py

Patrick had an old base_kg_dataset that's basically this idea I believe before we had all of these backend changes. Let's talk more later this week since I don't exactly have a perfect idea of how best to define these graph structures or store them within pyhealth. (Maybe we just assume the user has to provide it in some format).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants